python实现感知机模型的示例


Posted in Python onSeptember 30, 2020
from sklearn.linear_model import Perceptron
import argparse #一个好用的参数传递模型
import numpy as np
from sklearn.datasets import load_iris #数据集
from sklearn.model_selection import train_test_split #训练集和测试集分割
from loguru import logger #日志输出,不清楚用法

#python is also oop 
class PerceptronToby():
  """
  n_epoch:迭代次数
  learning_rate:学习率
  loss_tolerance:损失阈值,即损失函数达到极小值的变化量
  """
  def __init__(self, n_epoch = 500, learning_rate = 0.1, loss_tolerance = 0.01):
    self._n_epoch = n_epoch
    self._lr = learning_rate
    self._loss_tolerance = loss_tolerance
  
  """训练模型,即找到每个数据最合适的权重以得到最小的损失函数"""
  def fit(self, X, y):
    # X:训练集,即数据集,每一行是样本,每一列是数据或标签,一样本包括一数据和一标签
    # y:标签,即1或-1
    n_sample, n_feature = X.shape #剥离矩阵的方法真帅

    #均匀初始化参数
    rnd_val = 1/np.sqrt(n_feature)
    rng = np.random.default_rng()
    self._w = rng.uniform(-rnd_val,rnd_val,size = n_feature)
    #偏置初始化为0
    self._b = 0

    #开始训练了,迭代n_epoch次
    num_epoch = 0 #记录迭代次数
    prev_loss = 0 #前损失值
    while True:
      curr_loss = 0 #现在损失值
      wrong_classify = 0 #误分类样本

      #一次迭代对每个样本操作一次
      for i in range(n_sample):
        #输出函数
        y_pred = np.dot(self._w,X[i]) + self._b
        #损失函数
        curr_loss += -y[i] * y_pred
        # 感知机只对误分类样本进行参数更新,使用梯度下降法
        if y[i] * y_pred <= 0:
          self._w += self._lr * y[i] * X[i]
          self._b += self._lr * y[i]
          wrong_classify += 1

      num_epoch += 1
      loss_diff = curr_loss - prev_loss
      prev_loss = curr_loss
      # 训练终止条件:
      # 1. 训练epoch数达到指定的epoch数时停止训练
      # 2. 本epoch损失与上一个epoch损失差异小于指定的阈值时停止训练
      # 3. 训练过程中不再存在误分类点时停止训练
      if num_epoch >= self._n_epoch or abs(loss_diff) < self._loss_tolerance or wrong_classify == 0:
        break


  """预测模型,顾名思义"""
  def predict(self, x):
    """给定输入样本,预测其类别"""
    y_pred = np.dot(self._w, x) + self._b
    return 1 if y_pred >= 0 else -1

#主函数
def main():
  #参数数组生成
  parser = argparse.ArgumentParser(description="感知机算法实现命令行参数")
  parser.add_argument("--nepoch", type=int, default=500, help="训练多少个epoch后终止训练")
  parser.add_argument("--lr", type=float, default=0.1, help="学习率")
  parser.add_argument("--loss_tolerance", type=float, default=0.001, help="当前损失与上一个epoch损失之差的绝对值小于该值时终止训练")
  args = parser.parse_args()
  #导入数据
  X, y = load_iris(return_X_y=True)
  # print(y)
  y[:50] = -1
  # 分割数据
  xtrain, xtest, ytrain, ytest = train_test_split(X[:100], y[:100], train_size=0.8, shuffle=True)
  # print(xtest)
  #调用并训练模型
  model = PerceptronToby(args.nepoch, args.lr, args.loss_tolerance)
  model.fit(xtrain, ytrain)

  n_test = xtest.shape[0]
  # print(n_test)
  n_right = 0
  for i in range(n_test):
    y_pred = model.predict(xtest[i])
    if y_pred == ytest[i]:
      n_right += 1
    else:
      logger.info("该样本真实标签为:{},但是toby模型预测标签为:{}".format(ytest[i], y_pred))
  logger.info("toby模型在测试集上的准确率为:{}%".format(n_right * 100 / n_test))

  skmodel = Perceptron(max_iter=args.nepoch)
  skmodel.fit(xtrain, ytrain)
  logger.info("sklearn模型在测试集上准确率为:{}%".format(100 * skmodel.score(xtest, ytest)))
if __name__ == "__main__":
  main()```

视频参考地址

以上就是python实现感知机模型的示例的详细内容,更多关于python 实现感知机模型的示例代码的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python通过属性手段实现只允许调用一次的示例讲解
Apr 21 Python
浅述python中深浅拷贝原理
Sep 18 Python
在python环境下运用kafka对数据进行实时传输的方法
Dec 27 Python
pycharm+PyQt5+python最新开发环境配置(踩坑)
Feb 11 Python
使用Python Pandas处理亿级数据的方法
Jun 24 Python
Python识别html主要文本框过程解析
Feb 18 Python
解决pycharm不能自动补全第三方库的函数和属性问题
Mar 12 Python
django xadmin action兼容自定义model权限教程
Mar 30 Python
Python判断字符串是否为空和null方法实例
Apr 26 Python
PyCharm中如何直接使用Anaconda已安装的库
May 28 Python
python GUI计算器的实现
Oct 09 Python
python 标准库原理与用法详解之os.path篇
Oct 24 Python
python 实现关联规则算法Apriori的示例
Sep 30 #Python
Python之字典添加元素的几种方法
Sep 30 #Python
Python之字典对象的几种创建方法
Sep 30 #Python
python 实现朴素贝叶斯算法的示例
Sep 30 #Python
Python字典取键、值对的方法步骤
Sep 30 #Python
Python根据字典的值查询出对应的键的方法
Sep 30 #Python
python字典通过值反查键的实现(简洁写法)
Sep 30 #Python
You might like
支持中文的php加密解密类代码
2011/11/27 PHP
解析php中获取系统信息的方法
2013/06/25 PHP
php在数据库抽象层简单使用PDO的方法
2015/11/03 PHP
PHP中set_include_path()函数相关用法分析
2016/07/18 PHP
跟随鼠标旋转的文字
2006/11/30 Javascript
用js实现控制内容的向上向下滚动效果
2007/06/26 Javascript
JavaScript 注册事件代码
2011/01/27 Javascript
JavaScript判断一个URL链接是否有效的实现方法
2011/10/08 Javascript
JS中获取函数调用链所有参数的方法
2015/05/07 Javascript
javascript限制文本框输入值类型的方法
2015/05/07 Javascript
Ajax清除浏览器js、css、图片缓存的方法
2015/08/06 Javascript
angularjs封装bootstrap时间插件datetimepicker
2016/06/20 Javascript
Vue.js快速入门实例教程
2016/10/15 Javascript
Nodejs 搭建简单的Web服务器详解及实例
2016/11/30 NodeJs
JavaScript字符串对象(string)基本用法示例
2017/01/18 Javascript
React Native时间转换格式工具类分享
2017/10/24 Javascript
Vue二次封装axios为插件使用详解
2018/05/21 Javascript
关于vue-router的那些事儿
2018/05/23 Javascript
js input输入百分号保存数据库失败的解决方法
2018/05/26 Javascript
Vue EventBus自定义组件事件传递
2018/06/25 Javascript
vue兄弟组件传递数据的实例
2018/09/06 Javascript
node.js express框架实现文件上传与下载功能实例详解
2019/10/15 Javascript
微信小程序云开发获取文件夹下所有文件(推荐)
2019/11/14 Javascript
Python学习笔记(二)基础语法
2014/06/06 Python
Python缩进和冒号详解
2016/06/01 Python
Python自定义进程池实例分析【生产者、消费者模型问题】
2016/09/19 Python
Python实现动态加载模块、类、函数的方法分析
2017/07/18 Python
在Python中给Nan值更改为0的方法
2018/10/30 Python
pygame游戏之旅 计算游戏中躲过的障碍数量
2018/11/20 Python
python处理两种分隔符的数据集方法
2018/12/12 Python
keras打印loss对权重的导数方式
2020/06/10 Python
Django实现微信小程序支付的示例代码
2020/09/03 Python
欧舒丹比利时官网:L’OCCITANE比利时
2017/04/25 全球购物
经济纠纷起诉状
2015/05/20 职场文书
2015企业年终工作总结范文
2015/05/27 职场文书
js实现上传图片到服务器
2021/04/11 Javascript