TensorFlow实现非线性支持向量机的实现方法


Posted in Python onApril 28, 2018

这里将加载iris数据集,创建一个山鸢尾花(I.setosa)的分类器。

# Nonlinear SVM Example
#----------------------------------
#
# This function wll illustrate how to
# implement the gaussian kernel on
# the iris dataset.
#
# Gaussian Kernel:
# K(x1, x2) = exp(-gamma * abs(x1 - x2)^2)

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
# 加载iris数据集,抽取花萼长度和花瓣宽度,分割每类的x_vals值和y_vals值
iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals = np.array([1 if y==0 else -1 for y in iris.target])
class1_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==1]
class1_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==1]
class2_x = [x[0] for i,x in enumerate(x_vals) if y_vals[i]==-1]
class2_y = [x[1] for i,x in enumerate(x_vals) if y_vals[i]==-1]

# Declare batch size
# 声明批量大小(偏向于更大批量大小)
batch_size = 150

# Initialize placeholders
x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
prediction_grid = tf.placeholder(shape=[None, 2], dtype=tf.float32)

# Create variables for svm
b = tf.Variable(tf.random_normal(shape=[1,batch_size]))

# Gaussian (RBF) kernel
# 声明批量大小(偏向于更大批量大小)
gamma = tf.constant(-25.0)
sq_dists = tf.multiply(2., tf.matmul(x_data, tf.transpose(x_data)))
my_kernel = tf.exp(tf.multiply(gamma, tf.abs(sq_dists)))

# Compute SVM Model
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b), b)
y_target_cross = tf.matmul(y_target, tf.transpose(y_target))
second_term = tf.reduce_sum(tf.multiply(my_kernel, tf.multiply(b_vec_cross, y_target_cross)))
loss = tf.negative(tf.subtract(first_term, second_term))

# Gaussian (RBF) prediction kernel
# 创建一个预测核函数
rA = tf.reshape(tf.reduce_sum(tf.square(x_data), 1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid), 1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA, tf.multiply(2., tf.matmul(x_data, tf.transpose(prediction_grid)))), tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma, tf.abs(pred_sq_dist)))

# 声明一个准确度函数,其为正确分类的数据点的百分比
prediction_output = tf.matmul(tf.multiply(tf.transpose(y_target),b), pred_kernel)
prediction = tf.sign(prediction_output-tf.reduce_mean(prediction_output))
accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.squeeze(prediction), tf.squeeze(y_target)), tf.float32))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
batch_accuracy = []
for i in range(300):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = x_vals[rand_index]
  rand_y = np.transpose([y_vals[rand_index]])
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)

  acc_temp = sess.run(accuracy, feed_dict={x_data: rand_x,
                       y_target: rand_y,
                       prediction_grid:rand_x})
  batch_accuracy.append(acc_temp)

  if (i+1)%75==0:
    print('Step #' + str(i+1))
    print('Loss = ' + str(temp_loss))

# Create a mesh to plot points in
# 为了绘制决策边界(Decision Boundary),我们创建一个数据点(x,y)的网格,评估预测函数
x_min, x_max = x_vals[:, 0].min() - 1, x_vals[:, 0].max() + 1
y_min, y_max = x_vals[:, 1].min() - 1, x_vals[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
           np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]
[grid_predictions] = sess.run(prediction, feed_dict={x_data: rand_x,
                          y_target: rand_y,
                          prediction_grid: grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)

# Plot points and grid
plt.contourf(xx, yy, grid_predictions, cmap=plt.cm.Paired, alpha=0.8)
plt.plot(class1_x, class1_y, 'ro', label='I. setosa')
plt.plot(class2_x, class2_y, 'kx', label='Non setosa')
plt.title('Gaussian SVM Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5, 3.0])
plt.xlim([3.5, 8.5])
plt.show()

# Plot batch accuracy
plt.plot(batch_accuracy, 'k-', label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

输出:

Step #75
Loss = -110.332
Step #150
Loss = -222.832
Step #225
Loss = -335.332
Step #300
Loss = -447.832

四种不同的gamma值(1,10,25,100):

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

TensorFlow实现非线性支持向量机的实现方法 

不同gamma值的山鸢尾花(I.setosa)的分类器结果图,采用高斯核函数的SVM。

gamma值越大,每个数据点对分类边界的影响就越大。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
简单说明Python中的装饰器的用法
Apr 24 Python
Django中URL视图函数的一些高级概念介绍
Jul 20 Python
python微信跳一跳系列之棋子定位像素遍历
Feb 26 Python
python提取图像的名字*.jpg到txt文本的方法
May 10 Python
python leetcode 字符串相乘实例详解
Sep 03 Python
Python字符串匹配之6种方法的使用详解
Apr 08 Python
linux中如何使用python3获取ip地址
Jul 15 Python
python SVM 线性分类模型的实现
Jul 19 Python
Python对列表的操作知识点详解
Aug 20 Python
深入浅析Python 命令行模块 Click
Mar 11 Python
python 爬取B站原视频的实例代码
Sep 09 Python
使用Python获取字典键对应值的方法
Apr 26 Python
python 通过logging写入日志到文件和控制台的实例
Apr 28 #Python
Python实现合并同一个文件夹下所有PDF文件的方法示例
Apr 28 #Python
用TensorFlow实现多类支持向量机的示例代码
Apr 28 #Python
详谈python在windows中的文件路径问题
Apr 28 #Python
TensorFlow实现随机训练和批量训练的方法
Apr 28 #Python
对python中的logger模块全面讲解
Apr 28 #Python
详解PyTorch批训练及优化器比较
Apr 28 #Python
You might like
生成php程序的php代码
2008/04/07 PHP
php中文字母数字验证码实现代码
2008/04/25 PHP
php数组函数序列之array_flip() 将数组键名与值对调
2011/11/07 PHP
CodeIgniter图像处理类的深入解析
2013/06/17 PHP
php基于双向循环队列实现历史记录的前进后退等功能
2015/08/08 PHP
Mootools 1.2教程 函数
2009/09/15 Javascript
JS 实现Table相同行的单元格自动合并示例代码
2013/08/27 Javascript
jQuery 借助插件Lavalamp实现导航条动态美化效果
2013/09/27 Javascript
js unicode 编码解析关于数据转换为中文的两种方法
2014/04/21 Javascript
js实现jquery的offset()方法实例
2015/01/10 Javascript
javascript实现将文件保存到本地方法汇总
2015/07/26 Javascript
js字符串操作总结(必看篇)
2016/11/22 Javascript
bootstrap模态框消失问题的解决方法
2016/12/02 Javascript
jquery.validate.js 多个相同name的处理方式
2017/07/10 jQuery
Vue中组件之间数据的传递的示例代码
2017/09/08 Javascript
vue2.0 自定义组件的方法(vue组件的封装)
2018/06/05 Javascript
vuejs 动态添加input框的实例讲解
2018/08/24 Javascript
4个顶级JavaScript高级文本编辑器
2018/10/10 Javascript
JS事件流与事件处理程序实例分析
2019/08/16 Javascript
python简单获取数组元素个数的方法
2015/07/13 Python
Python实现获取磁盘剩余空间的2种方法
2017/06/07 Python
python opencv读mp4视频的实例
2018/12/07 Python
python 随机打乱 图片和对应的标签方法
2018/12/14 Python
python 实现多线程下载m3u8格式视频并使用fmmpeg合并
2019/11/15 Python
Python操作Excel工作簿的示例代码(\*.xlsx)
2020/03/23 Python
IE兼容css3圆角的实现代码
2011/07/21 HTML / CSS
纯HTML5+CSS3制作生日蛋糕代码
2016/11/16 HTML / CSS
几个Linux面试题笔试题
2016/08/01 面试题
大学旷课检讨书
2014/01/28 职场文书
财产公证书
2014/04/10 职场文书
小学生学雷锋演讲稿
2014/04/25 职场文书
综艺节目策划方案
2014/06/13 职场文书
图书馆标语
2014/06/19 职场文书
老人再婚离婚协议书范本
2014/10/27 职场文书
浅谈JavaScript浅拷贝和深拷贝
2021/11/07 Javascript
Python中re模块的元字符使用小结
2022/04/07 Python