TensorFlow实现iris数据集线性回归


Posted in Python onSeptember 07, 2018

本文将遍历批量数据点并让TensorFlow更新斜率和y截距。这次将使用Scikit Learn的内建iris数据集。特别地,我们将用数据点(x值代表花瓣宽度,y值代表花瓣长度)找到最优直线。选择这两种特征是因为它们具有线性关系,在后续结果中将会看到。本文将使用L2正则损失函数。

# 用TensorFlow实现线性回归算法
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Petal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# 批量大小
batch_size = 25

# Initialize 占位符
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 模型变量
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加线性模型,y=Ax+b
model_output = tf.add(tf.matmul(x_data, A), b)

# 声明L2损失函数,其为批量损失的平均值。
loss = tf.reduce_mean(tf.square(y_target - model_output))

# 声明优化器 学习率设为0.05
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 批量训练遍历迭代
# 迭代100次,每25次迭代输出变量值和损失值
loss_vec = []
for i in range(100):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))

# 抽取系数
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# 创建最佳拟合直线
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)

# 绘制两幅图
# 拟合的直线
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
# 迭代100次的L2正则损失函数
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

结果:

Step #25 A = [[ 1.93474162]] b = [[ 3.11190438]]
Loss = 1.21364
Step #50 A = [[ 1.48641717]] b = [[ 3.81004381]]
Loss = 0.945256
Step #75 A = [[ 1.26089203]] b = [[ 4.221035]]
Loss = 0.254756
Step #100 A = [[ 1.1693294]] b = [[ 4.47258472]]
Loss = 0.281654

TensorFlow实现iris数据集线性回归

TensorFlow实现iris数据集线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入理解Python中变量赋值的问题
Jan 12 Python
mysql 之通过配置文件链接数据库
Aug 12 Python
Python实现的计算马氏距离算法示例
Apr 03 Python
pip命令无法使用的解决方法
Jun 12 Python
如何通过雪花算法用Python实现一个简单的发号器
Jul 03 Python
python matplotlib库绘制散点图例题解析
Aug 10 Python
python实现代码统计器
Sep 19 Python
python实现按关键字筛选日志文件
Dec 24 Python
使用Django清空数据库并重新生成
Apr 03 Python
python实现一个猜拳游戏
Apr 05 Python
解决pycharm修改代码后第一次运行不生效的问题
Feb 06 Python
python Protobuf定义消息类型知识点讲解
Mar 02 Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
You might like
postfixadmin忘记密码后的修改密码方法详解
2016/07/20 PHP
js chrome浏览器判断代码
2010/03/28 Javascript
Extjs中TabPane如何嵌套在其他网页中实现思路及代码
2013/01/27 Javascript
JS匀速运动演示示例代码
2013/11/26 Javascript
jQuery学习笔记之 Ajax操作篇(一) - 数据加载
2014/06/23 Javascript
JQuery异步获取返回值中文乱码的解决方法
2015/01/29 Javascript
JavaScript取得WEB安全颜色列表的方法
2015/07/14 Javascript
第一次接触神奇的Bootstrap网格系统
2016/07/27 Javascript
详谈js对url进行编码和解码(三种方式的区别)
2017/08/16 Javascript
如何通过setTimeout理解JS运行机制详解
2019/03/23 Javascript
微信小程序引入模块中wxml、wxss、js的方法示例
2019/08/09 Javascript
微信小程序实现比较功能的方法汇总(五种方法)
2020/03/07 Javascript
python实现查找两个字符串中相同字符并输出的方法
2015/07/11 Python
利用Python yagmail三行代码实现发送邮件
2018/05/11 Python
python实现简易内存监控
2018/06/21 Python
python+unittest+requests实现接口自动化的方法
2018/11/29 Python
Python面向对象基础入门之编码细节与注意事项
2018/12/11 Python
使用PIL(Python-Imaging)反转图像的颜色方法
2019/01/24 Python
对Python中画图时候的线类型详解
2019/07/07 Python
python与C、C++混编的四种方式(小结)
2019/07/15 Python
Python图像处理之图片文字识别功能(OCR)
2019/07/30 Python
django项目环境搭建及在虚拟机本地创建django项目的教程
2019/08/02 Python
twilio python自动拨打电话,播放自定义mp3音频的方法
2019/08/08 Python
pytorch多进程加速及代码优化方法
2019/08/19 Python
简单了解python协程的相关知识
2019/08/31 Python
解决pycharm中导入自己写的.py函数出错问题
2020/02/12 Python
小白教你PyCharm从下载到安装再到科学使用PyCharm2020最新激活码
2020/09/25 Python
Johnson Fitness澳大利亚:高级健身器材
2021/03/16 全球购物
索引覆盖(Index Covering)查询含义
2012/02/18 面试题
学雷锋先进个人事迹
2014/05/26 职场文书
医院护士党的群众路线教育实践活动对照检查材料思想汇报
2014/10/04 职场文书
企业工会工作总结2015
2015/05/13 职场文书
党性教育心得体会(共6篇)
2016/01/21 职场文书
关于应聘教师的自荐信
2016/01/28 职场文书
Django路由层如何获取正确的url
2021/07/15 Python
Nginx实现会话保持的两种方式
2022/03/18 Servers