tensorflow实现逻辑回归模型


Posted in Python onSeptember 08, 2018

逻辑回归模型

逻辑回归是应用非常广泛的一个分类机器学习算法,它将数据拟合到一个logit函数(或者叫做logistic函数)中,从而能够完成对事件发生的概率进行预测。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#下载好的mnist数据集存在F:/mnist/data/中
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot = True)
print(mnist.train.num_examples)
print(mnist.test.num_examples)

trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels

print(type(trainimg))
print(trainimg.shape,)
print(trainlabel.shape,)
print(testimg.shape,)
print(testlabel.shape,)

nsample = 5
randidx = np.random.randint(trainimg.shape[0],size = nsample)

for i in randidx:
  curr_img = np.reshape(trainimg[i,:],(28,28))
  curr_label = np.argmax(trainlabel[i,:])
  plt.matshow(curr_img,cmap=plt.get_cmap('gray'))
  plt.title(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  print(""+str(i)+"th Training Data"+"label is"+str(curr_label))
  plt.show()


x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#
actv = tf.nn.softmax(tf.matmul(x,W)+b)
#计算损失
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
#学习率
learning_rate = 0.01
#随机梯度下降
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

#求1位置索引值 对比预测值索引与label索引是否一样,一样返回True
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
#tf.cast把True和false转换为float类型 0,1
#把所有预测结果加在一起求精度
accr = tf.reduce_mean(tf.cast(pred,"float"))
init = tf.global_variables_initializer()
"""
#测试代码 
sess = tf.InteractiveSession()
arr = np.array([[31,23,4,24,27,34],[18,3,25,4,5,6],[4,3,2,1,5,67]])
#返回数组的维数 2
print(tf.rank(arr).eval())
#返回数组的行列数 [3 6]
print(tf.shape(arr).eval())
#返回数组中每一列中最大元素的索引[0 0 1 0 0 2]
print(tf.argmax(arr,0).eval())
#返回数组中每一行中最大元素的索引[5 2 5]
print(tf.argmax(arr,1).eval()) 
J"""
#把所有样本迭代50次
training_epochs = 50
#每次迭代选择多少样本
batch_size = 100
display_step = 5

sess = tf.Session()
sess.run(init)

#循环迭代
for epoch in range(training_epochs):
  avg_cost = 0
  num_batch = int(mnist.train.num_examples/batch_size)
  for i in range(num_batch):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
    sess.run(optm,feed_dict = {x:batch_xs,y:batch_ys})
    feeds = {x:batch_xs,y:batch_ys}
    avg_cost += sess.run(cost,feed_dict = feeds)/num_batch

  if epoch % display_step ==0:
    feeds_train = {x:batch_xs,y:batch_ys}
    feeds_test = {x:mnist.test.images,y:mnist.test.labels}
    train_acc = sess.run(accr,feed_dict = feeds_train)
    test_acc = sess.run(accr,feed_dict = feeds_test)
    #每五个epoch打印一次信息
    print("Epoch:%03d/%03d cost:%.9f train_acc:%.3f test_acc: %.3f" %(epoch,training_epochs,avg_cost,train_acc,test_acc))

print("Done")

程序训练结果如下:

Epoch:000/050 cost:1.177228655 train_acc:0.800 test_acc: 0.855
Epoch:005/050 cost:0.440933891 train_acc:0.890 test_acc: 0.894
Epoch:010/050 cost:0.383387268 train_acc:0.930 test_acc: 0.905
Epoch:015/050 cost:0.357281335 train_acc:0.930 test_acc: 0.909
Epoch:020/050 cost:0.341473956 train_acc:0.890 test_acc: 0.913
Epoch:025/050 cost:0.330586549 train_acc:0.920 test_acc: 0.915
Epoch:030/050 cost:0.322370980 train_acc:0.870 test_acc: 0.916
Epoch:035/050 cost:0.315942993 train_acc:0.940 test_acc: 0.916
Epoch:040/050 cost:0.310728854 train_acc:0.890 test_acc: 0.917
Epoch:045/050 cost:0.306357428 train_acc:0.870 test_acc: 0.918
Done

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的重启关机程序实例
Aug 21 Python
用于统计项目中代码总行数的Python脚本分享
Apr 21 Python
python实现画一颗树和一片森林
Jun 25 Python
Python实现矩阵相乘的三种方法小结
Jul 26 Python
Python选择网卡发包及接收数据包
Apr 04 Python
python的pstuil模块使用方法总结
Jul 26 Python
Numpy与Pytorch 矩阵操作方式
Dec 27 Python
python中列表的含义及用法
May 26 Python
Numpy 多维数据数组的实现
Jun 18 Python
Python3基于plotly模块保存图片表格
Aug 03 Python
Django自带用户认证系统使用方法解析
Nov 12 Python
python中24小时制转换为12小时制的方法
Jun 18 Python
Django实现表单验证
Sep 08 #Python
python实现排序算法解析
Sep 08 #Python
TensorFlow实现Logistic回归
Sep 07 #Python
tensorflow实现简单逻辑回归
Sep 07 #Python
Tensorflow使用支持向量机拟合线性回归
Sep 07 #Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
You might like
example2.php
2006/10/09 PHP
php 修改zen-cart下单和付款流程以防止漏单
2010/03/08 PHP
如果文字过长,则将过长的部分变成省略号显示
2006/06/26 Javascript
jQuery 插件开发指南
2014/11/14 Javascript
jQuery+jsp实现省市县三级联动效果(附源码)
2015/12/03 Javascript
基于Javascript倒计时效果
2016/12/22 Javascript
bootstrap suggest搜索建议插件使用详解
2017/03/25 Javascript
jquery.guide.js新版上线操作向导镂空提示jQuery插件(推荐)
2017/05/20 jQuery
AngularJS实现tab选项卡的方法详解
2017/07/05 Javascript
Angularjs的启动过程分析
2017/07/18 Javascript
vue-cli配置文件——config篇
2018/01/04 Javascript
在vue项目中正确使用iconfont的方法
2018/09/28 Javascript
zepto.js 实时监听输入框的方法
2018/12/04 Javascript
Vue 后台管理类项目兼容IE9+的方法示例
2019/02/20 Javascript
vue组件数据传递、父子组件数据获取,slot,router路由功能示例
2019/03/19 Javascript
vue表单验证你真的会了吗?vue表单验证(form)validate
2019/04/07 Javascript
微信小程序常用赋值方法小结
2019/04/30 Javascript
使用Vue+Django+Ant Design做一个留言评论模块的示例代码
2020/06/01 Javascript
vue实现验证用户名是否可用
2021/01/20 Vue.js
python中enumerate的用法实例解析
2014/08/18 Python
python轻松查到删除自己的微信好友
2016/01/10 Python
Python机器学习算法库scikit-learn学习之决策树实现方法详解
2019/07/04 Python
PyCharm vs VSCode,作为python开发者,你更倾向哪种IDE呢?
2020/08/17 Python
教你使用Sublime text3搭建Python开发环境及常用插件安装另分享Sublime text3最新激活注册码
2020/11/12 Python
Lululemon英国官网:加拿大瑜伽服装品牌
2019/01/14 全球购物
伯克斯奥特莱斯:Burkes Outlet
2019/03/30 全球购物
香港中原电器网上商店:Chung Yuen
2019/06/26 全球购物
公共汽车、火车和飞机票的通用在线预订和销售平台:INFOBUS
2019/11/30 全球购物
学校运动会开幕演讲稿
2014/01/04 职场文书
大学生求职计划书
2014/04/30 职场文书
医院竞聘演讲稿
2014/05/16 职场文书
迁户口计划生育证明
2014/10/19 职场文书
2015年度优秀员工自荐书
2015/03/06 职场文书
小学生运动会广播
2015/08/19 职场文书
Nginx中break与last的区别详析
2021/03/31 Servers
Mybatis是这样防止sql注入的
2021/12/06 Java/Android