利用TensorFlow训练简单的二分类神经网络模型的方法


Posted in Python onMarch 05, 2018

利用TensorFlow实现《神经网络与机器学习》一书中4.7模式分类练习

具体问题是将如下图所示双月牙数据集分类。

利用TensorFlow训练简单的二分类神经网络模型的方法

使用到的工具:

python3.5    tensorflow1.2.1   numpy   matplotlib

1.产生双月环数据集

def produceData(r,w,d,num): 
  r1 = r-w/2 
  r2 = r+w/2 
  #上半圆 
  theta1 = np.random.uniform(0, np.pi ,num) 
  X_Col1 = np.random.uniform( r1*np.cos(theta1),r2*np.cos(theta1),num)[:, np.newaxis] 
  X_Row1 = np.random.uniform(r1*np.sin(theta1),r2*np.sin(theta1),num)[:, np.newaxis] 
  Y_label1 = np.ones(num) #类别标签为1 
  #下半圆 
  theta2 = np.random.uniform(-np.pi, 0 ,num) 
  X_Col2 = (np.random.uniform( r1*np.cos(theta2),r2*np.cos(theta2),num) + r)[:, np.newaxis] 
  X_Row2 = (np.random.uniform(r1 * np.sin(theta2), r2 * np.sin(theta2), num) -d)[:,np.newaxis] 
  Y_label2 = -np.ones(num) #类别标签为-1,注意:由于采取双曲正切函数作为激活函数,类别标签不能为0 
  #合并 
  X_Col = np.vstack((X_Col1, X_Col2)) 
  X_Row = np.vstack((X_Row1, X_Row2)) 
  X = np.hstack((X_Col, X_Row)) 
  Y_label = np.hstack((Y_label1,Y_label2)) 
  Y_label.shape = (num*2 , 1) 
  return X,Y_label

其中r为月环半径,w为月环宽度,d为上下月环距离(与书中一致)

2.利用TensorFlow搭建神经网络模型

2.1 神经网络层添加

def add_layer(layername,inputs, in_size, out_size, activation_function=None): 
  # add one more layer and return the output of this layer 
  with tf.variable_scope(layername,reuse=None): 
    Weights = tf.get_variable("weights",shape=[in_size, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
    biases = tf.get_variable("biases", shape=[1, out_size], 
                 initializer=tf.truncated_normal_initializer(stddev=0.1)) 
   
  Wx_plus_b = tf.matmul(inputs, Weights) + biases 
  if activation_function is None: 
    outputs = Wx_plus_b 
  else: 
    outputs = activation_function(Wx_plus_b) 
  return outputs

2.2 利用tensorflow建立神经网络模型

输入层大小:2

隐藏层大小:20

输出层大小:1

激活函数:双曲正切函数

学习率:0.1(与书中略有不同)

(具体的搭建过程可参考莫烦的视频,链接就不附上了自行搜索吧......)

###define placeholder for inputs to network 
xs = tf.placeholder(tf.float32, [None, 2]) 
ys = tf.placeholder(tf.float32, [None, 1]) 
###添加隐藏层 
l1 = add_layer("layer1",xs, 2, 20, activation_function=tf.tanh) 
###添加输出层 
prediction = add_layer("layer2",l1, 20, 1, activation_function=tf.tanh) 
###MSE 均方误差 
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys-prediction), reduction_indices=[1])) 
###优化器选取 学习率设置 此处学习率置为0.1 
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) 
###tensorflow变量初始化,打开会话 
init = tf.global_variables_initializer()#tensorflow更新后初始化所有变量不再用tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init)

2.3 训练模型

###训练2000次 
for i in range(2000): 
  sess.run(train_step, feed_dict={xs: x_data, ys: y_label})

3.利用训练好的网络模型寻找分类决策边界

3.1 产生二维空间随机点

def produce_random_data(r,w,d,num): 
  X1 = np.random.uniform(-r-w/2,2*r+w/2, num) 
  X2 = np.random.uniform(-r - w / 2-d, r+w/2, num) 
  X = np.vstack((X1, X2)) 
  return X.transpose()

3.2 用训练好的模型采集决策边界附近的点

向网络输入一个二维空间随机点,计算输出值大于-0.5小于0.5即认为该点落在决策边界附近(双曲正切函数)

def collect_boundary_data(v_xs): 
  global prediction 
  X = np.empty([1,2]) 
  X = list() 
  for i in range(len(v_xs)): 
    x_input = v_xs[i] 
    x_input.shape = [1,2] 
    y_pre = sess.run(prediction, feed_dict={xs: x_input}) 
    if abs(y_pre - 0) < 0.5: 
      X.append(v_xs[i]) 
  return np.array(X)

3.3 用numpy工具将采集到的边界附近点拟合成决策边界曲线,用matplotlib.pyplot画图

###产生空间随机数据 
  X_NUM = produce_random_data(10, 6, -4, 5000) 
  ###边界数据采样 
  X_b = collect_boundary_data(X_NUM) 
  ###画出数据 
  fig = plt.figure() 
  ax = fig.add_subplot(1, 1, 1) 
  ###设置坐标轴名称 
  plt.xlabel('x1') 
  plt.ylabel('x2') 
  ax.scatter(x_data[:, 0], x_data[:, 1], marker='x') 
  ###用采样的边界数据拟合边界曲线 7次曲线最佳 
  z1 = np.polyfit(X_b[:, 0], X_b[:, 1], 7) 
  p1 = np.poly1d(z1) 
  x = X_b[:, 0] 
  x.sort() 
  yvals = p1(x) 
  plt.plot(x, yvals, 'r', label='boundray line') 
  plt.legend(loc=4) 
  #plt.ion() 
  plt.show()

4.效果

利用TensorFlow训练简单的二分类神经网络模型的方法

5.附上源码Github链接

https://github.com/Peakulorain/Practices.git 里的PatternClassification.py文件

另注:分类问题还是用softmax去做吧.....我只是用这做书上的练习而已。

(初学者水平有限,有问题请指出,各位大佬轻喷)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python mysqldb连接数据库
Mar 16 Python
python抓取京东价格分析京东商品价格走势
Jan 09 Python
Python中使用haystack实现django全文检索搜索引擎功能
Aug 26 Python
Python实现的计数排序算法示例
Nov 29 Python
为什么选择python编程语言入门黑客攻防 给你几个理由!
Feb 02 Python
利用python如何处理nc数据详解
May 23 Python
python ddt数据驱动最简实例代码
Feb 22 Python
python tkinter库实现气泡屏保和锁屏
Jul 29 Python
Python爬虫 批量爬取下载抖音视频代码实例
Aug 16 Python
Python tkinter三种布局实例详解
Jan 06 Python
一文读懂Python 枚举
Aug 25 Python
Python调用JavaScript代码的方法
Oct 27 Python
python使用Pycharm创建一个Django项目
Mar 05 #Python
python爬虫基本知识
Mar 05 #Python
用tensorflow构建线性回归模型的示例代码
Mar 05 #Python
详解python实现线程安全的单例模式
Mar 05 #Python
分析python动态规划的递归、非递归实现
Mar 04 #Python
python3.x上post发送json数据
Mar 04 #Python
python数据封装json格式数据
Mar 04 #Python
You might like
资料注册后发信小技巧
2006/10/09 PHP
php垃圾代码优化操作代码
2010/08/05 PHP
php生成html文件方法总结
2014/12/01 PHP
PHP读取mssql json数据中文乱码的解决办法
2016/04/11 PHP
ThinkPHP3.2.3实现分页的方法详解
2016/06/03 PHP
PHP生成图片缩略图类示例
2017/01/12 PHP
基于ThinkPHP5.0实现图片上传插件
2017/09/25 PHP
比较详细的javascript对象的property和prototype是什么一种关系
2007/08/06 Javascript
javascript 写类方式之七
2009/07/05 Javascript
读jQuery之四(优雅的迭代)
2011/06/20 Javascript
js生成随机数之random函数随机示例
2013/12/20 Javascript
jquery 操作两个select实现值之间的互相传递
2014/03/07 Javascript
原生js实现淘宝首页点击按钮缓慢回到顶部效果
2014/04/06 Javascript
JavaScript中如何通过arguments对象实现对象的重载
2014/05/12 Javascript
使用jQuery.wechat构建微信WEB应用
2014/10/09 Javascript
javascript自动生成包含数字与字符的随机字符串
2015/02/09 Javascript
JavaScript中函数表达式和函数声明及函数声明与函数表达式的不同
2015/11/15 Javascript
鼠标悬停小图标显示大图标
2016/01/22 Javascript
AngularJs定制样式插入到ueditor中的问题小结
2016/08/01 Javascript
用js实现简单算法的实例代码
2016/09/24 Javascript
AngularJs验证重复密码的方法(两种)
2016/11/25 Javascript
jQuery中的select操作详解
2016/11/29 Javascript
jQuery使用ajax方法解析返回的json数据功能示例
2017/01/10 Javascript
Ionic3实现图片瀑布流布局
2017/08/09 Javascript
详解node nvm进行node多版本管理
2017/10/21 Javascript
基于node打包可执行文件工具_Pkg使用心得分享
2018/01/24 Javascript
浅谈针对Vue相同路由不同参数的刷新问题
2018/09/29 Javascript
Vue基本使用之对象提供的属性功能
2019/04/30 Javascript
Python实现的几个常用排序算法实例
2014/06/16 Python
python之virtualenv的简单使用方法(必看篇)
2017/11/25 Python
python实现数据导出到excel的示例--普通格式
2018/05/03 Python
Python 经典面试题 21 道【不可错过】
2018/09/21 Python
Python-opencv实现红绿两色识别操作
2020/06/04 Python
StubHub新加坡:购买和出售全球活动门票
2017/03/10 全球购物
班队活动设计方案
2014/01/30 职场文书
手把手带你彻底卸载MySQL数据库
2022/06/14 MySQL