tensorflow建立一个简单的神经网络的方法


Posted in Python onFebruary 10, 2018

本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。

如何添加一层网络

代码如下:

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。

创建一些数据

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。

numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source] 
Return evenly spaced numbers over a specified interval. 
Returns num evenly spaced samples, calculated over the interval [start, stop].

tensorflow建立一个简单的神经网络的方法

noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。

numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。

添加占位符,用作输入

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

添加隐藏层和输出层

# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

计算误差,并用梯度下降使得误差最小

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

完整代码如下:

from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
           reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()

for i in range(1000):
  # training
  sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
  if i % 50 == 0:
    # to visualize the result and improvement
    try:
      ax.lines.remove(lines[0])
    except Exception:
      pass
    prediction_value = sess.run(prediction, feed_dict={xs: x_data})
    # plot the prediction
    lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
    plt.pause(0.1)

运行结果:

tensorflow建立一个简单的神经网络的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python2.x中对Unicode编码的使用
Apr 03 Python
用Python编写简单的定时器的方法
May 02 Python
python中while循环语句用法简单实例
May 07 Python
PyQt5主窗口动态加载Widget实例代码
Feb 07 Python
python基础知识(一)变量与简单数据类型详解
Apr 17 Python
python交互模式下输入换行/输入多行命令的方法
Jul 02 Python
python多线程实现代码(模拟银行服务操作流程)
Jan 13 Python
Python对Tornado请求与响应的数据处理
Feb 12 Python
Python基于requests库爬取网站信息
Mar 02 Python
简单了解Python字典copy与赋值的区别
Sep 16 Python
Pycharm同步远程服务器调试的方法步骤
Nov 04 Python
python asyncio 协程库的使用
Jan 21 Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
Python中协程用法代码详解
Feb 10 #Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 #Python
You might like
PHP--用万网的接口实现域名查询功能
2012/12/13 PHP
PHP基于反射机制实现插件的可插拔设计详解
2016/11/10 PHP
CI(CodeIgniter)框架视图中加载视图的方法
2017/03/24 PHP
PHP unlink与rmdir删除目录及目录下所有文件实例代码
2018/02/07 PHP
PHP使用PhpSpreadsheet操作Excel实例详解
2020/03/26 PHP
如何在PHP中使用数组
2020/06/09 PHP
获取DOM对象的几种扩展及简写
2006/10/09 Javascript
犀利的js 函数集合
2009/06/11 Javascript
js实现文本框中焦点在最后位置
2014/03/04 Javascript
JavaScript利用构造函数和原型的方式模拟C#类的功能
2014/03/06 Javascript
JavaScript让网页出现渐隐渐显背景颜色的方法
2015/04/21 Javascript
详解JavaScript函数对象
2015/11/15 Javascript
使用CSS+JavaScript或纯js实现半透明遮罩效果的实例分享
2016/05/09 Javascript
javascript js 操作数组 增删改查的简单实现
2016/06/20 Javascript
Bootstrap CSS组件之导航条(navbar)
2016/12/17 Javascript
ES6中数组array新增方法实例总结
2017/11/07 Javascript
JS实现标签滚动切换效果
2017/12/25 Javascript
React Component存在的几种形式详解
2018/11/06 Javascript
python实现代理服务功能实例
2013/11/15 Python
Python实现将DOC文档转换为PDF的方法
2015/07/25 Python
详解MySQL数据类型int(M)中M的含义
2016/11/20 Python
Python语言描述机器学习之Logistic回归算法
2017/12/21 Python
python实现批量解析邮件并下载附件
2018/06/19 Python
python模拟鼠标点击和键盘输入的操作
2019/08/04 Python
window7下的python2.7版本和python3.5版本的opencv-python安装过程
2019/10/24 Python
pycharm激活码免费分享适用最新pycharm2020.2.3永久激活
2020/11/25 Python
Python文件名匹配与文件复制的实现
2020/12/11 Python
KEETSA环保床垫:更好的睡眠,更好的生活!
2016/11/24 全球购物
幼师自我鉴定范文
2013/10/01 职场文书
本科生求职简历的自我评价
2013/10/21 职场文书
电气工程及其自动化专业求职信
2014/06/23 职场文书
实习指导老师意见
2015/06/04 职场文书
本科毕业论文答辩稿
2015/06/23 职场文书
《巨人的花园》教学反思
2016/02/19 职场文书
浅谈Python类的单继承相关知识
2021/05/12 Python
Java虚拟机内存结构及编码实战分享
2022/04/07 Java/Android