TensorFlow实现iris数据集线性回归


Posted in Python onSeptember 07, 2018

本文将遍历批量数据点并让TensorFlow更新斜率和y截距。这次将使用Scikit Learn的内建iris数据集。特别地,我们将用数据点(x值代表花瓣宽度,y值代表花瓣长度)找到最优直线。选择这两种特征是因为它们具有线性关系,在后续结果中将会看到。本文将使用L2正则损失函数。

# 用TensorFlow实现线性回归算法
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Petal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# 批量大小
batch_size = 25

# Initialize 占位符
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 模型变量
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加线性模型,y=Ax+b
model_output = tf.add(tf.matmul(x_data, A), b)

# 声明L2损失函数,其为批量损失的平均值。
loss = tf.reduce_mean(tf.square(y_target - model_output))

# 声明优化器 学习率设为0.05
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 批量训练遍历迭代
# 迭代100次,每25次迭代输出变量值和损失值
loss_vec = []
for i in range(100):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))

# 抽取系数
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# 创建最佳拟合直线
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)

# 绘制两幅图
# 拟合的直线
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
# 迭代100次的L2正则损失函数
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

结果:

Step #25 A = [[ 1.93474162]] b = [[ 3.11190438]]
Loss = 1.21364
Step #50 A = [[ 1.48641717]] b = [[ 3.81004381]]
Loss = 0.945256
Step #75 A = [[ 1.26089203]] b = [[ 4.221035]]
Loss = 0.254756
Step #100 A = [[ 1.1693294]] b = [[ 4.47258472]]
Loss = 0.281654

TensorFlow实现iris数据集线性回归

TensorFlow实现iris数据集线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python的urllib库提交WEB表单
Feb 24 Python
python实现图片变亮或者变暗的方法
Jun 01 Python
浅谈python中的变量默认是什么类型
Sep 11 Python
python处理按钮消息的实例详解
Jul 11 Python
Python上下文管理器和with块详解
Sep 09 Python
python opencv之SURF算法示例
Feb 24 Python
Python cookbook(数据结构与算法)找出序列中出现次数最多的元素算法示例
Mar 15 Python
解决pip install的时候报错timed out的问题
Jun 12 Python
Python中存取文件的4种不同操作
Jul 02 Python
Pycharm最新激活码2019(推荐)
Dec 31 Python
python 实现仿微信聊天时间格式化显示的代码
Apr 17 Python
2020年10款优秀的Python第三方库,看看有你中意的吗?
Jan 12 Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
You might like
各种咖啡的英文名子是什么
2021/03/03 新手入门
实现“上一页”和“下一页按钮
2006/10/09 PHP
PHP中几个常用的魔术常量
2012/02/23 PHP
使用PHP静态变量当缓存的方法
2013/11/13 PHP
如何解决phpmyadmin导入数据库文件最大限制2048KB
2015/10/09 PHP
php mysql 封装类实例代码
2016/09/18 PHP
javascript 跨浏览器开发经验总结(五) js 事件
2010/05/19 Javascript
jqGrid jQuery 表格插件测试代码
2011/08/23 Javascript
javascript实现带节日和农历的日历特效
2015/02/01 Javascript
jQuery实现的漂亮表单效果代码
2015/08/18 Javascript
超赞的jQuery图片滑块动画特效代码汇总
2016/01/25 Javascript
浅述节点的创建及常见功能的实现
2016/12/15 Javascript
关于foreach循环中遇到的问题小结
2017/05/08 Javascript
JavaScript实现三级联动效果
2017/07/15 Javascript
详解使用Vue Router导航钩子与Vuex来实现后退状态保存
2017/09/11 Javascript
浅谈Node模块系统及其模式
2017/11/17 Javascript
javascript实现QQ空间相册展示源码
2017/12/12 Javascript
LayUi中接口传数据成功,表格不显示数据的解决方法
2018/08/19 Javascript
Element Input组件分析小结
2018/10/11 Javascript
基于JavaScript canvas绘制贝塞尔曲线
2018/12/25 Javascript
详解vue 命名视图
2019/08/14 Javascript
[04:27]2014DOTA2国际邀请赛 NAVI战队官方纪录片
2014/07/21 DOTA
python通过自定义isnumber函数判断字符串是否为数字的方法
2015/04/23 Python
Python图片裁剪实例代码(如头像裁剪)
2017/06/21 Python
python脚本执行CMD命令并返回结果的例子
2019/08/14 Python
python每5分钟从kafka中提取数据的例子
2019/12/23 Python
Python基于内置库pytesseract实现图片验证码识别功能
2020/02/24 Python
python实现发送邮件
2021/03/02 Python
类的核心特性有哪些
2014/01/01 面试题
中文教师求职信
2014/02/22 职场文书
最常使用的求职信
2014/05/25 职场文书
银行柜员求职自荐书
2014/06/18 职场文书
一份恶作剧的检讨书
2014/09/13 职场文书
幼儿园庆六一主持词
2015/06/30 职场文书
大学生社会实践感想
2015/08/11 职场文书
Springboot集成kafka高级应用实战分享
2022/08/14 Java/Android