Tensorflow使用支持向量机拟合线性回归


Posted in Python onSeptember 07, 2018

支持向量机可以用来拟合线性回归。

相同的最大间隔(maximum margin)的概念应用到线性回归拟合。代替最大化分割两类目标是,最大化分割包含大部分的数据点(x,y)。我们将用相同的iris数据集,展示用刚才的概念来进行花萼长度与花瓣宽度之间的线性拟合。

相关的损失函数类似于max(0,|yi-(Axi+b)|-ε)。ε这里,是间隔宽度的一半,这意味着如果一个数据点在该区域,则损失等于0。

# SVM Regression
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve support vector regression. We are going
# to find the line that has the maximum margin
# which INCLUDES as many points as possible
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Pedal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# Split data into train/test sets
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function
# = max(0, abs(target - predicted) + epsilon)
# 1/2 margin width parameter = epsilon
epsilon = tf.constant([0.5])
# Margin term in loss
loss = tf.reduce_mean(tf.maximum(0., tf.subtract(tf.abs(tf.subtract(model_output, y_target)), epsilon)))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.075)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
train_loss = []
test_loss = []
for i in range(200):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = np.transpose([x_vals_train[rand_index]])
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_train_loss = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_train]), y_target: np.transpose([y_vals_train])})
  train_loss.append(temp_train_loss)

  temp_test_loss = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_test]), y_target: np.transpose([y_vals_test])})
  test_loss.append(temp_test_loss)
  if (i+1)%50==0:
    print('-----------')
    print('Generation: ' + str(i+1))
    print('A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Train Loss = ' + str(temp_train_loss))
    print('Test Loss = ' + str(temp_test_loss))

# Extract Coefficients
[[slope]] = sess.run(A)
[[y_intercept]] = sess.run(b)
[width] = sess.run(epsilon)

# Get best fit line
best_fit = []
best_fit_upper = []
best_fit_lower = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)
 best_fit_upper.append(slope*i+y_intercept+width)
 best_fit_lower.append(slope*i+y_intercept-width)

# Plot fit with data
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='SVM Regression Line', linewidth=3)
plt.plot(x_vals, best_fit_upper, 'r--', linewidth=2)
plt.plot(x_vals, best_fit_lower, 'r--', linewidth=2)
plt.ylim([0, 10])
plt.legend(loc='lower right')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(train_loss, 'k-', label='Train Set Loss')
plt.plot(test_loss, 'r--', label='Test Set Loss')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.legend(loc='upper right')
plt.show()

输出结果:

-----------
Generation: 50
A = [[ 2.91328382]] b = [[ 1.18453276]]
Train Loss = 1.17104
Test Loss = 1.1143
-----------
Generation: 100
A = [[ 2.42788291]] b = [[ 2.3755331]]
Train Loss = 0.703519
Test Loss = 0.715295
-----------
Generation: 150
A = [[ 1.84078252]] b = [[ 3.40453291]]
Train Loss = 0.338596
Test Loss = 0.365562
-----------
Generation: 200
A = [[ 1.35343242]] b = [[ 4.14853334]]
Train Loss = 0.125198
Test Loss = 0.16121

 

Tensorflow使用支持向量机拟合线性回归

基于iris数据集(花萼长度和花瓣宽度)的支持向量机回归,间隔宽度为0.5

Tensorflow使用支持向量机拟合线性回归

每次迭代的支持向量机回归的损失值(训练集和测试集)

直观地讲,我们认为SVM回归算法试图把更多的数据点拟合到直线两边2ε宽度的间隔内。这时拟合的直线对于ε参数更有意义。如果选择太小的ε值,SVM回归算法在间隔宽度内不能拟合更多的数据点;如果选择太大的ε值,将有许多条直线能够在间隔宽度内拟合所有的数据点。作者更倾向于选取更小的ε值,因为在间隔宽度附近的数据点比远处的数据点贡献更少的损失。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python2.x版本中基本的中文编码问题解决
Oct 12 Python
Python调用SQLPlus来操作和解析Oracle数据库的方法
Apr 09 Python
详解python中executemany和序列的使用方法
Aug 12 Python
Python读取YUV文件,并显示的方法
Dec 04 Python
Python图像滤波处理操作示例【基于ImageFilter类】
Jan 03 Python
Python字典遍历操作实例小结
Mar 05 Python
使用selenium模拟登录解决滑块验证问题的实现
May 10 Python
Python如何爬取微信公众号文章和评论(基于 Fiddler 抓包分析)
Jun 28 Python
python关于矩阵重复赋值覆盖问题的解决方法
Jul 19 Python
Python爬虫实现使用beautifulSoup4爬取名言网功能案例
Sep 15 Python
Python使用Pyqt5实现简易浏览器(最新版本测试过)
Apr 27 Python
Python 如何对文件目录操作
Jul 10 Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
You might like
Terran剧情介绍
2020/03/14 星际争霸
通过具体程序来理解PHP里面的抽象类
2010/01/28 PHP
apache配置虚拟主机的方法详解
2013/06/17 PHP
php利用事务处理转账问题
2015/04/22 PHP
thinkphp5.1 文件引入路径问题及注意事项
2018/06/13 PHP
用户注册常用javascript代码
2009/08/29 Javascript
利用腾讯的ip地址库做ip物理地址定位
2010/07/24 Javascript
关闭时刷新父窗口两种方法
2014/05/07 Javascript
jquery 获取 outerHtml 包含当前节点本身的代码
2014/10/30 Javascript
Bootstrap三种表单布局的使用方法
2016/06/21 Javascript
在百度搜索结果中去除掉一些网站的资料(通过js控制不让显示)
2017/05/02 Javascript
vue实现引入本地json的方法分析
2018/07/12 Javascript
微信小程序实现录音功能
2019/11/22 Javascript
9个JavaScript日常开发小技巧
2020/10/06 Javascript
10个易被忽视但应掌握的Python基本用法
2015/04/01 Python
python中xrange用法分析
2015/04/15 Python
Python判断列表是否已排序的各种方法及其性能分析
2016/06/20 Python
Python3 伪装浏览器的方法示例
2017/11/23 Python
Python实现登陆文件验证方法
2018/10/06 Python
Python使用MyQR制作专属动态彩色二维码功能
2019/06/04 Python
Django实现web端tailf日志文件功能及实例详解
2019/07/28 Python
django ajax发送post请求的两种方法
2020/01/05 Python
python GUI库图形界面开发之PyQt5动态(可拖动控件大小)布局控件QSplitter详细使用方法与实例
2020/03/06 Python
Python多进程编程multiprocessing代码实例
2020/03/12 Python
TensorFLow 数学运算的示例代码
2020/04/21 Python
用 Python 制作地球仪的方法
2020/04/24 Python
Python PyQt5运行程序把输出信息展示到GUI图形界面上
2020/04/27 Python
PyCharm vs VSCode,作为python开发者,你更倾向哪种IDE呢?
2020/08/17 Python
函授本科毕业生自我鉴定
2013/10/16 职场文书
统计员岗位职责
2013/11/14 职场文书
结婚典礼证婚词
2014/01/11 职场文书
高一历史教学反思
2014/01/13 职场文书
数控专业毕业生自荐信范文
2014/03/04 职场文书
返乡农民工证明
2015/06/24 职场文书
2016入党积极分子考察评语
2015/12/01 职场文书
Win11怎么跳过联网验机 ?Win11跳过联网验机激活教程
2022/04/05 数码科技