详解用TensorFlow实现逻辑回归算法


Posted in Python onMay 02, 2018

本文将实现逻辑回归算法,预测低出生体重的概率。

# Logistic Regression
# 逻辑回归
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
# y = 0 or 1 = low birth weight
# x = demographic and medical history data

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csv


ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Obtain and prepare data for modeling
###

# name of data file
birth_weight_file = 'birth_weight.csv'

# download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):
  birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'
  birth_file = requests.get(birthdata_url)
  birth_data = birth_file.text.split('\r\n')
  birth_header = birth_data[0].split('\t')
  birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]
  with open(birth_weight_file, "w") as f:
    writer = csv.writer(f)
    writer.writerow(birth_header)
    writer.writerows(birth_data)
    f.close()

# read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:
   csv_reader = csv.reader(csvfile)
   birth_header = next(csv_reader)
   for row in csv_reader:
     birth_data.append(row)

birth_data = [[float(x) for x in row] for row in birth_data]

# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])

# set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)

# Split data into train/test = 80%/20%
# 分割数据集为测试集和训练集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Normalize by column (min-max norm)
# 将所有特征缩放到0和1区间(min-max缩放),逻辑回归收敛的效果更好
# 归一化特征
def normalize_cols(m):
  col_max = m.max(axis=0)
  col_min = m.min(axis=0)
  return (m-col_min) / (col_max - col_min)

x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))

###
# Define Tensorflow computational graph¶
###

# Declare batch size
batch_size = 25

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

###
# Train model
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Actual Prediction
# 除记录损失函数外,也需要记录分类器在训练集和测试集上的准确度。
# 所以创建一个返回准确度的预测函数
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)

# Training loop
# 开始遍历迭代训练,记录损失值和准确度
loss_vec = []
train_acc = []
test_acc = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = x_vals_train[rand_index]
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})
  train_acc.append(temp_acc_train)
  temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})
  test_acc.append(temp_acc_test)
  if (i+1)%300==0:
    print('Loss = ' + str(temp_loss))


###
# Display model performance
###

# 绘制损失和准确度
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()

# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

数据结果:

Loss = 0.845124
Loss = 0.658061
Loss = 0.471852
Loss = 0.643469
Loss = 0.672077

详解用TensorFlow实现逻辑回归算法

迭代1500次的交叉熵损失图

详解用TensorFlow实现逻辑回归算法

迭代1500次的测试集和训练集的准确度图

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用rpclib进行Python网络编程时的注释问题
May 06 Python
Python利用IPython提高开发效率
Aug 10 Python
Python编程实现的图片识别功能示例
Aug 03 Python
Django实现分页功能
Jul 02 Python
Python学习笔记之读取文件、OS模块、异常处理、with as语法示例
Jun 04 Python
对Python中TKinter模块中的Label组件实例详解
Jun 14 Python
浅谈django2.0 ForeignKey参数的变化
Aug 06 Python
python爬虫 Pyppeteer使用方法解析
Sep 28 Python
Pytorch训练过程出现nan的解决方式
Jan 02 Python
python使用opencv resize图像不进行插值的操作
Jul 05 Python
浅谈Python __init__.py的作用
Oct 28 Python
Python操作Excel的学习笔记
Feb 18 Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
Python3实现的简单验证码识别功能示例
May 02 #Python
利用Python在一个文件的头部插入数据的实例
May 02 #Python
You might like
php生成N个不重复的随机数实例
2013/11/12 PHP
php5.2的curl-bug 服务器被php进程卡死问题排查
2016/09/19 PHP
JavaScript语句可以不以;结尾的烦恼
2007/03/08 Javascript
用jscript实现新建word文档
2007/06/15 Javascript
javascript 判断数组是否已包含了某个元素的函数
2010/05/30 Javascript
JS实现闪动的title消息提醒效果
2014/06/20 Javascript
BootStrap实现鼠标悬停下拉列表功能
2017/02/17 Javascript
详解vue2路由vue-router配置(懒加载)
2017/04/08 Javascript
JavaScript闭包_动力节点Java学院整理
2017/06/27 Javascript
js中变量的连续赋值(实例讲解)
2017/07/08 Javascript
微信小程序实现选项卡功能
2020/06/19 Javascript
js+css实现打字效果
2020/06/24 Javascript
如何使用 vue-cli 创建模板项目
2020/11/19 Vue.js
python实现的各种排序算法代码
2013/03/04 Python
Python使用Mechanize模块编写爬虫的要点解析
2016/03/31 Python
Python变量和字符串详解
2017/04/29 Python
django 按时间范围查询数据库实例代码
2018/02/11 Python
使用Python快速搭建HTTP服务和文件共享服务的实例讲解
2018/06/04 Python
Python 正则表达式 re.match/re.search/re.sub的使用解析
2019/07/22 Python
面向对象学习之pygame坦克大战
2019/09/11 Python
利用Python校准本地时间的方法教程
2019/10/31 Python
windows下python安装pip方法详解
2020/02/10 Python
Django 允许局域网中的机器访问你的主机操作
2020/05/13 Python
Django ORM判断查询结果是否为空,判断django中的orm为空实例
2020/07/09 Python
python使用re模块爬取豆瓣Top250电影
2020/10/20 Python
Python自动化测试基础必备知识点总结
2021/02/07 Python
Ibatis中如何提高SQL Map的性能
2013/05/11 面试题
关键字throw与throws的用法差异
2016/11/22 面试题
报社实习生自荐信
2014/01/24 职场文书
电视购物广告词
2014/03/19 职场文书
咖啡店创业计划书范文
2014/09/15 职场文书
自主招生学校推荐信
2014/09/26 职场文书
喋血孤城观后感
2015/06/08 职场文书
劳动保障事务所个人工作总结
2015/08/12 职场文书
2015年物业管理员工工作总结
2015/10/15 职场文书
python爬不同图片分别保存在不同文件夹中的实现
2021/04/02 Python