python 牛顿法实现逻辑回归(Logistic Regression)


Posted in Python onOctober 15, 2020

本文采用的训练方法是牛顿法(Newton Method)。

代码

import numpy as np

class LogisticRegression(object):
 """
 Logistic Regression Classifier training by Newton Method
 """

 def __init__(self, error: float = 0.7, max_epoch: int = 100):
  """
  :param error: float, if the distance between new weight and 
      old weight is less than error, the process 
      of traing will break.
  :param max_epoch: if training epoch >= max_epoch the process 
       of traing will break.
  """
  self.error = error
  self.max_epoch = max_epoch
  self.weight = None
  self.sign = np.vectorize(lambda x: 1 if x >= 0.5 else 0)

 def p_func(self, X_):
  """Get P(y=1 | x)
  :param X_: shape = (n_samples + 1, n_features)
  :return: shape = (n_samples)
  """
  tmp = np.exp(self.weight @ X_.T)
  return tmp / (1 + tmp)

 def diff(self, X_, y, p):
  """Get derivative
  :param X_: shape = (n_samples, n_features + 1) 
  :param y: shape = (n_samples)
  :param p: shape = (n_samples) P(y=1 | x)
  :return: shape = (n_features + 1) first derivative
  """
  return -(y - p) @ X_

 def hess_mat(self, X_, p):
  """Get Hessian Matrix
  :param p: shape = (n_samples) P(y=1 | x)
  :return: shape = (n_features + 1, n_features + 1) second derivative
  """
  hess = np.zeros((X_.shape[1], X_.shape[1]))
  for i in range(X_.shape[0]):
   hess += self.X_XT[i] * p[i] * (1 - p[i])
  return hess

 def newton_method(self, X_, y):
  """Newton Method to calculate weight
  :param X_: shape = (n_samples + 1, n_features)
  :param y: shape = (n_samples)
  :return: None
  """
  self.weight = np.ones(X_.shape[1])
  self.X_XT = []
  for i in range(X_.shape[0]):
   t = X_[i, :].reshape((-1, 1))
   self.X_XT.append(t @ t.T)

  for _ in range(self.max_epoch):
   p = self.p_func(X_)
   diff = self.diff(X_, y, p)
   hess = self.hess_mat(X_, p)
   new_weight = self.weight - (np.linalg.inv(hess) @ diff.reshape((-1, 1))).flatten()

   if np.linalg.norm(new_weight - self.weight) <= self.error:
    break
   self.weight = new_weight

 def fit(self, X, y):
  """
  :param X_: shape = (n_samples, n_features)
  :param y: shape = (n_samples)
  :return: self
  """
  X_ = np.c_[np.ones(X.shape[0]), X]
  self.newton_method(X_, y)
  return self

 def predict(self, X) -> np.array:
  """
  :param X: shape = (n_samples, n_features] 
  :return: shape = (n_samples]
  """
  X_ = np.c_[np.ones(X.shape[0]), X]
  return self.sign(self.p_func(X_))

测试代码

import matplotlib.pyplot as plt
import sklearn.datasets

def plot_decision_boundary(pred_func, X, y, title=None):
 """分类器画图函数,可画出样本点和决策边界
 :param pred_func: predict函数
 :param X: 训练集X
 :param y: 训练集Y
 :return: None
 """

 # Set min and max values and give it some padding
 x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
 y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
 h = 0.01
 # Generate a grid of points with distance h between them
 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
 # Predict the function value for the whole gid
 Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
 Z = Z.reshape(xx.shape)
 # Plot the contour and training examples
 plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
 plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)
 if title:
  plt.title(title)
 plt.show()

效果

python 牛顿法实现逻辑回归(Logistic Regression)

更多机器学习代码,请访问 https://github.com/WiseDoge/plume

以上就是python 牛顿法实现逻辑回归(Logistic Regression)的详细内容,更多关于python 逻辑回归的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
使用setup.py安装python包和卸载python包的方法
Nov 27 Python
跟老齐学Python之Import 模块
Oct 13 Python
python分析apache访问日志脚本分享
Feb 26 Python
python检查URL是否正常访问的小技巧
Feb 25 Python
对python中raw_input()和input()的用法详解
Apr 22 Python
对Python中plt的画图函数详解
Nov 07 Python
OpenCV HSV颜色识别及HSV基本颜色分量范围
Mar 22 Python
python中pytest收集用例规则与运行指定用例详解
Jun 27 Python
基于python监控程序是否关闭
Jan 14 Python
Python 识别12306图片验证码物品的实现示例
Jan 20 Python
python实现粒子群算法
Oct 15 Python
python 爬取京东指定商品评论并进行情感分析
May 27 Python
PyCharm 2020.2.2 x64 下载并安装的详细教程
Oct 15 #Python
Python 实现3种回归模型(Linear Regression,Lasso,Ridge)的示例
Oct 15 #Python
Python在centos7.6上安装python3.9的详细教程(默认python版本为2.7.5)
Oct 15 #Python
Pycharm编辑器功能之代码折叠效果的实现代码
Oct 15 #Python
如何用Python 实现全连接神经网络(Multi-layer Perceptron)
Oct 15 #Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 #Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 #Python
You might like
php 文件上传类代码
2011/08/06 PHP
zf框架的db类select查询器join链表使用示例(zend框架)
2014/03/14 PHP
ThinkPHP整合百度Ueditor图文教程
2014/10/21 PHP
PHP+jquery实时显示网站在线人数的方法
2015/01/04 PHP
php接口技术实例详解
2016/12/07 PHP
PHP观察者模式示例【Laravel框架中有用到】
2018/06/15 PHP
PHP cookie与session会话基本用法实例分析
2019/11/18 PHP
PHP扩展安装方法步骤解析
2020/11/24 PHP
用prototype实现的简单小巧的多级联动菜单
2007/03/24 Javascript
Javascript 读书笔记索引贴
2010/01/11 Javascript
jQuery插件 tabBox实现代码
2010/02/09 Javascript
按给定几率进行随机抽取的js代码
2010/12/28 Javascript
js对象的比较
2011/02/26 Javascript
理解javascript中DOM事件
2015/12/25 Javascript
全面解析node 表单的图片上传
2016/11/21 Javascript
JS中input表单隐藏域及其使用方法
2017/02/13 Javascript
vue+ElementUI实现订单页动态添加产品数据效果实例代码
2017/07/13 Javascript
JS严格模式知识点总结
2018/02/27 Javascript
JS调用安卓手机摄像头扫描二维码
2018/10/16 Javascript
Vue组件Draggable实现拖拽功能
2018/12/01 Javascript
微信小程序使用echarts获取数据并生成折线图
2019/10/16 Javascript
JS严格模式原理与用法实例分析
2020/04/27 Javascript
jQuery cookie的公共方法封装和使用示例
2020/06/01 jQuery
解决elementui表格操作列自适应列宽
2020/12/28 Javascript
[01:06]DOTA2小知识课堂 Ep.02 吹风竟可解梦境缠绕
2019/12/05 DOTA
Python统计时间内的并发数代码实例
2019/12/28 Python
利用4行Python代码监测每一行程序的运行时间和空间消耗
2020/04/22 Python
利用python 读写csv文件
2020/09/10 Python
DNA测试:Orig3n
2019/03/01 全球购物
如何处理简单的PHP错误
2015/10/14 面试题
汽车销售求职自荐信
2013/10/01 职场文书
考试退步检讨书
2014/01/15 职场文书
付款证明模板
2015/06/19 职场文书
详解如何在Canvas中添加事件的方法
2021/04/17 Javascript
OpenCV-Python实现油画效果的实例
2021/06/08 Python
python中的mysql数据库LIKE操作符详解
2021/07/01 MySQL