tensorflow建立一个简单的神经网络的方法


Posted in Python onFebruary 10, 2018

本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。

如何添加一层网络

代码如下:

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。

创建一些数据

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。

numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source] 
Return evenly spaced numbers over a specified interval. 
Returns num evenly spaced samples, calculated over the interval [start, stop].

tensorflow建立一个简单的神经网络的方法

noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。

numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。

添加占位符,用作输入

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])

添加隐藏层和输出层

# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

计算误差,并用梯度下降使得误差最小

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

完整代码如下:

from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

def add_layer(inputs, in_size, out_size, activation_function=None):
  # add one more layer and return the output of this layer
  Weights = tf.Variable(tf.random_normal([in_size, out_size]))
  biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
  Wx_plus_b = tf.matmul(inputs, Weights) + biases
  if activation_function is None:
    outputs = Wx_plus_b
  else:
    outputs = activation_function(Wx_plus_b)
  return outputs

# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise

# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)

# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
           reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()

for i in range(1000):
  # training
  sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
  if i % 50 == 0:
    # to visualize the result and improvement
    try:
      ax.lines.remove(lines[0])
    except Exception:
      pass
    prediction_value = sess.run(prediction, feed_dict={xs: x_data})
    # plot the prediction
    lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
    plt.pause(0.1)

运行结果:

tensorflow建立一个简单的神经网络的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python两种遍历字典(dict)的方法比较
May 29 Python
python中pandas.DataFrame的简单操作方法(创建、索引、增添与删除)
Mar 12 Python
Python实现求笛卡尔乘积的方法
Sep 16 Python
致Python初学者 Anaconda入门使用指南完整版
Apr 05 Python
pycharm 主题theme设置调整仿sublime的方法
May 23 Python
详解python数据结构和算法
Apr 18 Python
python3.6根据m3u8下载mp4视频
Jun 17 Python
使用Python画股票的K线图的方法步骤
Jun 28 Python
tensorboard 可以显示graph,却不能显示scalar的解决方式
Feb 15 Python
新手入门学习python Numpy基础操作
Mar 02 Python
pandas读取csv文件提示不存在的解决方法及原因分析
Apr 21 Python
解决python3.x安装numpy成功但import出错的问题
Nov 17 Python
python取代netcat过程分析
Feb 10 #Python
浅谈Python黑帽子取代netcat
Feb 10 #Python
python3爬取淘宝信息代码分析
Feb 10 #Python
Python中property属性实例解析
Feb 10 #Python
Java编程迭代地删除文件夹及其下的所有文件实例
Feb 10 #Python
Python中协程用法代码详解
Feb 10 #Python
Python实现简单生成验证码功能【基于random模块】
Feb 10 #Python
You might like
destoon实现资讯信息前面调用它所属分类的方法
2014/07/15 PHP
图文详解phpstorm配置Xdebug进行调试PHP教程
2016/06/13 PHP
浅谈PHP中try{}catch{}的使用方法
2016/12/09 PHP
PHP chunk_split()函数讲解
2019/02/12 PHP
Aster vs KG BO3 第三场2.18
2021/03/10 DOTA
JQuery 图片滚动轮播示例代码
2014/03/24 Javascript
JavaScript数组迭代器实例分析
2015/06/09 Javascript
jQuery中数据缓存$.data的用法及源码完全解析
2016/04/29 Javascript
json实现添加、遍历与删除属性的方法
2016/06/17 Javascript
JavaScript SHA1加密算法实现详细代码
2016/10/06 Javascript
浅谈jQuery hover(over, out)事件函数
2016/12/03 Javascript
JS实现DOM节点插入操作之子节点与兄弟节点插入操作示例
2018/07/30 Javascript
JavaScript函数定义方法实例详解
2019/03/05 Javascript
vue cli使用融云实现聊天功能的实例代码
2019/04/19 Javascript
基于 vue-skeleton-webpack-plugin 的骨架屏实战
2019/08/05 Javascript
vue vant中picker组件的使用
2020/11/03 Javascript
用Python的urllib库提交WEB表单
2009/02/24 Python
Python+Selenium自动化实现分页(pagination)处理
2017/03/31 Python
基于python3 OpenCV3实现静态图片人脸识别
2018/05/25 Python
Python3列表内置方法大全及示例代码小结
2019/05/10 Python
logging level级别介绍
2020/02/21 Python
Python+Appium实现自动化测试的使用步骤
2020/03/24 Python
PIP和conda 更换国内安装源的方法步骤
2020/09/21 Python
纯CSS3实现手风琴风格菜单具体步骤
2013/05/06 HTML / CSS
浅谈CSS3特性查询(Feature Query: @supports)功能简介
2017/07/31 HTML / CSS
CSS3+HTML5+JS 实现一个块的收缩与展开动画效果
2020/11/17 HTML / CSS
NFL墨西哥官方商店:Tienda NFL
2017/11/28 全球购物
办公室前台的岗位职责
2013/12/20 职场文书
超市活动计划书
2014/04/24 职场文书
四风问题个人剖析材料
2014/10/07 职场文书
虎兄虎弟观后感
2015/06/12 职场文书
vue3如何优雅的实现移动端登录注册模块
2021/03/29 Vue.js
python 制作一个gui界面的翻译工具
2021/05/14 Python
pytorch 运行一段时间后出现GPU OOM的问题
2021/06/02 Python
Python字符串格式化方式
2022/04/07 Python
5个实用的JavaScript新特性
2022/06/16 Javascript