TensorFlow实现非线性支持向量机的实现方法


Posted in Python onApril 28, 2018

这里将加载iris数据集,创建一个山鸢尾花(I.setosa)的分类器。

# Nonlinear SVM Example
#----------------------------------
#
# This function wll illustrate how to
# implement the gaussian kernel on
# the iris dataset.
#
# Gaussian Kernel:
# K(x1, x2) = exp(-gamma * abs(x1 - x2)^2)

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
# 加载iris数据集,抽取花萼长度和花瓣宽度,分割每类的x_vals值和y_vals值
iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals = np.array([1 if y==0 else -1 for y in iris.target])
class1_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==1]
class1_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==1]
class2_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==-1]
class2_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==-1]

# Declare batch size
# 声明批量大小(偏向于更大批量大小)
batch_size = 150

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
prediction_grid = tf.placeholder(shape=[None, 2], dtype=tf.float32)

# Create variables for svm
b = tf.Variable(tf.random_normal(shape=[1,batch_size]))

# Gaussian (RBF) kernel
# 声明批量大小(偏向于更大批量大小)
gamma = tf.constant(-25.0)
sq_dists = tf.multiply(2., tf.matmul(x_data, tf.transpose(x_data)))
my_kernel = tf.exp(tf.multiply(gamma, tf.abs(sq_dists)))

# Compute SVM Model
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b), b)
y_target_cross = tf.matmul(y_target, tf.transpose(y_target))
second_term = tf.reduce_sum(tf.multiply(my_kernel, tf.multiply(b_vec_cross, y_target_cross)))
loss = tf.negative(tf.subtract(first_term, second_term))

# Gaussian (RBF) prediction kernel
# 创建一个预测核函数
rA = tf.reshape(tf.reduce_sum(tf.square(x_data), 1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid), 1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA, tf.multiply(2., tf.matmul(x_data, tf.transpose(prediction_grid)))), tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma, tf.abs(pred_sq_dist)))

# 声明一个准确度函数,其为正确分类的数据点的百分比
prediction_output = tf.matmul(tf.multiply(tf.transpose(y_target),b), pred_kernel)
prediction = tf.sign(prediction_output-tf.reduce_mean(prediction_output))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.squeeze(prediction), tf.squeeze(y_target)), tf.float32))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
batch_accuracy = []
for i in range(300):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = x_vals[rand_index]
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)

  acc_temp = sess.run(accuracy, feed_dict={x_data: rand_x,
                       y_target: rand_y,
                       prediction_grid:rand_x})
  batch_accuracy.append(acc_temp)

  if (i+1)%75==0:
    print('Step #' + str(i+1))
    print('Loss = ' + str(temp_loss))

# Create a mesh to plot points in
# 为了绘制决策边界(Decision Boundary),我们创建一个数据点(x,y)的网格,评估预测函数
x_min, x_max = x_vals[:, 0].min() - 1, x_vals[:, 0].max() + 1
y_min, y_max = x_vals[:, 1].min() - 1, x_vals[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
           np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]
[grid_predictions] = sess.run(prediction, feed_dict={x_data: rand_x,
                          y_target: rand_y,
                          prediction_grid: grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)

# Plot points and grid
plt.contourf(xx, yy, grid_predictions, cmap=plt.cm.Paired, alpha=0.8)
plt.plot(class1_x, class1_y, 'ro', label='I. setosa')
plt.plot(class2_x, class2_y, 'kx', label='Non setosa')
plt.title('Gaussian SVM Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5, 3.0])
plt.xlim([3.5, 8.5])
plt.show()

# Plot batch accuracy
plt.plot(batch_accuracy, 'k-', label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出:

Step #75
Loss = -110.332
Step #150
Loss = -222.832
Step #225
Loss = -335.332
Step #300
Loss = -447.832

四种不同的gamma值(1,10,25,100):

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

不同gamma值的山鸢尾花(I.setosa)的分类器结果图,采用高斯核函数的SVM。

gamma值越大,每个数据点对分类边界的影响就越大。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python脚本实时处理log文件的方法
Nov 21 Python
Python实现定时任务
Feb 08 Python
Python实现两个list对应元素相减操作示例
Jun 09 Python
Python 反转字符串(reverse)的方法小结
Feb 20 Python
使用50行Python代码从零开始实现一个AI平衡小游戏
Nov 21 Python
python pandas 时间日期的处理实现
Jul 30 Python
Django中create和save方法的不同
Aug 13 Python
python实现修改固定模式的字符串内容操作示例
Dec 30 Python
python中的 zip函数详解及用法举例
Feb 16 Python
python中导入 train_test_split提示错误的解决
Jun 19 Python
Python同时迭代多个序列的方法
Jul 28 Python
python机器学习创建基于规则聊天机器人过程示例详解
Nov 02 Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
TensorFlow实现随机训练和批量训练的方法
Apr 28 #Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
You might like
php 阴历-农历-转换类代码
2012/01/16 PHP
php中OR与|| AND与&&的区别总结
2013/10/26 PHP
PHP的伪随机数与真随机数详解
2015/05/27 PHP
中高级PHP程序员应该掌握哪些技术?
2016/09/23 PHP
php头像上传预览实例代码
2017/05/02 PHP
Javascript日期对象的dateAdd与dateDiff方法
2008/11/18 Javascript
基于jQuery的图片左右无缝滚动插件
2012/05/23 Javascript
jquery 之 $().hover(func1, funct2)使用方法
2012/06/14 Javascript
jquery动态添加删除div 具体实现
2013/07/20 Javascript
正负小数点后两位浮点数实现原理及代码
2013/09/06 Javascript
JavaScript运行过程中的“预编译阶段”和“执行阶段”
2015/12/16 Javascript
javascript css红色经典选项卡效果实现代码
2016/05/17 Javascript
浅谈window.onbeforeunload() 事件调用ajax
2016/06/29 Javascript
jquery拼接ajax 的json和字符串拼接的方法
2017/03/11 Javascript
Vue.js上下滚动加载组件的实例代码
2017/07/17 Javascript
vue中页面跳转拦截器的实现方法
2017/08/23 Javascript
[js高手之路]图解javascript的原型(prototype)对象,原型链实例
2017/08/28 Javascript
原生JS实现多个小球碰撞反弹效果示例
2018/01/31 Javascript
vue超时计算的组件实例代码
2018/07/09 Javascript
VUE 配置vue-devtools调试工具及安装方法
2018/09/30 Javascript
jquery使用FormData实现异步上传文件
2018/10/25 jQuery
微信小程序五子棋游戏的棋盘,重置,对弈实现方法【附demo源码下载】
2019/02/20 Javascript
jQuery实现的导航条点击后高亮显示功能示例
2019/03/04 jQuery
JS中async/await实现异步调用的方法
2019/08/28 Javascript
VUE 单页面使用 echart 窗口变化时的用法
2020/07/30 Javascript
[01:09:40]Newbee vs Pain 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
python3图片转换二进制存入mysql
2013/12/06 Python
python虚拟环境virtualenv的使用教程
2017/10/20 Python
python 密码学示例——理解哈希(Hash)算法
2020/09/21 Python
Giglio美国站:意大利奢侈品购物网
2018/02/10 全球购物
请问如下代码执行后a和b的值分别是什么
2016/05/05 面试题
咖啡厅创业计划书范本
2014/01/22 职场文书
李开复演讲稿
2014/05/24 职场文书
读群众路线的心得体会
2014/09/03 职场文书
企业党的群众路线教育实践活动学习心得体会
2014/10/31 职场文书
运动会加油稿50字
2015/07/21 职场文书