使用TensorFlow实现SVM


Posted in Python onSeptember 06, 2018

较基础的SVM,后续会加上多分类以及高斯核,供大家参考。

Talk is cheap, show me the code

import tensorflow as tf
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TFSVM(BaseEstimator, ClassifierMixin):

 def __init__(self, 
  C = 1, kernel = 'linear', 
  learning_rate = 0.01, 
  training_epoch = 1000, 
  display_step = 50,
  batch_size = 50,
  random_state = 42):
  #参数列表
  self.svmC = C
  self.kernel = kernel
  self.learning_rate = learning_rate
  self.training_epoch = training_epoch
  self.display_step = display_step
  self.random_state = random_state
  self.batch_size = batch_size

 def reset_seed(self):
  #重置随机数
  tf.set_random_seed(self.random_state)
  np.random.seed(self.random_state)

 def random_batch(self, X, y):
  #调用随机子集,实现mini-batch gradient descent
  indices = np.random.randint(1, X.shape[0], self.batch_size)
  X_batch = X[indices]
  y_batch = y[indices]
  return X_batch, y_batch

 def _build_graph(self, X_train, y_train):
  #创建计算图
  self.reset_seed()

  n_instances, n_inputs = X_train.shape

  X = tf.placeholder(tf.float32, [None, n_inputs], name = 'X')
  y = tf.placeholder(tf.float32, [None, 1], name = 'y')

  with tf.name_scope('trainable_variables'):
   #决策边界的两个变量
   W = tf.Variable(tf.truncated_normal(shape = [n_inputs, 1], stddev = 0.1), name = 'weights')
   b = tf.Variable(tf.truncated_normal([1]), name = 'bias')

  with tf.name_scope('training'):
   #算法核心
   y_raw = tf.add(tf.matmul(X, W), b)
   l2_norm = tf.reduce_sum(tf.square(W))
   hinge_loss = tf.reduce_mean(tf.maximum(tf.zeros(self.batch_size, 1), tf.subtract(1., tf.multiply(y_raw, y))))
   svm_loss = tf.add(hinge_loss, tf.multiply(self.svmC, l2_norm))
   training_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(svm_loss)

  with tf.name_scope('eval'):
   #正确率和预测
   prediction_class = tf.sign(y_raw)
   correct_prediction = tf.equal(y, prediction_class)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  init = tf.global_variables_initializer()

  self._X = X; self._y = y
  self._loss = svm_loss; self._training_op = training_op
  self._accuracy = accuracy; self.init = init
  self._prediction_class = prediction_class
  self._W = W; self._b = b

 def _get_model_params(self):
  #获取模型的参数,以便存储
  with self._graph.as_default():
   gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
  return {gvar.op.name: value for gvar, value in zip(gvars, self._session.run(gvars))}

 def _restore_model_params(self, model_params):
  #保存模型的参数
  gvar_names = list(model_params.keys())
  assign_ops = {gvar_name: self._graph.get_operation_by_name(gvar_name + '/Assign') for gvar_name in gvar_names}
  init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
  feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
  self._session.run(assign_ops, feed_dict = feed_dict)

 def fit(self, X, y, X_val = None, y_val = None):
  #fit函数,注意要输入验证集
  n_batches = X.shape[0] // self.batch_size

  self._graph = tf.Graph()
  with self._graph.as_default():
   self._build_graph(X, y)

  best_loss = np.infty
  best_accuracy = 0
  best_params = None
  checks_without_progress = 0
  max_checks_without_progress = 20

  self._session = tf.Session(graph = self._graph)

  with self._session.as_default() as sess:
   self.init.run()

   for epoch in range(self.training_epoch):
    for batch_index in range(n_batches):
     X_batch, y_batch = self.random_batch(X, y)
     sess.run(self._training_op, feed_dict = {self._X:X_batch, self._y:y_batch})
    loss_val, accuracy_val = sess.run([self._loss, self._accuracy], feed_dict = {self._X: X_val, self._y: y_val})
    accuracy_train = self._accuracy.eval(feed_dict = {self._X: X_batch, self._y: y_batch})

    if loss_val < best_loss:
     best_loss = loss_val
     best_params = self._get_model_params()
     checks_without_progress = 0
    else:
     checks_without_progress += 1
     if checks_without_progress > max_checks_without_progress:
      break

    if accuracy_val > best_accuracy:
     best_accuracy = accuracy_val
     #best_params = self._get_model_params()

    if epoch % self.display_step == 0:
     print('Epoch: {}\tValidaiton loss: {:.6f}\tValidation Accuracy: {:.4f}\tTraining Accuracy: {:.4f}'
      .format(epoch, loss_val, accuracy_val, accuracy_train))
   print('Best Accuracy: {:.4f}\tBest Loss: {:.6f}'.format(best_accuracy, best_loss))
   if best_params:
    self._restore_model_params(best_params)
    self._intercept = best_params['trainable_variables/weights']
    self._bias = best_params['trainable_variables/bias']
   return self

 def predict(self, X):
  with self._session.as_default() as sess:
   return self._prediction_class.eval(feed_dict = {self._X: X})

 def _intercept(self):
  return self._intercept

 def _bias(self):
  return self._bias

实际运行效果如下(以Iris数据集为样本):

使用TensorFlow实现SVM 

画出决策边界来看看:

使用TensorFlow实现SVM 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python2.x和3.x下maketrans与translate函数使用上的不同
Apr 13 Python
Python 十六进制整数与ASCii编码字符串相互转换方法
Jul 09 Python
django从请求到响应的过程深入讲解
Aug 01 Python
通过python改变图片特定区域的颜色详解
Jul 15 Python
Pycharm连接远程服务器并实现远程调试的实现
Aug 02 Python
Python 图像对比度增强的几种方法(小结)
Sep 25 Python
matplotlib绘制多个子图(subplot)的方法
Dec 03 Python
使用python写一个自动浏览文章的脚本实例
Dec 05 Python
Python运行提示缺少模块问题解决方案
Apr 02 Python
Python爬虫爬取博客实现可视化过程解析
Jun 29 Python
用 Python 定义 Schema 并生成 Parquet 文件详情
Sep 25 Python
Pygame Rect区域位置的使用(图文)
Nov 17 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
python实现机器学习之元线性回归
Sep 06 #Python
Python import与from import使用及区别介绍
Sep 06 #Python
用python实现k近邻算法的示例代码
Sep 06 #Python
python K近邻算法的kd树实现
Sep 06 #Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
You might like
php 来访国内外IP判断代码并实现页面跳转
2009/12/18 PHP
PHP字符串函数系列之nl2br(),在字符串中的每个新行 (\n) 之前插入 HTML 换行符br
2011/11/10 PHP
PHP中比较时间大小实例
2014/08/21 PHP
Cygwin中安装PHP方法步骤
2015/07/04 PHP
PHP如何获取当前主机、域名、网址、路径、端口等参数
2017/06/09 PHP
基于Laravel实现的用户动态模块开发
2017/09/21 PHP
PHP上传图片到数据库并显示的实例代码
2019/12/20 PHP
js,jquery滚动/跳转页面到指定位置的实现思路
2014/06/03 Javascript
实现非常简单的js双向数据绑定
2015/11/06 Javascript
nodejs的压缩文件模块archiver用法示例
2017/01/18 NodeJs
JavaScript 异步调用
2017/10/25 Javascript
微信小程序分享功能onShareAppMessage(options)用法分析
2019/04/24 Javascript
判断文字超过2行添加展开按钮,未超过则不显示,溢出部分显示省略号
2019/04/28 Javascript
微信小程序前端promise封装代码实例
2019/08/24 Javascript
解决qrcode.js生成二维码时必须定义一个空div的问题
2020/07/09 Javascript
[16:04]DOTA2海涛带你玩炸弹 9月5日更新内容详解
2014/09/05 DOTA
Python字符串和文件操作常用函数分析
2015/04/08 Python
python代码 if not x: 和 if x is not None: 和 if not x is None:使用介绍
2016/09/21 Python
python实现图像全景拼接
2020/03/27 Python
Python函数参数分类原理详解
2020/05/28 Python
python利用pytesseract 实现本地识别图片文字
2020/12/14 Python
目前不被任何主流浏览器支持的CSS3属性汇总
2014/07/21 HTML / CSS
超30万乐谱下载:Musicnotes.com
2016/09/24 全球购物
ORLY官网:美国专业美甲一线品牌
2019/12/11 全球购物
某公司C#程序员面试题笔试题
2014/05/26 面试题
市场营销专业推荐信
2013/11/03 职场文书
生产内勤岗位职责
2013/12/07 职场文书
自荐书模板
2013/12/19 职场文书
理发店策划方案
2014/06/05 职场文书
绿色环保家庭事迹材料
2014/08/31 职场文书
民间借贷协议书范本
2014/10/01 职场文书
个人合伙协议书范本
2014/10/14 职场文书
2014年民政工作总结
2014/11/26 职场文书
2015年村计划生育工作总结
2015/04/28 职场文书
MySQL 数据丢失排查案例
2021/05/08 MySQL
基于Python实现将列表数据生成折线图
2022/03/23 Python