TensorFlow实现创建分类器


Posted in Python onFebruary 06, 2018

本文实例为大家分享了TensorFlow实现创建分类器的具体代码,供大家参考,具体内容如下

创建一个iris数据集的分类器。

加载样本数据集,实现一个简单的二值分类器来预测一朵花是否为山鸢尾。iris数据集有三类花,但这里仅预测是否是山鸢尾。导入iris数据集和工具库,相应地对原数据集进行转换。

# Combining Everything Together
#----------------------------------
# This file will perform binary classification on the
# iris dataset. We will only predict if a flower is
# I.setosa or not.
#
# We will create a simple binary classifier by creating a line
# and running everything through a sigmoid to get a binary predictor.
# The two features we will use are pedal length and pedal width.
#
# We will use batch training, but this can be easily
# adapted to stochastic training.

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
import tensorflow as tf
from tensorflow.python.framework import ops
ops.reset_default_graph()

# 导入iris数据集
# 根据目标数据是否为山鸢尾将其转换成1或者0。
# 由于iris数据集将山鸢尾标记为0,我们将其从0置为1,同时把其他物种标记为0。
# 本次训练只使用两种特征:花瓣长度和花瓣宽度,这两个特征在x-value的第三列和第四列
# iris.target = {0, 1, 2}, where '0' is setosa
# iris.data ~ [sepal.width, sepal.length, pedal.width, pedal.length]
iris = datasets.load_iris()
binary_target = np.array([1. if x==0 else 0. for x in iris.target])
iris_2d = np.array([[x[2], x[3]] for x in iris.data])

# 声明批量训练大小
batch_size = 20

# 初始化计算图
sess = tf.Session()

# 声明数据占位符
x1_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
x2_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 声明模型变量
# Create variables A and b (0 = x1 - A*x2 + b)
A = tf.Variable(tf.random_normal(shape=[1, 1]))
b = tf.Variable(tf.random_normal(shape=[1, 1]))

# 定义线性模型:
# 如果找到的数据点在直线以上,则将数据点代入x2-x1*A-b计算出的结果大于0;
# 同理找到的数据点在直线以下,则将数据点代入x2-x1*A-b计算出的结果小于0。
# x1 - A*x2 + b
my_mult = tf.matmul(x2_data, A)
my_add = tf.add(my_mult, b)
my_output = tf.subtract(x1_data, my_add)

# 增加TensorFlow的sigmoid交叉熵损失函数(cross entropy)
xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=my_output, labels=y_target)

# 声明优化器方法
my_opt = tf.train.GradientDescentOptimizer(0.05)
train_step = my_opt.minimize(xentropy)

# 创建一个变量初始化操作
init = tf.global_variables_initializer()
sess.run(init)

# 运行迭代1000次
for i in range(1000):
  rand_index = np.random.choice(len(iris_2d), size=batch_size)
  # rand_x = np.transpose([iris_2d[rand_index]])
  # 传入三种数据:花瓣长度、花瓣宽度和目标变量
  rand_x = iris_2d[rand_index]
  rand_x1 = np.array([[x[0]] for x in rand_x])
  rand_x2 = np.array([[x[1]] for x in rand_x])
  #rand_y = np.transpose([binary_target[rand_index]])
  rand_y = np.array([[y] for y in binary_target[rand_index]])
  sess.run(train_step, feed_dict={x1_data: rand_x1, x2_data: rand_x2, y_target: rand_y})
  if (i+1)%200==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ', b = ' + str(sess.run(b)))


# 绘图
# 获取斜率/截距
# Pull out slope/intercept
[[slope]] = sess.run(A)
[[intercept]] = sess.run(b)

# 创建拟合线
x = np.linspace(0, 3, num=50)
ablineValues = []
for i in x:
 ablineValues.append(slope*i+intercept)

# 绘制拟合曲线
setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==1]
setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==1]
non_setosa_x = [a[1] for i,a in enumerate(iris_2d) if binary_target[i]==0]
non_setosa_y = [a[0] for i,a in enumerate(iris_2d) if binary_target[i]==0]
plt.plot(setosa_x, setosa_y, 'rx', ms=10, mew=2, label='setosa')
plt.plot(non_setosa_x, non_setosa_y, 'ro', label='Non-setosa')
plt.plot(x, ablineValues, 'b-')
plt.xlim([0.0, 2.7])
plt.ylim([0.0, 7.1])
plt.suptitle('Linear Separator For I.setosa', fontsize=20)
plt.xlabel('Petal Length')
plt.ylabel('Petal Width')
plt.legend(loc='lower right')
plt.show()

输出:

Step #200 A = [[ 8.70572948]], b = [[-3.46638322]]
Step #400 A = [[ 10.21302414]], b = [[-4.720438]]
Step #600 A = [[ 11.11844635]], b = [[-5.53361702]]
Step #800 A = [[ 11.86427212]], b = [[-6.0110755]]
Step #1000 A = [[ 12.49524498]], b = [[-6.29990339]]

TensorFlow实现创建分类器

 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 变量类型及命名规则介绍
Jun 08 Python
Windows系统下安装Python的SSH模块教程
Feb 05 Python
在Python操作时间和日期之asctime()方法的使用
May 22 Python
python字符类型的一些方法小结
May 16 Python
python3+PyQt5实现使用剪贴板做复制与粘帖示例
Jan 24 Python
Python判断一个list中是否包含另一个list全部元素的方法分析
Dec 24 Python
Python实现图像去噪方式(中值去噪和均值去噪)
Dec 18 Python
TensorFlow tf.nn.conv2d实现卷积的方式
Jan 03 Python
Python中常见的数制转换有哪些
May 27 Python
Python datetime模块使用方法小结
Jun 18 Python
Python数据可视化常用4大绘图库原理详解
Oct 23 Python
如何在向量化NumPy数组上进行移动窗口
May 18 Python
Python模拟随机游走图形效果示例
Feb 06 #Python
Python 12306抢火车票脚本 Python京东抢手机脚本
Feb 06 #Python
TensorFlow高效读取数据的方法示例
Feb 06 #Python
django使用xlwt导出excel文件实例代码
Feb 06 #Python
Python使用装饰器进行django开发实例代码
Feb 06 #Python
Python yield与实现方法代码分析
Feb 06 #Python
Django中间件工作流程及写法实例代码
Feb 06 #Python
You might like
使用MaxMind 根据IP地址对访问者定位
2006/10/09 PHP
关于BIG5-HKSCS的解决方法
2007/03/20 PHP
关于查看MSSQL 数据库 用户每个表 占用的空间大小
2013/06/21 PHP
PHP上传文件时文件过大$_FILES为空的解决方法
2013/11/26 PHP
详解php的socket通信
2015/08/11 PHP
PHP+Mysql+jQuery实现发布微博程序 php篇
2015/10/15 PHP
PHP经典面试题之设计模式(经常遇到)
2015/10/15 PHP
对比分析php中Cookie与Session的异同
2016/02/19 PHP
Zend Framework实现自定义过滤器的方法
2016/12/09 PHP
PHP数据库操作二:memcache用法分析
2017/08/16 PHP
Thinkphp 框架扩展之Widget扩展实现方法分析
2020/04/23 PHP
js post方式传递提交的实现代码
2010/05/31 Javascript
JavaScript运行时库属性一览表
2014/03/14 Javascript
基于jquery实现鼠标滚轮驱动的图片切换效果
2015/10/26 Javascript
html中鼠标滚轮事件onmousewheel的处理方法
2016/11/11 Javascript
react中使用swiper的具体方法
2018/05/15 Javascript
浅谈Webpack下多环境配置的思路
2018/06/27 Javascript
微信小程序提交form操作示例
2018/12/30 Javascript
React通过redux-persist持久化数据存储的方法示例
2019/02/14 Javascript
详解使用WebPack搭建React开发环境
2019/08/06 Javascript
详解vue中$nextTick和$forceUpdate的用法
2019/12/11 Javascript
python通过yield实现数组全排列的方法
2015/03/18 Python
利用Python将数值型特征进行离散化操作的方法
2018/11/06 Python
Python实现九宫格式的朋友圈功能内附“马云”朋友圈
2019/05/07 Python
django框架用户权限中的session缓存到redis中的方法
2019/08/06 Python
使用Tkinter制作信息提示框
2020/02/18 Python
用HTML5制作数字时钟的教程
2015/05/11 HTML / CSS
工作疏忽检讨书
2014/01/25 职场文书
新学期红领巾广播稿
2014/10/04 职场文书
2015年师德师风自我评价范文
2015/03/05 职场文书
五四青年节比赛演讲稿
2015/03/18 职场文书
2015仓库保管员年终工作总结
2015/05/13 职场文书
2015年社区宣传工作总结
2015/05/20 职场文书
2015初一年级组工作总结
2015/07/24 职场文书
golang 实现菜单树的生成方式
2021/04/28 Golang
JavaScript实例 ODO List分析
2022/01/22 Javascript