使用TensorFlow实现SVM


Posted in Python onSeptember 06, 2018

较基础的SVM,后续会加上多分类以及高斯核,供大家参考。

Talk is cheap, show me the code

import tensorflow as tf
from sklearn.base import BaseEstimator, ClassifierMixin
import numpy as np

class TFSVM(BaseEstimator, ClassifierMixin):

 def __init__(self, 
  C = 1, kernel = 'linear', 
  learning_rate = 0.01, 
  training_epoch = 1000, 
  display_step = 50,
  batch_size = 50,
  random_state = 42):
  #参数列表
  self.svmC = C
  self.kernel = kernel
  self.learning_rate = learning_rate
  self.training_epoch = training_epoch
  self.display_step = display_step
  self.random_state = random_state
  self.batch_size = batch_size

 def reset_seed(self):
  #重置随机数
  tf.set_random_seed(self.random_state)
  np.random.seed(self.random_state)

 def random_batch(self, X, y):
  #调用随机子集,实现mini-batch gradient descent
  indices = np.random.randint(1, X.shape[0], self.batch_size)
  X_batch = X[indices]
  y_batch = y[indices]
  return X_batch, y_batch

 def _build_graph(self, X_train, y_train):
  #创建计算图
  self.reset_seed()

  n_instances, n_inputs = X_train.shape

  X = tf.placeholder(tf.float32, [None, n_inputs], name = 'X')
  y = tf.placeholder(tf.float32, [None, 1], name = 'y')

  with tf.name_scope('trainable_variables'):
   #决策边界的两个变量
   W = tf.Variable(tf.truncated_normal(shape = [n_inputs, 1], stddev = 0.1), name = 'weights')
   b = tf.Variable(tf.truncated_normal([1]), name = 'bias')

  with tf.name_scope('training'):
   #算法核心
   y_raw = tf.add(tf.matmul(X, W), b)
   l2_norm = tf.reduce_sum(tf.square(W))
   hinge_loss = tf.reduce_mean(tf.maximum(tf.zeros(self.batch_size, 1), tf.subtract(1., tf.multiply(y_raw, y))))
   svm_loss = tf.add(hinge_loss, tf.multiply(self.svmC, l2_norm))
   training_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(svm_loss)

  with tf.name_scope('eval'):
   #正确率和预测
   prediction_class = tf.sign(y_raw)
   correct_prediction = tf.equal(y, prediction_class)
   accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

  init = tf.global_variables_initializer()

  self._X = X; self._y = y
  self._loss = svm_loss; self._training_op = training_op
  self._accuracy = accuracy; self.init = init
  self._prediction_class = prediction_class
  self._W = W; self._b = b

 def _get_model_params(self):
  #获取模型的参数,以便存储
  with self._graph.as_default():
   gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
  return {gvar.op.name: value for gvar, value in zip(gvars, self._session.run(gvars))}

 def _restore_model_params(self, model_params):
  #保存模型的参数
  gvar_names = list(model_params.keys())
  assign_ops = {gvar_name: self._graph.get_operation_by_name(gvar_name + '/Assign') for gvar_name in gvar_names}
  init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}
  feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}
  self._session.run(assign_ops, feed_dict = feed_dict)

 def fit(self, X, y, X_val = None, y_val = None):
  #fit函数,注意要输入验证集
  n_batches = X.shape[0] // self.batch_size

  self._graph = tf.Graph()
  with self._graph.as_default():
   self._build_graph(X, y)

  best_loss = np.infty
  best_accuracy = 0
  best_params = None
  checks_without_progress = 0
  max_checks_without_progress = 20

  self._session = tf.Session(graph = self._graph)

  with self._session.as_default() as sess:
   self.init.run()

   for epoch in range(self.training_epoch):
    for batch_index in range(n_batches):
     X_batch, y_batch = self.random_batch(X, y)
     sess.run(self._training_op, feed_dict = {self._X:X_batch, self._y:y_batch})
    loss_val, accuracy_val = sess.run([self._loss, self._accuracy], feed_dict = {self._X: X_val, self._y: y_val})
    accuracy_train = self._accuracy.eval(feed_dict = {self._X: X_batch, self._y: y_batch})

    if loss_val < best_loss:
     best_loss = loss_val
     best_params = self._get_model_params()
     checks_without_progress = 0
    else:
     checks_without_progress += 1
     if checks_without_progress > max_checks_without_progress:
      break

    if accuracy_val > best_accuracy:
     best_accuracy = accuracy_val
     #best_params = self._get_model_params()

    if epoch % self.display_step == 0:
     print('Epoch: {}\tValidaiton loss: {:.6f}\tValidation Accuracy: {:.4f}\tTraining Accuracy: {:.4f}'
      .format(epoch, loss_val, accuracy_val, accuracy_train))
   print('Best Accuracy: {:.4f}\tBest Loss: {:.6f}'.format(best_accuracy, best_loss))
   if best_params:
    self._restore_model_params(best_params)
    self._intercept = best_params['trainable_variables/weights']
    self._bias = best_params['trainable_variables/bias']
   return self

 def predict(self, X):
  with self._session.as_default() as sess:
   return self._prediction_class.eval(feed_dict = {self._X: X})

 def _intercept(self):
  return self._intercept

 def _bias(self):
  return self._bias

实际运行效果如下(以Iris数据集为样本):

使用TensorFlow实现SVM 

画出决策边界来看看:

使用TensorFlow实现SVM 

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python数据结构之二叉树的统计与转换实例
Apr 29 Python
python实现在目录中查找指定文件的方法
Nov 11 Python
python创建临时文件夹的方法
Jul 06 Python
python的pdb调试命令的命令整理及实例
Jul 12 Python
Python实现学生成绩管理系统
Apr 05 Python
Python 实现交换矩阵的行示例
Jun 26 Python
pytorch梯度剪裁方式
Feb 04 Python
python检查目录文件权限并修改目录文件权限的操作
Mar 11 Python
python 连续不等式语法糖实例
Apr 15 Python
appium+python自动化配置(adk、jdk、node.js)
Nov 17 Python
20行代码教你用python给证件照换底色的方法示例
Feb 05 Python
python爬虫beautifulsoup库使用操作教程全解(python爬虫基础入门)
Feb 19 Python
使用Python制作自动推送微信消息提醒的备忘录功能
Sep 06 #Python
python实现机器学习之多元线性回归
Sep 06 #Python
python实现机器学习之元线性回归
Sep 06 #Python
Python import与from import使用及区别介绍
Sep 06 #Python
用python实现k近邻算法的示例代码
Sep 06 #Python
python K近邻算法的kd树实现
Sep 06 #Python
pyqt5的QComboBox 使用模板的具体方法
Sep 06 #Python
You might like
虹吸式咖啡壶操作
2021/03/03 冲泡冲煮
PHP获取http请求的头信息实现步骤
2012/12/16 PHP
php设置session值和cookies的学习示例
2014/03/21 PHP
PHP模块memcached使用指南
2014/12/08 PHP
PHP获取访问页面HTTP状态码的实现代码
2016/11/03 PHP
php5.5使用PHPMailer-5.2发送邮件的完整步骤
2018/10/14 PHP
php面向对象基础详解【星际争霸游戏案例】
2020/01/23 PHP
用javascript实现自定义标签
2007/05/08 Javascript
JavaScript格式化日期时间的方法和自定义格式化函数示例
2014/04/04 Javascript
D3.js封装文本实现自动换行和旋转平移等功能
2016/10/14 Javascript
轻松学习Javascript闭包
2017/03/01 Javascript
基于vue v-for 循环复选框-默认勾选第一个的实现方法
2018/03/03 Javascript
vue中Element-ui 输入银行账号每四位加一个空格的实现代码
2018/09/14 Javascript
js实现打字小游戏
2019/12/17 Javascript
JS数组扁平化、去重、排序操作实例详解
2020/02/24 Javascript
vue实现商品列表的添加删除实例讲解
2020/05/14 Javascript
[07:38]2014DOTA2国际邀请赛 Newbee顺利挺进胜者组赛后专访
2014/07/15 DOTA
[00:10]DOTA2全国高校联赛速递
2018/05/30 DOTA
Python中不同进制互相转换(二进制、八进制、十进制和十六进制)
2015/04/05 Python
Python实现简单的代理服务器
2015/07/25 Python
python数据类型_元组、字典常用操作方法(介绍)
2017/05/30 Python
django初始化数据库的实例
2018/05/27 Python
在python中获取div的文本内容并和想定结果进行对比详解
2019/01/02 Python
django Admin文档生成器使用详解
2019/07/22 Python
python自动保存百度盘资源到百度盘中的实例代码
2019/08/26 Python
python实现小世界网络生成
2019/11/21 Python
Python 实现向word(docx)中输出
2020/02/13 Python
通过实例简单了解Python sys.argv[]使用方法
2020/08/04 Python
css3实现背景动态渐变效果
2019/12/10 HTML / CSS
马克华菲官方商城:Mark Fairwhale
2016/09/04 全球购物
采购主管的岗位职责
2013/12/17 职场文书
无财产无子女离婚协议书范文
2014/09/14 职场文书
个人融资协议书范本两则
2014/10/15 职场文书
小学班级管理心得体会
2016/01/07 职场文书
超级详细实用的pycharm常用快捷键
2021/05/12 Python
Keras多线程机制与flask多线程冲突的解决方案
2021/05/28 Python