Tensorflow使用支持向量机拟合线性回归


Posted in Python onSeptember 07, 2018

支持向量机可以用来拟合线性回归。

相同的最大间隔(maximum margin)的概念应用到线性回归拟合。代替最大化分割两类目标是,最大化分割包含大部分的数据点(x,y)。我们将用相同的iris数据集,展示用刚才的概念来进行花萼长度与花瓣宽度之间的线性拟合。

相关的损失函数类似于max(0,|yi-(Axi+b)|-ε)。ε这里,是间隔宽度的一半,这意味着如果一个数据点在该区域,则损失等于0。

# SVM Regression
#----------------------------------
#
# This function shows how to use TensorFlow to
# solve support vector regression. We are going
# to find the line that has the maximum margin
# which INCLUDES as many points as possible
#
# We will use the iris data, specifically:
# y = Sepal Length
# x = Pedal Width

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([x[3] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])

# Split data into train/test sets
train_indices = np.random.choice(len(x_vals), round(len(x_vals)*0.8), replace=False)
test_indices = np.array(list(set(range(len(x_vals))) - set(train_indices)))
x_vals_train = x_vals[train_indices]
x_vals_test = x_vals[test_indices]
y_vals_train = y_vals[train_indices]
y_vals_test = y_vals[test_indices]

# Declare batch size
batch_size = 50

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# Create variables for linear regression
A = tf.Variable(tf.random_normal(shape=[1,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))

# Declare model operations
model_output = tf.add(tf.matmul(x_data, A), b)

# Declare loss function
# = max(0, abs(target - predicted) + epsilon)
# 1/2 margin width parameter = epsilon
epsilon = tf.constant([0.5])
# Margin term in loss
loss = tf.reduce_mean(tf.maximum(0., tf.subtract(tf.abs(tf.subtract(model_output, y_target)), epsilon)))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.075)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
train_loss = []
test_loss = []
for i in range(200):
  rand_index = np.random.choice(len(x_vals_train), size=batch_size)
  rand_x = np.transpose([x_vals_train[rand_index]])
  rand_y = np.transpose([y_vals_train[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_train_loss = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_train]), y_target: np.transpose([y_vals_train])})
  train_loss.append(temp_train_loss)

  temp_test_loss = sess.run(loss, feed_dict={x_data: np.transpose([x_vals_test]), y_target: np.transpose([y_vals_test])})
  test_loss.append(temp_test_loss)
  if (i+1)%50==0:
    print('-----------')
    print('Generation: ' + str(i+1))
    print('A = ' + str(sess.run(A)) + ' b = ' + str(sess.run(b)))
    print('Train Loss = ' + str(temp_train_loss))
    print('Test Loss = ' + str(temp_test_loss))

# Extract Coefficients
[[slope]] = sess.run(A)
[[y_intercept]] = sess.run(b)
[width] = sess.run(epsilon)

# Get best fit line
best_fit = []
best_fit_upper = []
best_fit_lower = []
for i in x_vals:
 best_fit.append(slope*i+y_intercept)
 best_fit_upper.append(slope*i+y_intercept+width)
 best_fit_lower.append(slope*i+y_intercept-width)

# Plot fit with data
plt.plot(x_vals, y_vals, 'o', label='Data Points')
plt.plot(x_vals, best_fit, 'r-', label='SVM Regression Line', linewidth=3)
plt.plot(x_vals, best_fit_upper, 'r--', linewidth=2)
plt.plot(x_vals, best_fit_lower, 'r--', linewidth=2)
plt.ylim([0, 10])
plt.legend(loc='lower right')
plt.title('Sepal Length vs Pedal Width')
plt.xlabel('Pedal Width')
plt.ylabel('Sepal Length')
plt.show()

# Plot loss over time
plt.plot(train_loss, 'k-', label='Train Set Loss')
plt.plot(test_loss, 'r--', label='Test Set Loss')
plt.title('L2 Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('L2 Loss')
plt.legend(loc='upper right')
plt.show()

输出结果:

-----------
Generation: 50
A = [[ 2.91328382]] b = [[ 1.18453276]]
Train Loss = 1.17104
Test Loss = 1.1143
-----------
Generation: 100
A = [[ 2.42788291]] b = [[ 2.3755331]]
Train Loss = 0.703519
Test Loss = 0.715295
-----------
Generation: 150
A = [[ 1.84078252]] b = [[ 3.40453291]]
Train Loss = 0.338596
Test Loss = 0.365562
-----------
Generation: 200
A = [[ 1.35343242]] b = [[ 4.14853334]]
Train Loss = 0.125198
Test Loss = 0.16121

 

Tensorflow使用支持向量机拟合线性回归

基于iris数据集(花萼长度和花瓣宽度)的支持向量机回归,间隔宽度为0.5

Tensorflow使用支持向量机拟合线性回归

每次迭代的支持向量机回归的损失值(训练集和测试集)

直观地讲,我们认为SVM回归算法试图把更多的数据点拟合到直线两边2ε宽度的间隔内。这时拟合的直线对于ε参数更有意义。如果选择太小的ε值,SVM回归算法在间隔宽度内不能拟合更多的数据点;如果选择太大的ε值,将有许多条直线能够在间隔宽度内拟合所有的数据点。作者更倾向于选取更小的ε值,因为在间隔宽度附近的数据点比远处的数据点贡献更少的损失。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python调用windows api锁定计算机示例
Apr 17 Python
tensorflow实现KNN识别MNIST
Mar 12 Python
使用 Python 实现微信群友统计器的思路详解
Sep 26 Python
python 图像平移和旋转的实例
Jan 10 Python
在Python中如何传递任意数量的实参的示例代码
Mar 21 Python
网易有道2017内推编程题 洗牌(python)
Jun 19 Python
pyqt5对用qt designer设计的窗体实现弹出子窗口的示例
Jun 19 Python
python中with语句结合上下文管理器操作详解
Dec 19 Python
python多线程使用方法实例详解
Dec 30 Python
Django-xadmin后台导入json数据及后台显示信息图标和主题更改方式
Mar 11 Python
利用python批量爬取百度任意类别的图片的实现方法
Oct 07 Python
python使用pycharm安装pyqt5以及相关配置
Apr 22 Python
TensorFlow实现iris数据集线性回归
Sep 07 #Python
TensorFlow实现模型评估
Sep 07 #Python
使用tensorflow实现线性svm
Sep 07 #Python
Python多进程池 multiprocessing Pool用法示例
Sep 07 #Python
详解python while 函数及while和for的区别
Sep 07 #Python
使用TensorFlow实现SVM
Sep 06 #Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
You might like
分页详解 从此分页无忧(PHP+mysql)
2007/11/23 PHP
PHP访问MYSQL数据库封装类(附函数说明)
2010/12/04 PHP
php求一个网段开始与结束IP地址的方法
2015/07/09 PHP
PHP/HTML混写的四种方式总结
2017/02/27 PHP
Laravel 修改验证异常的响应格式实例代码详解
2020/05/25 PHP
JSQL SQLProxy 的 php 版本代码
2010/05/05 Javascript
JQuery 弹出框定位实现方法
2010/12/02 Javascript
AeroWindow 基于JQuery的弹出窗口插件
2011/06/27 Javascript
基于jQuery的图片剪切插件
2011/08/03 Javascript
五段实用的js高级技巧
2011/12/20 Javascript
web基于浏览器的本地存储方法应用
2012/11/27 Javascript
jQuery中:button选择器用法实例
2015/01/04 Javascript
JavaScript数组和循环详解
2015/04/27 Javascript
深入理解Angular2 模板语法
2016/08/07 Javascript
浅析Jquery操作select
2016/12/13 Javascript
javascript显示系统当前时间代码
2016/12/29 Javascript
解决vue-cli3 使用子目录部署问题
2018/07/19 Javascript
tweenjs缓动算法的使用实例分析
2019/08/26 Javascript
jQuery 判断元素是否存在然后按需加载内容的实现代码
2020/01/16 jQuery
[01:45]典藏宝瓶2+祈求者身心——这就是DOTA2TI9总奖金突破3000万美元的秘密
2019/07/21 DOTA
Python Tkinter简单布局实例教程
2014/09/03 Python
Python循环语句之break与continue的用法
2015/10/14 Python
Python简单获取网卡名称及其IP地址的方法【基于psutil模块】
2018/05/24 Python
python实现视频分帧效果
2019/05/31 Python
pyqt实现.ui文件批量转换为对应.py文件脚本
2019/06/19 Python
对python3 sort sorted 函数的应用详解
2019/06/27 Python
解决django-xadmin列表页filter关联对象搜索问题
2019/11/15 Python
搭建pypi私有仓库实现过程详解
2020/11/25 Python
css3旋转木马_动力节点Java学院整理
2017/07/12 HTML / CSS
基于Html5 canvas实现裁剪图片和马赛克功能及又拍云上传图片 功能
2019/07/09 HTML / CSS
阿根廷首家户外用品制造商和经销商:Montagne
2018/02/12 全球购物
学校对教师的评语
2014/04/28 职场文书
4s店活动策划方案
2014/08/25 职场文书
舞蹈社团活动总结
2015/05/07 职场文书
2015-2016年小学教导工作总结
2015/07/21 职场文书
java版 简单三子棋游戏
2022/05/04 Java/Android