TensorFlow实现Logistic回归


Posted in Python onSeptember 07, 2018

本文实例为大家分享了TensorFlow实现Logistic回归的具体代码,供大家参考,具体内容如下

1.导入模块

import numpy as np
import pandas as pd
from pandas import Series,DataFrame

from matplotlib import pyplot as plt
%matplotlib inline

#导入tensorflow
import tensorflow as tf

#导入MNIST(手写数字数据集)
from tensorflow.examples.tutorials.mnist import input_data

2.获取训练数据和测试数据

import ssl 
ssl._create_default_https_context = ssl._create_unverified_context

mnist = input_data.read_data_sets('./TensorFlow',one_hot=True)

test = mnist.test
test_images = test.images

train = mnist.train
images = train.images

3.模拟线性方程

#创建占矩阵位符X,Y
X = tf.placeholder(tf.float32,shape=[None,784])
Y = tf.placeholder(tf.float32,shape=[None,10])

#随机生成斜率W和截距b
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#根据模拟线性方程得出预测值
y_pre = tf.matmul(X,W)+b

#将预测值结果概率化
y_pre_r = tf.nn.softmax(y_pre)

4.构造损失函数

# -y*tf.log(y_pre_r) --->-Pi*log(Pi)  信息熵公式

cost = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(y_pre_r),axis=1))

5.实现梯度下降,获取最小损失函数

#learning_rate:学习率,是进行训练时在最陡的梯度方向上所采取的「步」长;
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

6.TensorFlow初始化,并进行训练

#定义相关参数

#训练循环次数
training_epochs = 25
#batch 一批,每次训练给算法10个数据
batch_size = 10
#每隔5次,打印输出运算的结果
display_step = 5


#预定义初始化
init = tf.global_variables_initializer()

#开始训练
with tf.Session() as sess:
  #初始化
  sess.run(init)
  #循环训练次数
  for epoch in range(training_epochs):
    avg_cost = 0.
    #总训练批次total_batch =训练总样本量/每批次样本数量
    total_batch = int(train.num_examples/batch_size)
    for i in range(total_batch):
      #每次取出100个数据作为训练数据
      batch_xs,batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer,cost],feed_dict={X:batch_xs,Y:batch_ys})
      avg_cost +=c/total_batch
    if(epoch+1)%display_step == 0:
      print(batch_xs.shape,batch_ys.shape)
      print('epoch:','%04d'%(epoch+1),'cost=','{:.9f}'.format(avg_cost))
  print('Optimization Finished!')

  #7.评估效果
  # Test model
  correct_prediction = tf.equal(tf.argmax(y_pre_r,1),tf.argmax(Y,1))
  # Calculate accuracy for 3000 examples
  # tf.cast类型转换
  accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  print("Accuracy:",accuracy.eval({X: mnist.test.images[:3000], Y: mnist.test.labels[:3000]}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在 Python 应用中使用 MongoDB的方法
Jan 05 Python
Python字符串处理实现单词反转
Jun 14 Python
浅谈pandas中Dataframe的查询方法([], loc, iloc, at, iat, ix)
Apr 10 Python
python pandas.DataFrame选取、修改数据最好用.loc,.iloc,.ix实现
Jun 11 Python
可能是最全面的 Python 字符串拼接总结【收藏】
Jul 09 Python
FFT快速傅里叶变换的python实现过程解析
Oct 21 Python
彻底搞懂 python 中文乱码问题(深入分析)
Feb 28 Python
利用python生成照片墙的示例代码
Apr 09 Python
python小白切忌乱用表达式
May 29 Python
Python tkinter制作单机五子棋游戏
Sep 14 Python
详解scrapy内置中间件的顺序
Sep 28 Python
详解python模块pychartdir安装及导入问题
Oct 22 Python
tensorflow实现简单逻辑回归
Sep 07 #Python
Tensorflow使用支持向量机拟合线性回归
Sep 07 #Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
You might like
PHP+MYSQL的文章管理系统(一)
2006/10/09 PHP
PHP经典的给图片加水印程序
2006/12/06 PHP
PHP和Mysqlweb应用开发核心技术 第1部分 Php基础-1 开始了解php
2011/07/03 PHP
php自动加载机制的深入分析
2013/06/08 PHP
PHP设计模式之解释器模式的深入解析
2013/06/13 PHP
五款PHP代码重构工具推荐
2014/10/14 PHP
THINKPHP2.0到3.0有哪些改进之处
2015/01/04 PHP
PHP如何通过传引用的思想实现无限分类(代码简单)
2015/10/13 PHP
php入门教程之Zend Studio设置与开发实例
2016/09/09 PHP
PHP使用PDO、mysqli扩展实现与数据库交互操作详解
2019/07/20 PHP
PHP实现二维数组(或多维数组)转换成一维数组的常见方法总结
2019/12/04 PHP
Laravel 微信小程序后端搭建步骤详解
2019/11/26 PHP
JS通过分析userAgent属性来判断浏览器的类型及版本
2014/03/28 Javascript
jQuery鼠标悬停内容动画切换效果
2017/04/27 jQuery
关于javascript sort()排序你可能忽略的一点理解
2017/07/18 Javascript
详谈Node.js之操作文件系统
2017/08/29 Javascript
详解关于vue2.0工程发布上线操作步骤
2018/09/27 Javascript
浅谈HTTP 缓存的那些事儿
2018/10/17 Javascript
微信小程序上线发布流程图文详解
2019/05/06 Javascript
vue实现标签云效果的方法详解
2019/08/28 Javascript
python获取代理IP的实例分享
2018/05/07 Python
python计算日期之间的放假日期
2018/06/05 Python
Python列表(List)知识点总结
2019/02/18 Python
python交易记录链的实现过程详解
2019/07/03 Python
Python 日期时间datetime 加一天,减一天,加减一小时一分钟,加减一年
2020/04/16 Python
HTML5实现文件断点续传的方法
2017/01/04 HTML / CSS
健身场所或家用健身设备:Life Fitness
2017/11/01 全球购物
国贸类专业毕业生的求职信分享
2013/12/08 职场文书
学子宴答谢词
2014/01/25 职场文书
节能减耗标语
2014/06/21 职场文书
2015年世界环境日演讲稿
2015/03/18 职场文书
初中毕业感言300字
2015/07/31 职场文书
六一儿童节园长致辞
2015/07/31 职场文书
运动会主持人开幕词
2016/03/04 职场文书
如何理解python接口自动化之logging日志模块
2021/06/15 Python
DIY胆机必读:各国电子管评价
2022/04/06 无线电