TensorFlow实现非线性支持向量机的实现方法


Posted in Python onApril 28, 2018

这里将加载iris数据集,创建一个山鸢尾花(I.setosa)的分类器。

# Nonlinear SVM Example
#----------------------------------
#
# This function wll illustrate how to
# implement the gaussian kernel on
# the iris dataset.
#
# Gaussian Kernel:
# K(x1, x2) = exp(-gamma * abs(x1 - x2)^2)

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
# 加载iris数据集,抽取花萼长度和花瓣宽度,分割每类的x_vals值和y_vals值
iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals = np.array([1 if y==0 else -1 for y in iris.target])
class1_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==1]
class1_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==1]
class2_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==-1]
class2_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==-1]

# Declare batch size
# 声明批量大小(偏向于更大批量大小)
batch_size = 150

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
prediction_grid = tf.placeholder(shape=[None, 2], dtype=tf.float32)

# Create variables for svm
b = tf.Variable(tf.random_normal(shape=[1,batch_size]))

# Gaussian (RBF) kernel
# 声明批量大小(偏向于更大批量大小)
gamma = tf.constant(-25.0)
sq_dists = tf.multiply(2., tf.matmul(x_data, tf.transpose(x_data)))
my_kernel = tf.exp(tf.multiply(gamma, tf.abs(sq_dists)))

# Compute SVM Model
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b), b)
y_target_cross = tf.matmul(y_target, tf.transpose(y_target))
second_term = tf.reduce_sum(tf.multiply(my_kernel, tf.multiply(b_vec_cross, y_target_cross)))
loss = tf.negative(tf.subtract(first_term, second_term))

# Gaussian (RBF) prediction kernel
# 创建一个预测核函数
rA = tf.reshape(tf.reduce_sum(tf.square(x_data), 1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid), 1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA, tf.multiply(2., tf.matmul(x_data, tf.transpose(prediction_grid)))), tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma, tf.abs(pred_sq_dist)))

# 声明一个准确度函数,其为正确分类的数据点的百分比
prediction_output = tf.matmul(tf.multiply(tf.transpose(y_target),b), pred_kernel)
prediction = tf.sign(prediction_output-tf.reduce_mean(prediction_output))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.squeeze(prediction), tf.squeeze(y_target)), tf.float32))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
batch_accuracy = []
for i in range(300):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = x_vals[rand_index]
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)

  acc_temp = sess.run(accuracy, feed_dict={x_data: rand_x,
                       y_target: rand_y,
                       prediction_grid:rand_x})
  batch_accuracy.append(acc_temp)

  if (i+1)%75==0:
    print('Step #' + str(i+1))
    print('Loss = ' + str(temp_loss))

# Create a mesh to plot points in
# 为了绘制决策边界(Decision Boundary),我们创建一个数据点(x,y)的网格,评估预测函数
x_min, x_max = x_vals[:, 0].min() - 1, x_vals[:, 0].max() + 1
y_min, y_max = x_vals[:, 1].min() - 1, x_vals[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
           np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]
[grid_predictions] = sess.run(prediction, feed_dict={x_data: rand_x,
                          y_target: rand_y,
                          prediction_grid: grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)

# Plot points and grid
plt.contourf(xx, yy, grid_predictions, cmap=plt.cm.Paired, alpha=0.8)
plt.plot(class1_x, class1_y, 'ro', label='I. setosa')
plt.plot(class2_x, class2_y, 'kx', label='Non setosa')
plt.title('Gaussian SVM Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5, 3.0])
plt.xlim([3.5, 8.5])
plt.show()

# Plot batch accuracy
plt.plot(batch_accuracy, 'k-', label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出:

Step #75
Loss = -110.332
Step #150
Loss = -222.832
Step #225
Loss = -335.332
Step #300
Loss = -447.832

四种不同的gamma值(1,10,25,100):

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

不同gamma值的山鸢尾花(I.setosa)的分类器结果图,采用高斯核函数的SVM。

gamma值越大,每个数据点对分类边界的影响就越大。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python根据开头和结尾字符串获取中间字符串的方法
Mar 26 Python
Python中分数的相关使用教程
Mar 30 Python
Python的IDEL增加清屏功能实例
Jun 19 Python
浅谈python编译pyc工程--导包问题解决
Mar 20 Python
python实现图像检索的三种(直方图/OpenCV/哈希法)
Aug 08 Python
python实现静态服务器
Sep 05 Python
Python使用Pandas库常见操作详解
Jan 16 Python
深入浅析Python 函数注解与匿名函数
Feb 24 Python
Python flask框架实现查询数据库并显示数据
Jun 04 Python
Python通过zookeeper实现分布式服务代码解析
Jul 22 Python
PyQt QMainWindow的使用示例
Mar 24 Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
TensorFlow实现随机训练和批量训练的方法
Apr 28 #Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
You might like
通过对服务器端特性的配置加强php的安全
2006/10/09 PHP
Yii2 加载css、js 载静态资源的方法
2017/03/10 PHP
使用JQUERY Tabs插件宿主IFRAMES
2010/01/01 Javascript
JQuery 插件模板 制作jquery插件的朋友可以参考下
2010/03/17 Javascript
node.js中RPC(远程过程调用)的实现原理介绍
2014/12/05 Javascript
JavaScript多并发问题如何处理
2015/10/28 Javascript
JS基于ocanvas插件实现的简单画板效果代码(附demo源码下载)
2016/04/05 Javascript
JS中innerHTML和pasteHTML的区别实例分析
2016/06/22 Javascript
总结JavaScript的正则与其他语言的不同之处
2016/08/25 Javascript
微信小程序 购物车简单实例
2016/10/24 Javascript
jQuery中$.grep() 过滤函数 数组过滤
2016/11/22 Javascript
Vue概念及常见命令介绍(1)
2016/12/08 Javascript
浅谈react-router HashRouter和BrowserRouter的使用
2017/12/29 Javascript
Angular4 ElementRef的应用
2018/02/26 Javascript
详解Webpack loader 之 file-loader
2018/11/07 Javascript
JS使用iView的Dropdown实现一个右键菜单
2019/05/06 Javascript
基于redis的小程序登录实现方法流程分析
2020/05/25 Javascript
利用Anaconda完美解决Python 2与python 3的共存问题
2017/05/25 Python
Python实现动态加载模块、类、函数的方法分析
2017/07/18 Python
Python读取mat文件,并保存为pickle格式的方法
2018/10/23 Python
利用Python进行图像的加法,图像混合(附代码)
2019/07/14 Python
python解析多层json操作示例
2019/12/30 Python
Python HTMLTestRunner可视化报告实现过程解析
2020/04/10 Python
tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)
2020/04/22 Python
CSS3让登陆面板3D旋转起来
2016/05/03 HTML / CSS
苹果美国官方商城:Apple美国
2016/08/24 全球购物
选购国际女性时装设计师品牌:IFCHIC(支持中文)
2018/04/12 全球购物
日本最佳原创设计品牌:Felissimo(芬理希梦)
2019/03/19 全球购物
安德玛菲律宾官网:Under Armour菲律宾
2020/07/28 全球购物
北京天润融通.net面试题笔试题
2012/02/20 面试题
什么是方法的重载
2013/06/24 面试题
办公室综合文员岗位职责范本
2014/02/13 职场文书
小学生运动会报道稿
2014/09/12 职场文书
先进个人评语大全
2015/01/04 职场文书
承诺书模板大全
2015/05/04 职场文书
2019年励志签名:致拼搏路上的自己
2019/10/11 职场文书