TensorFlow实现iris数据集线性回归


Posted in Python onSeptember 07, 2018

本文将遍历批量数据点并让TensorFlow更新斜率和y截距。这次将使用Scikit Learn的内建iris数据集。特别地,我们将用数据点(x值代表花瓣宽度,y值代表花瓣长度)找到最优直线。选择这两种特征是因为它们具有线性关系,在后续结果中将会看到。本文将使用L2正则损失函数。

# 用TensorFlow实现线性回归算法
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve linear regression.
# y = Ax + b
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Petal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# 批量大小
batch_size = 25

# Initialize 占位符
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 模型变量
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# 增加线性模型,y=Ax+b
model_output = tf.add(tf.matmul(x_data, A), b)

# 声明L2损失函数,其为批量损失的平均值。
loss = tf.reduce_mean(tf.square(y_target - model_output))

# 声明优化器 学习率设为0.05
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(loss)

# 初始化变量
init = tf.global_variables_initializer()
sess.run(init)

# 批量训练遍历迭代
# 迭代100次,每25次迭代输出变量值和损失值
loss_vec = []
for i in range(100):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  if (i+1)%25==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))

# 抽取系数
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# 创建最佳拟合直线
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)

# 绘制两幅图
# 拟合的直线
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
# 迭代100次的L2正则损失函数
plt.plot(loss_vec, 'k-')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.show()

结果:

Step #25 A = [[ 1.93474162]] b = [[ 3.11190438]]
Loss = 1.21364
Step #50 A = [[ 1.48641717]] b = [[ 3.81004381]]
Loss = 0.945256
Step #75 A = [[ 1.26089203]] b = [[ 4.221035]]
Loss = 0.254756
Step #100 A = [[ 1.1693294]] b = [[ 4.47258472]]
Loss = 0.281654

TensorFlow实现iris数据集线性回归

TensorFlow实现iris数据集线性回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中操作时间之mktime()方法的使用教程
May 22 Python
Python获取暗黑破坏神3战网前1000命位玩家的英雄技能统计
Jul 04 Python
Python使用add_subplot与subplot画子图操作示例
Jun 01 Python
Django 创建新App及其常用命令的实现方法
Aug 04 Python
PyTorch中Tensor的维度变换实现
Aug 18 Python
python3使用GUI统计代码量
Sep 18 Python
Python上下文管理器类和上下文管理器装饰器contextmanager用法实例分析
Nov 07 Python
python实现快递价格查询系统
Mar 03 Python
解决pymysql cursor.fetchall() 获取不到数据的问题
May 15 Python
如何表示python中的相对路径
Jul 08 Python
详解Django中的FBV和CBV对比分析
Mar 01 Python
一劳永逸彻底解决pip install慢的办法
May 24 Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
You might like
PHP扩展编写点滴 技巧收集
2010/03/09 PHP
laravel 5 实现模板主题功能(续)
2015/03/02 PHP
php简单获取文件扩展名的方法
2015/03/24 PHP
PHP实现验证码校验功能
2017/11/16 PHP
thinkphp5 加载静态资源路径与常量的方法
2017/12/24 PHP
JavaScript中Object和Function的关系小结
2009/09/26 Javascript
解析javascript系统错误:-1072896658的解决办法
2013/07/08 Javascript
在每个匹配元素的外部插入新元素的方法
2013/12/20 Javascript
jquery解析xml字符串示例分享
2014/03/25 Javascript
js实现简单秒表走动的时钟特效
2020/03/25 Javascript
javascript url几种编码方式详解
2016/06/06 Javascript
Javascript之String对象详解
2016/06/08 Javascript
require.js 加载 vue组件 r.js 合并压缩的实例
2016/10/14 Javascript
详解JS: reduce方法实现 webpack多文件入口
2017/02/14 Javascript
原生node.js案例--前后台交互
2017/02/20 Javascript
浅谈React深度编程之受控组件与非受控组件
2017/12/26 Javascript
vue.js与后台数据交互的实例讲解
2018/08/08 Javascript
Vue.directive使用注意(小结)
2018/08/31 Javascript
react-navigation之动态修改title的内容
2018/09/26 Javascript
实例讲解JS中pop使用方法
2019/01/27 Javascript
python实现中文输出的两种方法
2015/05/09 Python
深入理解Python3 内置函数大全
2017/11/23 Python
使用Python爬取最好大学网大学排名
2018/02/24 Python
python并发编程多进程 互斥锁原理解析
2019/08/20 Python
pandas中ix的使用详细讲解
2020/03/09 Python
python中if及if-else如何使用
2020/06/02 Python
python 进程池pool使用详解
2020/10/15 Python
Html5百叶窗效果的示例代码
2017/12/11 HTML / CSS
英国领先的男装设计师服装购物网站:Mainline Menswear
2018/02/04 全球购物
Pandora德国官网:购买潘多拉手链、戒指、项链和耳环
2020/02/20 全球购物
中专毕业自我鉴定
2013/10/16 职场文书
运动会广播稿30字
2014/01/21 职场文书
应用心理学专业求职信
2014/08/04 职场文书
党的群众路线剖析材料
2014/10/09 职场文书
缓刑期间思想汇报范文
2014/10/10 职场文书
于丹讲座视频观后感
2015/06/15 职场文书