用tensorflow构建线性回归模型的示例代码


Posted in Python onMarch 05, 2018

用tensorflow构建简单的线性回归模型是tensorflow的一个基础样例,但是原有的样例存在一些问题,我在实际调试的过程中做了一点自己的改进,并且有一些体会。

首先总结一下tf构建模型的总体套路

1、先定义模型的整体图结构,未知的部分,比如输入就用placeholder来代替。

2、再定义最后与目标的误差函数。

3、最后选择优化方法。

另外几个值得注意的地方是:

1、tensorflow构建模型第一步是先用代码搭建图模型,此时图模型是静止的,是不产生任何运算结果的,必须使用Session来驱动。

2、第二步根据问题的不同要求构建不同的误差函数,这个函数就是要求优化的函数。

3、调用合适的优化器优化误差函数,注意,此时反向传播调整参数的过程隐藏在了图模型当中,并没有显式显现出来。

4、tensorflow的中文意思是张量流动,也就是说有两个意思,一个是参与运算的不仅仅是标量或是矩阵,甚至可以是具有很高维度的张量,第二个意思是这些数据在图模型中流动,不停地更新。

5、session的run函数中,按照传入的操作向上查找,凡是操作中涉及的无论是变量、常量都要参与运算,占位符则要在run过程中以字典形式传入。

以上时tensorflow的一点认识,下面是关于梯度下降的一点新认识。

1、梯度下降法分为批量梯度下降和随机梯度下降法,第一种是所有数据都参与运算后,计算误差函数,根据此误差函数来更新模型参数,实际调试发现,如果定义误差函数为平方误差函数,这个值很快就会飞掉,原因是,批量平方误差都加起来可能会很大,如果此时学习率比较高,那么调整就会过,造成模型参数向一个方向大幅调整,造成最终结果发散。所以这个时候要降低学习率,让参数变化不要太快。

2、随机梯度下降法,每次用一个数据计算误差函数,然后更新模型参数,这个方法有可能会造成结果出现震荡,而且麻烦的是由于要一个个取出数据参与运算,而不是像批量计算那样采用了广播或者向量化乘法的机制,收敛会慢一些。但是速度要比使用批量梯度下降要快,原因是不需要每次计算全部数据的梯度了。比较折中的办法是mini-batch,也就是每次选用一小部分数据做梯度下降,目前这也是最为常用的方法了。

3、epoch概念:所有样本集过完一轮,就是一个epoch,很明显,如果是严格的随机梯度下降法,一个epoch内更新了样本个数这么多次参数,而批量法只更新了一次。

以上是我个人的一点认识,希望大家看到有不对的地方及时批评指针,不胜感激!

#encoding=utf-8 
__author__ = 'freedom' 
import tensorflow as tf 
import numpy as np 
 
def createData(dataNum,w,b,sigma): 
 train_x = np.arange(dataNum) 
 train_y = w*train_x+b+np.random.randn()*sigma 
 #print train_x 
 #print train_y 
 return train_x,train_y 
 
def linerRegression(train_x,train_y,epoch=100000,rate = 0.000001): 
 train_x = np.array(train_x) 
 train_y = np.array(train_y) 
 n = train_x.shape[0] 
 x = tf.placeholder("float") 
 y = tf.placeholder("float") 
 w = tf.Variable(tf.random_normal([1])) # 生成随机权重 
 b = tf.Variable(tf.random_normal([1])) 
 
 pred = tf.add(tf.mul(x,w),b) 
 loss = tf.reduce_sum(tf.pow(pred-y,2)) 
 optimizer = tf.train.GradientDescentOptimizer(rate).minimize(loss) 
 init = tf.initialize_all_variables() 
 
 sess = tf.Session() 
 sess.run(init) 
 print 'w start is ',sess.run(w) 
 print 'b start is ',sess.run(b) 
 for index in range(epoch): 
  #for tx,ty in zip(train_x,train_y): 
   #sess.run(optimizer,{x:tx,y:ty}) 
  sess.run(optimizer,{x:train_x,y:train_y}) 
  # print 'w is ',sess.run(w) 
  # print 'b is ',sess.run(b) 
  # print 'pred is ',sess.run(pred,{x:train_x}) 
  # print 'loss is ',sess.run(loss,{x:train_x,y:train_y}) 
  #print '------------------' 
 print 'loss is ',sess.run(loss,{x:train_x,y:train_y}) 
 w = sess.run(w) 
 b = sess.run(b) 
 return w,b 
 
def predictionTest(test_x,test_y,w,b): 
 W = tf.placeholder(tf.float32) 
 B = tf.placeholder(tf.float32) 
 X = tf.placeholder(tf.float32) 
 Y = tf.placeholder(tf.float32) 
 n = test_x.shape[0] 
 pred = tf.add(tf.mul(X,W),B) 
 loss = tf.reduce_mean(tf.pow(pred-Y,2)) 
 sess = tf.Session() 
 loss = sess.run(loss,{X:test_x,Y:test_y,W:w,B:b}) 
 return loss 
 
if __name__ == "__main__": 
 train_x,train_y = createData(50,2.0,7.0,1.0) 
 test_x,test_y = createData(20,2.0,7.0,1.0) 
 w,b = linerRegression(train_x,train_y) 
 print 'weights',w 
 print 'bias',b 
 loss = predictionTest(test_x,test_y,w,b) 
 print loss

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的print用法示例
Feb 11 Python
Python scikit-learn 做线性回归的示例代码
Nov 01 Python
Python实现图片滑动式验证识别方法
Nov 09 Python
win7+Python3.5下scrapy的安装方法
Jul 31 Python
使用python 写一个静态服务(实战)
Jun 28 Python
基于Python实现签到脚本过程解析
Oct 25 Python
Python魔法方法 容器部方法详解
Jan 02 Python
python统计字符串中字母出现次数代码实例
Mar 02 Python
PyQt5实现登录页面
May 30 Python
解析python 中/ 和 % 和 //(地板除)
Jun 28 Python
用python进行视频剪辑
Nov 02 Python
教你如何使用Python Tkinter库制作记事本
Jun 10 Python
详解python实现线程安全的单例模式
Mar 05 #Python
分析python动态规划的递归、非递归实现
Mar 04 #Python
python3.x上post发送json数据
Mar 04 #Python
python数据封装json格式数据
Mar 04 #Python
Python爬虫实例扒取2345天气预报
Mar 04 #Python
Python爬虫设置代理IP的方法(爬虫技巧)
Mar 04 #Python
浅析python实现scrapy定时执行爬虫
Mar 04 #Python
You might like
PHP语法速查表
2007/01/02 PHP
使用php清除bom示例
2014/03/03 PHP
php实现向javascript传递数组的方法
2015/07/27 PHP
Zend Framework教程之请求对象的封装Zend_Controller_Request实例详解
2016/03/07 PHP
Nigma vs AM BO3 第一场2.13
2021/03/10 DOTA
ASP中用Join和Array,可以加快字符连接速度的代码
2007/08/22 Javascript
JavaScript入门教程(2) JS基础知识
2009/01/31 Javascript
解析javascript系统错误:-1072896658的解决办法
2013/07/08 Javascript
JavaScript中的console.log()函数详细介绍
2014/12/29 Javascript
JS选项卡动态替换banner图片路径的方法
2015/05/11 Javascript
jQuery添加删除DOM元素方法详解
2016/01/18 Javascript
JavaScript仿网易选项卡制作代码
2016/10/06 Javascript
jquery中用jsonp实现搜索框功能
2016/10/18 Javascript
js实现带简单弹性运动的导航条
2017/02/22 Javascript
Vue2组件tree实现无限级树形菜单
2017/03/29 Javascript
vue实现微信分享功能
2018/11/28 Javascript
浅谈JavaScript面向对象--继承
2019/03/20 Javascript
用webpack4开发小程序的实现方法
2019/06/04 Javascript
[05:37]DOTA2-DPC中国联赛 正赛 Elephant vs iG 选手采访
2021/03/11 DOTA
获取python文件扩展名和文件名方法
2018/02/02 Python
python如何让类支持比较运算
2018/03/20 Python
python编写暴力破解zip文档程序的实例讲解
2018/04/24 Python
python实现两张图片的像素融合
2019/02/23 Python
Python多进程方式抓取基金网站内容的方法分析
2019/06/03 Python
python单例模式的多种实现方法
2019/07/26 Python
使用Python的networkx绘制精美网络图教程
2019/11/21 Python
美国知名的家庭连锁百货商店:Boscov’s
2017/07/27 全球购物
Spartoo比利时:欧洲时尚购物网站
2017/12/06 全球购物
荷兰鞋子在线:Nelson Schoenen
2017/12/25 全球购物
联想印度官方网上商店:Lenovo India
2019/08/24 全球购物
腾讯公司的一个sql题
2013/01/22 面试题
买房子个人收入证明
2014/01/16 职场文书
奥巴马竞选演讲稿
2014/05/15 职场文书
工作时间擅自离岗检讨书
2014/10/24 职场文书
2014年学校团委工作总结
2014/12/20 职场文书
元宵节寄语大全
2015/02/27 职场文书