详解用TensorFlow实现逻辑回归算法


Posted in Python onMay 02, 2018

本文将实现逻辑回归算法,预测低出生体重的概率。

# Logistic Regression
# 逻辑回归
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve logistic regression.
# y = sigmoid(Ax + b)
#
# We will use the low birth weight data, specifically:
# y = 0 or 1 = low birth weight
# x = demographic and medical history data

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import requests
from tensorflow.python.framework import ops
import os.path
import csv


ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Obtain and prepare data for modeling
###

# name of data file
birth_weight_file = 'birth_weight.csv'

# download data and create data file if file does not exist in current directory
if not os.path.exists(birth_weight_file):
  birthdata_url = 'https://github.com/nfmcclure/tensorflow_cookbook/raw/master/01_Introduction/07_Working_with_Data_Sources/birthweight_data/birthweight.dat'
  birth_file = requests.get(birthdata_url)
  birth_data = birth_file.text.split('\r\n')
  birth_header = birth_data[0].split('\t')
  birth_data = [[float(x) for x in y.split('\t') if len(x)>=1] for y in birth_data[1:] if len(y)>=1]
  with open(birth_weight_file, "w") as f:
    writer = csv.writer(f)
    writer.writerow(birth_header)
    writer.writerows(birth_data)
    f.close()

# read birth weight data into memory
birth_data = []
with open(birth_weight_file, newline='') as csvfile:
   csv_reader = csv.reader(csvfile)
   birth_header = next(csv_reader)
   for row in csv_reader:
     birth_data.append(row)

birth_data = [[float(x) for x in row] for row in birth_data]

# Pull out target variable
y_vals = np.array([x[0] for x in birth_data])
# Pull out predictor variables (not id, not target, and not birthweight)
x_vals = np.array([x[1:8] for x in birth_data])

# set for reproducible results
seed = 99
np.random.seed(seed)
tf.set_random_seed(seed)

# Split data into train/test = 80%/20%
# 分割数据集为测试集和训练集
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Normalize by column (min-max norm)
# 将所有特征缩放到0和1区间(min-max缩放),逻辑回归收敛的效果更好
# 归一化特征
def normalize_cols(m):
  col_max = m.max(axis=0)
  col_min = m.min(axis=0)
  return (m-col_min) / (col_max - col_min)

x_vals_train = np.nan_to_num(normalize_cols(x_vals_train))
x_vals_test = np.nan_to_num(normalize_cols(x_vals_test))

###
# Define Tensorflow computational graph¶
###

# Declare batch size
batch_size = 25

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 7], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[7,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=model_output, labels=y_target))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

###
# Train model
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Actual Prediction
# 除记录损失函数外,也需要记录分类器在训练集和测试集上的准确度。
# 所以创建一个返回准确度的预测函数
prediction = tf.round(tf.sigmoid(model_output))
predictions_correct = tf.cast(tf.equal(prediction, y_target), tf.float32)
accuracy = tf.reduce_mean(predictions_correct)

# Training loop
# 开始遍历迭代训练,记录损失值和准确度
loss_vec = []
train_acc = []
test_acc = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = x_vals_train[rand_index]
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)
  temp_acc_train = sess.run(accuracy, feed_dict={x_data: x_vals_train, y_target: np.transpose([y_vals_train])})
  train_acc.append(temp_acc_train)
  temp_acc_test = sess.run(accuracy, feed_dict={x_data: x_vals_test, y_target: np.transpose([y_vals_test])})
  test_acc.append(temp_acc_test)
  if (i+1)%300==0:
    print('Loss = ' + str(temp_loss))


###
# Display model performance
###

# 绘制损失和准确度
plt.plot(loss_vec, 'k-')
plt.title('Cross Entropy Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Cross Entropy Loss')
plt.show()

# Plot train and test accuracy
plt.plot(train_acc, 'k-', label='Train Set Accuracy')
plt.plot(test_acc, 'r--', label='Test Set Accuracy')
plt.title('Train and Test Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

数据结果:

Loss = 0.845124
Loss = 0.658061
Loss = 0.471852
Loss = 0.643469
Loss = 0.672077

详解用TensorFlow实现逻辑回归算法

迭代1500次的交叉熵损失图

详解用TensorFlow实现逻辑回归算法

迭代1500次的测试集和训练集的准确度图

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现k均值算法示例(k均值聚类算法)
Mar 16 Python
python实现端口转发器的方法
Mar 13 Python
python3.0 模拟用户登录,三次错误锁定的实例
Nov 02 Python
利用Python循环(包括while&for)各种打印九九乘法表的实例
Nov 06 Python
python如何派生内置不可变类型并修改实例化行为
Mar 21 Python
Python发送http请求解析返回json的实例
Mar 26 Python
浅谈python图片处理Image和skimage的区别
Aug 04 Python
pygame实现贪吃蛇游戏(上)
Oct 29 Python
基于TensorFlow常量、序列以及随机值生成实例
Jan 04 Python
使用Python实现牛顿法求极值
Feb 10 Python
后端开发使用pycharm的技巧(推荐)
Mar 27 Python
jupyter notebook运行命令显示[*](解决办法)
May 18 Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
Python3实现的简单验证码识别功能示例
May 02 #Python
利用Python在一个文件的头部插入数据的实例
May 02 #Python
You might like
PHP基础知识介绍
2013/09/17 PHP
PHP实现的简单sha1加密功能示例
2017/08/27 PHP
javascript offsetX与layerX区别
2010/03/12 Javascript
基于jquery的下拉框改变动态添加和删除表格实现代码
2020/09/12 Javascript
javascript 寻找错误方法整理
2014/06/15 Javascript
js防止DIV布局滚动时闪动的解决方法
2014/10/30 Javascript
jQuery实时显示鼠标指针位置和键盘ASCII码
2016/03/28 Javascript
bootstrap布局中input输入框右侧图标点击功能
2016/05/16 Javascript
基于JavaScript实现Tab选项卡切换效果
2016/11/24 Javascript
JavaScript设置名字输入不合法的实现方法
2017/05/23 Javascript
使用JavaScript实现alert的实例代码
2017/07/06 Javascript
vue.js框架实现表单排序和分页效果
2017/08/09 Javascript
React根据宽度自适应高度的示例代码
2017/10/11 Javascript
利用JS测试目标网站的打开响应速度
2017/12/01 Javascript
微信小程序开发搜索功能实现(前端+后端+数据库)
2020/03/04 Javascript
jQuery实现简单日历效果
2020/07/05 jQuery
[00:47]DOTA2荣耀之路6:玩不了啦!
2018/05/30 DOTA
python冒泡排序简单实现方法
2015/07/09 Python
Django应用程序中如何发送电子邮件详解
2017/02/04 Python
python cx_Oracle的基础使用方法(连接和增删改查)
2017/11/19 Python
Python中协程用法代码详解
2018/02/10 Python
详解如何用TensorFlow训练和识别/分类自定义图片
2019/08/05 Python
使用Python实现文字转语音并生成wav文件的例子
2019/08/08 Python
django 做 migrate 时 表已存在的处理方法
2019/08/31 Python
微信html5页面调用第三方位置导航的示例
2018/03/14 HTML / CSS
凯特王妃父母建立的派对用品网站:Party Pieces
2017/05/28 全球购物
List、Map、Set三个接口,存取元素时,各有什么特点?
2015/09/27 面试题
大三毕业自我鉴定
2014/01/15 职场文书
小学校园活动策划
2014/01/30 职场文书
中学学校门卫岗位职责
2014/08/15 职场文书
影视广告专业求职信
2014/09/02 职场文书
《水上飞机》教学反思
2016/02/20 职场文书
《牧场之国》教学反思
2016/02/22 职场文书
Canvas绘制像素风图片的示例代码
2021/09/25 HTML / CSS
Golang数据类型和相互转换
2022/04/12 Golang
Python 避免字典和元组的多重嵌套问题
2022/07/15 Python