TensorFlow实现iris数据集线性回归


Posted in Python onSeptember 07, 2018

本文将遍历批量数据点并让TensorFlow更新斜率和y截距。这次将使用Scikit Learn的内建iris数据集。特别地,我们将用数据点(x值代表花瓣宽度,y值代表花瓣长度)找到最优直线。选择这两种特征是因为它们具有线性关系,在后续结果中将会看到。本文将使用L2正则损失函数。

# 用TensorFlow实现线性回归算法
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Petal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# 批量大小
batch_size = 25

# Initialize 占位符
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 模型变量
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加线性模型,y=Ax+b
model_output = tf.add(tf.matmul(x_data, A), b)

# 声明L2损失函数,其为批量损失的平均值。
loss = tf.reduce_mean(tf.square(y_target - model_output))

# 声明优化器 学习率设为0.05
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 批量训练遍历迭代
# 迭代100次,每25次迭代输出变量值和损失值
loss_vec = []
for i in range(100):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))

# 抽取系数
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# 创建最佳拟合直线
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)

# 绘制两幅图
# 拟合的直线
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
# 迭代100次的L2正则损失函数
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

结果:

Step #25 A = [[ 1.93474162]] b = [[ 3.11190438]]
Loss = 1.21364
Step #50 A = [[ 1.48641717]] b = [[ 3.81004381]]
Loss = 0.945256
Step #75 A = [[ 1.26089203]] b = [[ 4.221035]]
Loss = 0.254756
Step #100 A = [[ 1.1693294]] b = [[ 4.47258472]]
Loss = 0.281654

TensorFlow实现iris数据集线性回归

TensorFlow实现iris数据集线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
浅析Python中的join()方法的使用
May 19 Python
Python如何快速上手? 快速掌握一门新语言的方法
Nov 14 Python
python读取txt文件,去掉空格计算每行长度的方法
Dec 20 Python
Python3分析处理声音数据的例子
Aug 27 Python
Windows下pycharm创建Django 项目(虚拟环境)过程解析
Sep 16 Python
pandas factorize实现将字符串特征转化为数字特征
Dec 19 Python
开启Django博客的RSS功能的实现方法
Feb 17 Python
Python使用xlrd实现读取合并单元格
Jul 09 Python
python实现自动打卡的示例代码
Oct 10 Python
django注册用邮箱发送验证码的实现
Apr 18 Python
Python趣味实战之手把手教你实现举牌小人生成器
Jun 07 Python
python 进阶学习之python装饰器小结
Sep 04 Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
You might like
php注入实例
2006/10/09 PHP
改写ThinkPHP的U方法使其路由下分页正常
2014/07/02 PHP
PHP 使用redis简单示例分享
2015/03/05 PHP
php实现文本数据导入SQL SERVER
2015/05/17 PHP
php发送短信验证码完成注册功能
2015/11/24 PHP
JS图片浏览组件PhotoLook的公开属性方法介绍和进阶实例代码
2010/11/09 Javascript
JS中图片缓冲loading技术的实例代码
2013/08/29 Javascript
jquery鼠标停止移动事件
2013/12/21 Javascript
js获取鼠标点击的位置实现思路及代码
2014/05/09 Javascript
jQuery实现提示密码强度的代码
2015/07/15 Javascript
easyui Draggable组件实现拖动效果
2015/08/19 Javascript
AngularJS 简单应用实例
2016/07/28 Javascript
JavaScipt选取文档元素的方法(推荐)
2016/08/05 Javascript
JS实现图片剪裁并预览效果
2016/08/12 Javascript
Javascript ES6中数据类型Symbol的使用详解
2017/05/02 Javascript
iview日期控件,双向绑定日期格式的方法
2018/03/15 Javascript
JS插入排序简单理解与实现方法分析
2019/11/25 Javascript
[01:02:09]Liquid vs TNC 2019国际邀请赛淘汰赛 胜者组 BO3 第二场 8.21
2020/07/19 DOTA
Python3连接SQLServer、Oracle、MySql的方法
2018/06/28 Python
pytorch: tensor类型的构建与相互转换实例
2018/07/26 Python
python利用7z批量解压rar的实现
2019/08/07 Python
使用Python刷淘宝喵币(低阶入门版)
2019/10/30 Python
pytorch: Parameter 的数据结构实例
2019/12/31 Python
python实现滑雪者小游戏
2020/02/22 Python
Python3 mmap内存映射文件示例解析
2020/03/23 Python
基于python代码批量处理图片resize
2020/06/04 Python
基于Python下载网络图片方法汇总代码实例
2020/06/24 Python
python实现PolynomialFeatures多项式的方法
2021/01/06 Python
html5用video标签流式加载的实现
2020/05/20 HTML / CSS
英国最大的奢侈品零售网络商城:Flannels
2016/09/16 全球购物
装潢设计实习自我鉴定
2013/09/19 职场文书
自我评价是什么
2014/01/04 职场文书
解除合同协议书
2014/04/17 职场文书
2014城乡环境综合治理工作总结
2014/12/19 职场文书
深入理解redis中multi与pipeline
2021/06/02 Redis
Python常用配置文件ini、json、yaml读写总结
2021/07/09 Python