用TensorFlow实现lasso回归和岭回归算法的示例


Posted in Python onMay 02, 2018

也有些正则方法可以限制回归算法输出结果中系数的影响,其中最常用的两种正则方法是lasso回归和岭回归。

lasso回归和岭回归算法跟常规线性回归算法极其相似,有一点不同的是,在公式中增加正则项来限制斜率(或者净斜率)。这样做的主要原因是限制特征对因变量的影响,通过增加一个依赖斜率A的损失函数实现。

对于lasso回归算法,在损失函数上增加一项:斜率A的某个给定倍数。我们使用TensorFlow的逻辑操作,但没有这些操作相关的梯度,而是使用阶跃函数的连续估计,也称作连续阶跃函数,其会在截止点跳跃扩大。一会就可以看到如何使用lasso回归算法。

对于岭回归算法,增加一个L2范数,即斜率系数的L2正则。

# LASSO and Ridge Regression
# lasso回归和岭回归
# 
# This function shows how to use TensorFlow to solve LASSO or 
# Ridge regression for 
# y = Ax + b
# 
# We will use the iris data, specifically: 
#  y = Sepal Length 
#  x = Petal Width

# import required libraries
import matplotlib.pyplot as plt
import sys
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops


# Specify 'Ridge' or 'LASSO'
regression_type = 'LASSO'

# clear out old graph
ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Load iris data
###

# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

###
# Model Parameters
###

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# make results reproducible
seed = 13
np.random.seed(seed)
tf.set_random_seed(seed)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

###
# Loss Functions
###

# Select appropriate loss function based on regression type

if regression_type == 'LASSO':
  # Declare Lasso loss function
  # 增加损失函数,其为改良过的连续阶跃函数,lasso回归的截止点设为0.9。
  # 这意味着限制斜率系数不超过0.9
  # Lasso Loss = L2_Loss + heavyside_step,
  # Where heavyside_step ~ 0 if A < constant, otherwise ~ 99
  lasso_param = tf.constant(0.9)
  heavyside_step = tf.truediv(1., tf.add(1., tf.exp(tf.multiply(-50., tf.subtract(A, lasso_param)))))
  regularization_param = tf.multiply(heavyside_step, 99.)
  loss = tf.add(tf.reduce_mean(tf.square(y_target - model_output)), regularization_param)

elif regression_type == 'Ridge':
  # Declare the Ridge loss function
  # Ridge loss = L2_loss + L2 norm of slope
  ridge_param = tf.constant(1.)
  ridge_loss = tf.reduce_mean(tf.square(A))
  loss = tf.expand_dims(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), tf.multiply(ridge_param, ridge_loss)), 0)

else:
  print('Invalid regression_type parameter value',file=sys.stderr)


###
# Optimizer
###

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.001)
train_step = my_opt.minimize(loss)

###
# Run regression
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss[0])
  if (i+1)%300==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))
    print('\n')

###
# Extract regression results
###

# Get the optimal coefficients
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# Get best fit line
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)


###
# Plot results
###

# Plot regression line against data points
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title(regression_type + ' Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出结果:

Step #300 A = [[ 0.77170753]] b = [[ 1.82499862]]
Loss = [[ 10.26473045]]
Step #600 A = [[ 0.75908542]] b = [[ 3.2220633]]
Loss = [[ 3.06292033]]
Step #900 A = [[ 0.74843585]] b = [[ 3.9975822]]
Loss = [[ 1.23220456]]
Step #1200 A = [[ 0.73752165]] b = [[ 4.42974091]]
Loss = [[ 0.57872057]]
Step #1500 A = [[ 0.72942668]] b = [[ 4.67253113]]
Loss = [[ 0.40874988]]

用TensorFlow实现lasso回归和岭回归算法的示例 

用TensorFlow实现lasso回归和岭回归算法的示例

通过在标准线性回归估计的基础上,增加一个连续的阶跃函数,实现lasso回归算法。由于阶跃函数的坡度,我们需要注意步长,因为太大的步长会导致最终不收敛。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python网络爬虫采集联想词示例
Feb 11 Python
Python最基本的数据类型以及对元组的介绍
Apr 14 Python
Python WXPY实现微信监控报警功能的代码
Oct 20 Python
Python3中条件控制、循环与函数的简易教程
Nov 21 Python
手把手教你用python抢票回家过年(代码简单)
Jan 21 Python
python实现快速排序的示例(二分法思想)
Mar 12 Python
解决python读取几千万行的大表内存问题
Jun 26 Python
python爱心表白 每天都是浪漫七夕!
Aug 18 Python
python 重命名轴索引的方法
Nov 10 Python
python实现串口自动触发工作的示例
Jul 02 Python
python GUI库图形界面开发之PyQt5单行文本框控件QLineEdit详细使用方法与实例
Feb 27 Python
win10+anaconda安装yolov5的方法及问题解决方案
Apr 29 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 #Python
详解用TensorFlow实现逻辑回归算法
May 02 #Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
You might like
PHP 编程安全性小结
2010/01/08 PHP
JS 对象介绍
2010/01/20 Javascript
jQuery获取CSS样式中的颜色值的问题,不同浏览器格式不同的解决办法
2013/05/13 Javascript
JQuery分别取得每行最后一列和最后一行的示例代码
2013/08/18 Javascript
Javascript中的方法和匿名方法实例详解
2015/06/13 Javascript
纯js实现瀑布流布局及ajax动态新增数据
2016/04/07 Javascript
js 转json格式的字符串为对象或数组(前后台)的方法
2016/11/02 Javascript
微信小程序中form 表单提交和取值实例详解
2017/04/20 Javascript
微信小程序制作表格的方法
2019/02/14 Javascript
原生JS实现随机点名项目的实例代码
2019/04/30 Javascript
vue路由权限校验功能的实现代码
2020/06/07 Javascript
Javascript表单序列化原理及实现代码详解
2020/10/30 Javascript
Vue多选列表组件深入详解
2021/03/02 Vue.js
[01:05]DOTA2完美大师赛趣味视频之选手教你打职业
2017/11/23 DOTA
用Python的Django框架完成视频处理任务的教程
2015/04/02 Python
总结Python编程中三条常用的技巧
2015/05/11 Python
解决python3在anaconda下安装caffe失败的问题
2017/06/15 Python
python3+PyQt5实现支持多线程的页面索引器应用程序
2018/04/20 Python
通过pycharm使用git的步骤(图文详解)
2019/06/13 Python
python GUI库图形界面开发之PyQt5信号与槽事件处理机制详细介绍与实例解析
2020/03/08 Python
使用Keras 实现查看model weights .h5 文件的内容
2020/06/09 Python
Python测试框架:pytest学习笔记
2020/10/20 Python
Django使用django-simple-captcha做验证码的实现示例
2021/01/07 Python
Html5实现首页动态视频背景的示例代码
2019/09/25 HTML / CSS
美国婴儿用品店:Babies”R”Us
2017/10/12 全球购物
吉尔德利巧克力公司:Ghirardelli Chocolate Company
2019/03/27 全球购物
英国电气世界:Electrical World
2019/09/08 全球购物
意大利奢侈品综合电商网站:MODES
2019/12/14 全球购物
自我鉴定书面格式
2014/01/13 职场文书
事业单位绩效考核实施方案
2014/03/27 职场文书
廉洁自律演讲稿
2014/05/22 职场文书
会计学毕业生求职信
2014/06/25 职场文书
2015年留守儿童工作总结
2015/05/22 职场文书
迎新生欢迎词2015
2015/07/16 职场文书
初中英语教师个人工作总结2015
2015/07/21 职场文书
利用python进行数据加载
2021/06/20 Python