TensorFlow实现创建分类器


Posted in Python onFebruary 06, 2018

本文实例为大家分享了TensorFlow实现创建分类器的具体代码,供大家参考,具体内容如下

创建一个iris数据集的分类器。

加载样本数据集,实现一个简单的二值分类器来预测一朵花是否为山鸢尾。iris数据集有三类花,但这里仅预测是否是山鸢尾。导入iris数据集和工具库,相应地对原数据集进行转换。

# Combining Everything Together
#----------------------------------
# This file will perform binary classification on the
# iris dataset. We will only predict if a flower is
# I.setosa or not.
#
# We will create a simple binary classifier by creating a line
# and running everything through a sigmoid to get a binary predictor.
# The two features we will use are pedal length and pedal width.
#
# We will use batch training, but this can be easily
# adapted to stochastic training.

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 导入iris数据集
# 根据目标数据是否为山鸢尾将其转换成1或者0。
# 由于iris数据集将山鸢尾标记为0,我们将其从0置为1,同时把其他物种标记为0。
# 本次训练只使用两种特征:花瓣长度和花瓣宽度,这两个特征在x-value的第三列和第四列
# iris.target = {0, 1, 2}, where '0' is setosa
# iris.data ~ [sepal.width, sepal.length, pedal.width, pedal.length]
iris = datasets.load_iris()
binary_target = np.array([1. if x==0 else 0. for x in iris.target])
iris_2d = np.array([[x[2], x[3]] for x in iris.data])

# 声明批量训练大小
batch_size = 20

# 初始化计算图
sess = tf.Session()

# 声明数据占位符
x1_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
x2_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明模型变量
# Create variables A and b (0 = x1 - A*x2 + b)
A = tf.Variable(tf.random_normal(shape=[1, 1]))
b = tf.Variable(tf.random_normal(shape=[1, 1]))

# 定义线性模型:
# 如果找到的数据点在直线以上,则将数据点代入x2-x1*A-b计算出的结果大于0;
# 同理找到的数据点在直线以下,则将数据点代入x2-x1*A-b计算出的结果小于0。
# x1 - A*x2 + b
my_mult = tf.matmul(x2_data, A)
my_add = tf.add(my_mult, b)
my_output = tf.subtract(x1_data, my_add)

# 增加TensorFlow的sigmoid交叉熵损失函数(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output, labels=y_target)

# 声明优化器方法
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 创建一个变量初始化操作
init = tf.global_variables_initializer()
sess.run(init)

# 运行迭代1000次
for i in range(1000):
  rand_index = np.random.choice(len(iris_2d), size=batch_size)
  # rand_x = np.transpose([iris_2d[rand_index]])
  # 传入三种数据:花瓣长度、花瓣宽度和目标变量
  rand_x = iris_2d[rand_index]
  rand_x1 = np.array([[x[0]] for x in rand_x])
  rand_x2 = np.array([[x[1]] for x in rand_x])
  #rand_y = np.transpose([binary_target[rand_index]])
  rand_y = np.array([[y] for y in binary_target[rand_index]])
  sess.run(train_step, feed_dict={x1_data: rand_x1, x2_data: rand_x2, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ', b = ' + str(sess.run(b)))


# 绘图
# 获取斜率/截距
# Pull out slope/intercept
[[slope]] = sess.run(A)
[[intercept]] = sess.run(b)

# 创建拟合线
x = np.linspace(0, 3, num=50)
ablineValues = []
for i in x:
 ablineValues.append(slope*i+intercept)

# 绘制拟合曲线
setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==1]
setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==1]
non_setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==0]
non_setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==0]
plt.plot(setosa_x, setosa_y, 'rx', ms=10, mew=2, label='setosa')
plt.plot(non_setosa_x, non_setosa_y, 'ro', label='Non-setosa')
plt.plot(x, ablineValues, 'b-')
plt.xlim([0.0, 2.7])
plt.ylim([0.0, 7.1])
plt.suptitle('Linear Separator For I.setosa', fontsize=20)
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
plt.legend(loc='lower right')
plt.show()

输出:

Step #200 A = [[ 8.70572948]], b = [[-3.46638322]]
Step #400 A = [[ 10.21302414]], b = [[-4.720438]]
Step #600 A = [[ 11.11844635]], b = [[-5.53361702]]
Step #800 A = [[ 11.86427212]], b = [[-6.0110755]]
Step #1000 A = [[ 12.49524498]], b = [[-6.29990339]]

TensorFlow实现创建分类器

 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python获取文件版本信息、公司名和产品名的方法
Oct 05 Python
python3 模拟登录v2ex实例讲解
Jul 13 Python
python计算列表内各元素的个数实例
Jun 29 Python
python 对key为时间的dict排序方法
Oct 17 Python
Python实现常见的回文字符串算法
Nov 14 Python
python程序封装为win32服务的方法
Mar 07 Python
Python实现连接MySql数据库及增删改查操作详解
Apr 16 Python
在python image 中安装中文字体的实现方法
Aug 22 Python
python实现批量文件重命名
Oct 31 Python
python解析命令行参数的三种方法详解
Nov 29 Python
Python 中由 yield 实现异步操作
May 04 Python
Python列表的深复制和浅复制示例详解
Feb 12 Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
Django中间件工作流程及写法实例代码
Feb 06 #Python
You might like
如何使用动态共享对象的模式来安装PHP
2006/10/09 PHP
php中ob_get_length缓冲与获取缓冲长度实例
2014/11/20 PHP
php实现批量删除挂马文件及批量替换页面内容完整实例
2016/07/08 PHP
ThinkPHP 3.2.2实现事务操作的方法
2017/05/05 PHP
thinkphp5框架结合mysql实现微信登录和自定义分享链接与图文功能示例
2019/08/13 PHP
对laravel的session获取与存取方法详解
2019/10/08 PHP
js键盘事件的keyCode
2014/07/29 Javascript
javascript实现捕捉键盘上按下的键
2015/05/05 Javascript
JavaScript观察者模式(经典)
2015/12/09 Javascript
使用postMesssage()实现iframe跨域页面间的信息传递
2016/03/29 Javascript
总结在前端排序中遇到的问题
2016/07/19 Javascript
JS中script标签defer和async属性的区别详解
2016/08/12 Javascript
利用策略模式与装饰模式扩展JavaScript表单验证功能
2017/02/14 Javascript
原生JS实现《别踩白块》游戏(兼容IE)
2017/02/20 Javascript
AngularJS常见过滤器用法实例总结
2017/07/06 Javascript
js实现简单数字变动效果
2017/11/06 Javascript
vue实现点击当前标签高亮效果【推荐】
2018/06/22 Javascript
js实现移动端tab切换时下划线滑动效果
2019/09/08 Javascript
[26:50]2018完美盛典DOTA2表演赛
2018/12/17 DOTA
从零学Python之入门(四)运算
2014/05/27 Python
Python中的推导式使用详解
2015/06/03 Python
Python实现解析Bit Torrent种子文件内容的方法
2017/08/29 Python
CentOS 7 安装python3.7.1的方法及注意事项
2018/11/01 Python
Python实现打砖块小游戏代码实例
2019/05/18 Python
Python获取命令实时输出-原样彩色输出并返回输出结果的示例
2019/07/11 Python
举例讲解Python装饰器
2020/12/24 Python
Python如何实现单例模式
2016/06/03 面试题
会计与审计专业大专生求职信
2013/10/03 职场文书
暑期实习鉴定
2013/12/16 职场文书
业务员的岗位职责
2014/03/15 职场文书
运动会报道稿300字
2014/10/02 职场文书
小学优秀教师先进事迹材料
2014/12/16 职场文书
2015年物流客服工作总结
2015/07/27 职场文书
2016大学生入党积极分子心得体会
2016/01/06 职场文书
《合作意向书》怎么写?
2019/08/20 职场文书
高并发下Redis如何保持数据一致性(避免读后写)
2022/03/18 Redis