TensorFlow实现Logistic回归


Posted in Python onSeptember 07, 2018

本文实例为大家分享了TensorFlow实现Logistic回归的具体代码,供大家参考,具体内容如下

1.导入模块

import numpy as np
import pandas as pd
from pandas import Series,DataFrame

from matplotlib import pyplot as plt
%matplotlib inline

#导入tensorflow
import tensorflow as tf

#导入MNIST(手写数字数据集)
from tensorflow.examples.tutorials.mnist import input_data

2.获取训练数据和测试数据

import ssl 
ssl._create_default_https_context = ssl._create_unverified_context

mnist = input_data.read_data_sets('./TensorFlow',one_hot=True)

test = mnist.test
test_images = test.images

train = mnist.train
images = train.images

3.模拟线性方程

#创建占矩阵位符X,Y
X = tf.placeholder(tf.float32,shape=[None,784])
Y = tf.placeholder(tf.float32,shape=[None,10])

#随机生成斜率W和截距b
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#根据模拟线性方程得出预测值
y_pre = tf.matmul(X,W)+b

#将预测值结果概率化
y_pre_r = tf.nn.softmax(y_pre)

4.构造损失函数

# -y*tf.log(y_pre_r) --->-Pi*log(Pi)  信息熵公式

cost = tf.reduce_mean(-tf.reduce_sum(Y*tf.log(y_pre_r),axis=1))

5.实现梯度下降,获取最小损失函数

#learning_rate:学习率,是进行训练时在最陡的梯度方向上所采取的「步」长;
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

6.TensorFlow初始化,并进行训练

#定义相关参数

#训练循环次数
training_epochs = 25
#batch 一批,每次训练给算法10个数据
batch_size = 10
#每隔5次,打印输出运算的结果
display_step = 5


#预定义初始化
init = tf.global_variables_initializer()

#开始训练
with tf.Session() as sess:
  #初始化
  sess.run(init)
  #循环训练次数
  for epoch in range(training_epochs):
    avg_cost = 0.
    #总训练批次total_batch =训练总样本量/每批次样本数量
    total_batch = int(train.num_examples/batch_size)
    for i in range(total_batch):
      #每次取出100个数据作为训练数据
      batch_xs,batch_ys = mnist.train.next_batch(batch_size)
      _, c = sess.run([optimizer,cost],feed_dict={X:batch_xs,Y:batch_ys})
      avg_cost +=c/total_batch
    if(epoch+1)%display_step == 0:
      print(batch_xs.shape,batch_ys.shape)
      print('epoch:','%04d'%(epoch+1),'cost=','{:.9f}'.format(avg_cost))
  print('Optimization Finished!')

  #7.评估效果
  # Test model
  correct_prediction = tf.equal(tf.argmax(y_pre_r,1),tf.argmax(Y,1))
  # Calculate accuracy for 3000 examples
  # tf.cast类型转换
  accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
  print("Accuracy:",accuracy.eval({X: mnist.test.images[:3000], Y: mnist.test.labels[:3000]}))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python列出一个文件夹及其子目录的所有文件
Jun 30 Python
Python遍历文件夹和读写文件的实现代码
Aug 28 Python
浅谈Python中带_的变量或函数命名
Dec 04 Python
Python3之简单搭建自带服务器的实例讲解
Jun 04 Python
python 定义n个变量方法 (变量声明自动化)
Nov 10 Python
基于sklearn实现Bagging算法(python)
Jul 11 Python
Django CBV与FBV原理及实例详解
Aug 12 Python
关于Python中的向量相加和numpy中的向量相加效率对比
Aug 26 Python
python多进程(加入进程池)操作常见案例
Oct 21 Python
python分布式编程实现过程解析
Nov 08 Python
使用python处理题库表格并转化为word形式的实现
Apr 14 Python
Python参数传递对象的引用原理解析
May 22 Python
tensorflow实现简单逻辑回归
Sep 07 #Python
Tensorflow使用支持向量机拟合线性回归
Sep 07 #Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
You might like
便携利器 — TECSUN PL-365简评
2021/03/02 无线电
使用php判断浏览器的类型和语言的函数代码
2013/02/28 PHP
php实现的一段简单概率相关代码
2016/05/30 PHP
PHP使用preg_split()分割特殊字符(元字符等)的方法分析
2017/02/04 PHP
php基于环形链表解决约瑟夫环问题示例
2017/11/07 PHP
thinkPHP5.1框架中Request类四种调用方式示例
2019/08/03 PHP
用javascript删除当前行,添加行(示例代码)
2013/11/25 Javascript
jQuery选择器源码解读(六):Sizzle选择器匹配逻辑分析
2015/03/31 Javascript
bootstrap3 兼容IE8浏览器!
2016/05/02 Javascript
BootStrap智能表单实战系列(九)表单图片上传的支持
2016/06/13 Javascript
artDialog+plupload实现多文件上传
2016/07/19 Javascript
jquery+css3问卷答题卡翻页动画效果示例
2016/10/26 Javascript
vue 点击按钮增加一行的方法
2018/09/07 Javascript
JS实现的杨辉三角【帕斯卡三角形】算法示例
2019/02/26 Javascript
深入浅析nuxt.js基于ssh的vue通用框架
2019/05/21 Javascript
JavaScript实现串行请求的示例代码
2020/09/14 Javascript
python读取html中指定元素生成excle文件示例
2014/04/03 Python
Python中List.index()方法的使用教程
2015/05/20 Python
python装饰器与递归算法详解
2016/02/18 Python
Python使用matplotlib实现绘制自定义图形功能示例
2018/01/18 Python
python用plt画图时,cmp设置方法
2018/12/13 Python
Python和Java的语法对比分析语法简洁上python的确完美胜出
2019/05/10 Python
在django中实现页面倒数几秒后自动跳转的例子
2019/08/16 Python
python3.6生成器yield用法实例分析
2019/08/23 Python
Python坐标线性插值应用实现
2019/11/13 Python
flask框架json数据的拿取和返回操作示例
2019/11/28 Python
解决python-docx打包之后找不到default.docx的问题
2020/02/13 Python
python 使用while循环输出*组成的菱形实例
2020/04/12 Python
pyecharts在数据可视化中的应用详解
2020/06/08 Python
CSS3 linear-gradient线性渐变生成加号和减号的方法
2017/11/21 HTML / CSS
意大利高端时尚买手店:Stefania Mode
2018/03/01 全球购物
Topshop法国官网:英国快速时尚品牌
2018/04/08 全球购物
舞会礼服和舞会鞋:PromGirl
2019/04/22 全球购物
英国最受欢迎的平价女士时装零售商:Roman Originals
2019/11/02 全球购物
Linux管理员面试经常问道的相关命令
2013/04/29 面试题
我的网上商城创业计划书
2013/12/26 职场文书