用TensorFlow实现lasso回归和岭回归算法的示例


Posted in Python onMay 02, 2018

也有些正则方法可以限制回归算法输出结果中系数的影响,其中最常用的两种正则方法是lasso回归和岭回归。

lasso回归和岭回归算法跟常规线性回归算法极其相似,有一点不同的是,在公式中增加正则项来限制斜率(或者净斜率)。这样做的主要原因是限制特征对因变量的影响,通过增加一个依赖斜率A的损失函数实现。

对于lasso回归算法,在损失函数上增加一项:斜率A的某个给定倍数。我们使用TensorFlow的逻辑操作,但没有这些操作相关的梯度,而是使用阶跃函数的连续估计,也称作连续阶跃函数,其会在截止点跳跃扩大。一会就可以看到如何使用lasso回归算法。

对于岭回归算法,增加一个L2范数,即斜率系数的L2正则。

# LASSO and Ridge Regression
# lasso回归和岭回归
# 
# This function shows how to use TensorFlow to solve LASSO or 
# Ridge regression for 
# y = Ax + b
# 
# We will use the iris data, specifically: 
#  y = Sepal Length 
#  x = Petal Width

# import required libraries
import matplotlib.pyplot as plt
import sys
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops


# Specify 'Ridge' or 'LASSO'
regression_type = 'LASSO'

# clear out old graph
ops.reset_default_graph()

# Create graph
sess = tf.Session()

###
# Load iris data
###

# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

###
# Model Parameters
###

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# make results reproducible
seed = 13
np.random.seed(seed)
tf.set_random_seed(seed)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

###
# Loss Functions
###

# Select appropriate loss function based on regression type

if regression_type == 'LASSO':
  # Declare Lasso loss function
  # 增加损失函数,其为改良过的连续阶跃函数,lasso回归的截止点设为0.9。
  # 这意味着限制斜率系数不超过0.9
  # Lasso Loss = L2_Loss + heavyside_step,
  # Where heavyside_step ~ 0 if A < constant, otherwise ~ 99
  lasso_param = tf.constant(0.9)
  heavyside_step = tf.truediv(1., tf.add(1., tf.exp(tf.multiply(-50., tf.subtract(A, lasso_param)))))
  regularization_param = tf.multiply(heavyside_step, 99.)
  loss = tf.add(tf.reduce_mean(tf.square(y_target - model_output)), regularization_param)

elif regression_type == 'Ridge':
  # Declare the Ridge loss function
  # Ridge loss = L2_loss + L2 norm of slope
  ridge_param = tf.constant(1.)
  ridge_loss = tf.reduce_mean(tf.square(A))
  loss = tf.expand_dims(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), tf.multiply(ridge_param, ridge_loss)), 0)

else:
  print('Invalid regression_type parameter value',file=sys.stderr)


###
# Optimizer
###

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.001)
train_step = my_opt.minimize(loss)

###
# Run regression
###

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
for i in range(1500):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = np.transpose([x_vals[rand_index]])
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})
  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss[0])
  if (i+1)%300==0:
    print('Step #' + str(i+1) + ' A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Loss = ' + str(temp_loss))
    print('\n')

###
# Extract regression results
###

# Get the optimal coefficients
[slope] = sess.run(A)
[y_intercept] = sess.run(b)

# Get best fit line
best_fit = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)


###
# Plot results
###

# Plot regression line against data points
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='Best fit line', linewidth=3)
plt.legend(loc='upper left')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title(regression_type + ' Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出结果:

Step #300 A = [[ 0.77170753]] b = [[ 1.82499862]]
Loss = [[ 10.26473045]]
Step #600 A = [[ 0.75908542]] b = [[ 3.2220633]]
Loss = [[ 3.06292033]]
Step #900 A = [[ 0.74843585]] b = [[ 3.9975822]]
Loss = [[ 1.23220456]]
Step #1200 A = [[ 0.73752165]] b = [[ 4.42974091]]
Loss = [[ 0.57872057]]
Step #1500 A = [[ 0.72942668]] b = [[ 4.67253113]]
Loss = [[ 0.40874988]]

用TensorFlow实现lasso回归和岭回归算法的示例 

用TensorFlow实现lasso回归和岭回归算法的示例

通过在标准线性回归估计的基础上,增加一个连续的阶跃函数,实现lasso回归算法。由于阶跃函数的坡度,我们需要注意步长,因为太大的步长会导致最终不收敛。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python按照多个字符对字符串进行分割的方法
Mar 17 Python
Python浅拷贝与深拷贝用法实例
May 09 Python
在Python的Django框架中包装视图函数
Jul 20 Python
Python实现的堆排序算法示例
Apr 29 Python
Python爬虫——爬取豆瓣电影Top250代码实例
Apr 17 Python
Python Django的安装配置教程图文详解
Jul 17 Python
详解用Python爬虫获取百度企业信用中企业基本信息
Jul 02 Python
Python爬虫之Selenium库的使用方法
Jan 03 Python
Python爬取某平台短视频的方法
Feb 08 Python
python机器学习Github已达8.9Kstars模型解释器LIME
Nov 23 Python
Python中的datetime包与time包包和模块详情
Feb 28 Python
方法汇总:Python 安装第三方库常用
Apr 26 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 #Python
详解用TensorFlow实现逻辑回归算法
May 02 #Python
Python获取指定字符前面的所有字符方法
May 02 #Python
Python 查找字符在字符串中的位置实例
May 02 #Python
python 巧用正则寻找字符串中的特定字符的位置方法
May 02 #Python
Python 在字符串中加入变量的实例讲解
May 02 #Python
Python 实现字符串中指定位置插入一个字符
May 02 #Python
You might like
一个简单计数器的源代码
2006/10/09 PHP
PHP序列号生成函数和字符串替换函数代码
2012/06/07 PHP
php 判断数组是几维数组
2013/03/20 PHP
解析php获取字符串的编码格式的方法(函数)
2013/06/21 PHP
codeigniter教程之上传视频并使用ffmpeg转flv示例
2014/02/13 PHP
php获取网页请求状态程序示例
2014/06/17 PHP
PHP基于yii框架实现生成ICO图标
2015/11/13 PHP
joomla组件开发入门教程
2016/05/04 PHP
PHP长连接实现与使用方法详解
2018/02/11 PHP
PHP实现15位身份证号转18位的方法分析
2019/10/16 PHP
Laravel 5.5 异常处理 &amp; 错误日志的解决
2019/10/17 PHP
JavaScript中URL编码函数代码
2011/01/11 Javascript
jQuery的slideToggle方法实例
2013/05/07 Javascript
js模拟hashtable的简单实例
2014/03/06 Javascript
上传文件返回的json数据会被提示下载问题解决方案
2014/12/03 Javascript
jQuery+HTML5美女瀑布流布局实现方法
2015/09/21 Javascript
jQuery使用$.ajax进行异步刷新的方法(附demo下载)
2015/12/04 Javascript
无需 Flash 使用 jQuery 复制文字到剪贴板
2016/04/26 Javascript
详谈JS中实现种子随机数及作用
2016/07/19 Javascript
关于JavaScript数组你所不知道的3件事
2016/08/24 Javascript
js实现拖拽功能
2017/03/01 Javascript
ES6中Symbol类型用法实例详解
2017/04/06 Javascript
JavaScript运动框架 解决防抖动问题、悬浮对联(二)
2017/05/17 Javascript
浅谈webpack 自动刷新与解析
2018/04/09 Javascript
百度小程序之间的页面通信过程详解
2019/07/18 Javascript
Python中的Matplotlib模块入门教程
2015/04/15 Python
Python AutoCAD 系统设置的实现方法
2020/04/01 Python
Python图片处理模块PIL操作方法(pillow)
2020/04/07 Python
canvas像素画板的实现代码
2018/11/21 HTML / CSS
html5读取本地文件示例代码
2014/04/22 HTML / CSS
温泉秘密:Onsen Secret
2020/07/06 全球购物
高中军训感想300字
2014/03/04 职场文书
领导班子个人对照检查材料(群众路线)
2014/09/26 职场文书
接待员岗位职责
2015/02/13 职场文书
收银员岗位职责范本
2015/04/07 职场文书
2015年销售部工作总结范文
2015/04/27 职场文书