python 牛顿法实现逻辑回归(Logistic Regression)


Posted in Python onOctober 15, 2020

本文采用的训练方法是牛顿法(Newton Method)。

代码

import numpy as np

class LogisticRegression(object):
 """
 Logistic Regression Classifier training by Newton Method
 """

 def __init__(self, error: float = 0.7, max_epoch: int = 100):
  """
  :param error: float, if the distance between new weight and 
      old weight is less than error, the process 
      of traing will break.
  :param max_epoch: if training epoch >= max_epoch the process 
       of traing will break.
  """
  self.error = error
  self.max_epoch = max_epoch
  self.weight = None
  self.sign = np.vectorize(lambda x: 1 if x >= 0.5 else 0)

 def p_func(self, X_):
  """Get P(y=1 | x)
  :param X_: shape = (n_samples + 1, n_features)
  :return: shape = (n_samples)
  """
  tmp = np.exp(self.weight @ X_.T)
  return tmp / (1 + tmp)

 def diff(self, X_, y, p):
  """Get derivative
  :param X_: shape = (n_samples, n_features + 1) 
  :param y: shape = (n_samples)
  :param p: shape = (n_samples) P(y=1 | x)
  :return: shape = (n_features + 1) first derivative
  """
  return -(y - p) @ X_

 def hess_mat(self, X_, p):
  """Get Hessian Matrix
  :param p: shape = (n_samples) P(y=1 | x)
  :return: shape = (n_features + 1, n_features + 1) second derivative
  """
  hess = np.zeros((X_.shape[1], X_.shape[1]))
  for i in range(X_.shape[0]):
   hess += self.X_XT[i] * p[i] * (1 - p[i])
  return hess

 def newton_method(self, X_, y):
  """Newton Method to calculate weight
  :param X_: shape = (n_samples + 1, n_features)
  :param y: shape = (n_samples)
  :return: None
  """
  self.weight = np.ones(X_.shape[1])
  self.X_XT = []
  for i in range(X_.shape[0]):
   t = X_[i, :].reshape((-1, 1))
   self.X_XT.append(t @ t.T)

  for _ in range(self.max_epoch):
   p = self.p_func(X_)
   diff = self.diff(X_, y, p)
   hess = self.hess_mat(X_, p)
   new_weight = self.weight - (np.linalg.inv(hess) @ diff.reshape((-1, 1))).flatten()

   if np.linalg.norm(new_weight - self.weight) <= self.error:
    break
   self.weight = new_weight

 def fit(self, X, y):
  """
  :param X_: shape = (n_samples, n_features)
  :param y: shape = (n_samples)
  :return: self
  """
  X_ = np.c_[np.ones(X.shape[0]), X]
  self.newton_method(X_, y)
  return self

 def predict(self, X) -> np.array:
  """
  :param X: shape = (n_samples, n_features] 
  :return: shape = (n_samples]
  """
  X_ = np.c_[np.ones(X.shape[0]), X]
  return self.sign(self.p_func(X_))

测试代码

import matplotlib.pyplot as plt
import sklearn.datasets

def plot_decision_boundary(pred_func, X, y, title=None):
 """分类器画图函数,可画出样本点和决策边界
 :param pred_func: predict函数
 :param X: 训练集X
 :param y: 训练集Y
 :return: None
 """

 # Set min and max values and give it some padding
 x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
 y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
 h = 0.01
 # Generate a grid of points with distance h between them
 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
 # Predict the function value for the whole gid
 Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
 Z = Z.reshape(xx.shape)
 # Plot the contour and training examples
 plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
 plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)
 if title:
  plt.title(title)
 plt.show()

效果

python 牛顿法实现逻辑回归(Logistic Regression)

更多机器学习代码,请访问 https://github.com/WiseDoge/plume

以上就是python 牛顿法实现逻辑回归(Logistic Regression)的详细内容,更多关于python 逻辑回归的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python实现给字典添加条目的方法
Sep 25 Python
python文件操作之目录遍历实例分析
May 20 Python
Python编程实现二叉树及七种遍历方法详解
Jun 02 Python
python snownlp情感分析简易demo(分享)
Jun 04 Python
python DataFrame 修改列的顺序实例
Apr 10 Python
python实现linux下抓包并存库功能
Jul 18 Python
Python机器学习之scikit-learn库中KNN算法的封装与使用方法
Dec 14 Python
对python内置map和six.moves.map的区别详解
Dec 19 Python
python爬虫 基于requests模块的get请求实现详解
Aug 20 Python
python配置文件写入过程详解
Oct 19 Python
Python对Excel按列值筛选并拆分表格到多个文件的代码
Nov 05 Python
python 实现查询Neo4j多节点的多层关系
Dec 23 Python
PyCharm 2020.2.2 x64 下载并安装的详细教程
Oct 15 #Python
Python 实现3种回归模型(Linear Regression,Lasso,Ridge)的示例
Oct 15 #Python
Python在centos7.6上安装python3.9的详细教程(默认python版本为2.7.5)
Oct 15 #Python
Pycharm编辑器功能之代码折叠效果的实现代码
Oct 15 #Python
如何用Python 实现全连接神经网络(Multi-layer Perceptron)
Oct 15 #Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 #Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 #Python
You might like
php 短链接算法收集与分析
2011/12/30 PHP
php addslashes及其他清除空格的方法是不安全的
2012/01/25 PHP
浅析php中三个等号(===)和两个等号(==)的区别
2013/08/06 PHP
PHP实现清除wordpress里恶意代码
2015/10/21 PHP
centos7上编译安装php7以php-fpm方式连接apache
2018/11/08 PHP
javascript 关于# 和 void的区别分析
2009/10/26 Javascript
浏览器兼容console对象的简要解决方案分享
2013/10/24 Javascript
addEventListener()第三个参数useCapture (Boolean)详细解析
2013/11/07 Javascript
使用JavaScript触发过渡效果的方法
2017/01/19 Javascript
bootstrap3使用bootstrap datetimepicker日期插件
2017/05/24 Javascript
关于Vue.nextTick()的正确使用方法浅析
2017/08/25 Javascript
VueJS组件之间通过props交互及验证的方式
2017/09/04 Javascript
微信小程序日期时间选择器使用方法
2018/02/01 Javascript
详解如何写出一个利于扩展的vue路由配置
2019/05/16 Javascript
Layui 带多选框表格监听事件以及按钮自动点击写法实例
2019/09/02 Javascript
vue页面更新patch的实现示例
2020/03/25 Javascript
javascript局部自定义鼠标右键菜单
2020/12/08 Javascript
[02:35]DOTA2英雄基础教程 狙击手
2014/01/14 DOTA
python判断端口是否打开的实现代码
2013/02/10 Python
使用go和python递归删除.ds store文件的方法
2014/01/22 Python
windows及linux环境下永久修改pip镜像源的方法
2016/11/28 Python
Python用户推荐系统曼哈顿算法实现完整代码
2017/12/01 Python
pandas表连接 索引上的合并方法
2018/06/08 Python
几行Python代码爬取3000+上市公司的信息
2019/01/24 Python
Python OpenCV实现鼠标画框效果
2020/08/19 Python
基于Python爬取爱奇艺资源过程解析
2020/03/02 Python
Mountain Warehouse澳大利亚官网:欧洲家庭户外品牌倡导者
2016/11/20 全球购物
英国工作场所设备购买网站:Slingsby
2019/05/03 全球购物
课程改革实施方案
2014/03/16 职场文书
2014年新生军训方案
2014/05/01 职场文书
2014年国庆节活动总结
2014/08/26 职场文书
班主任师德师风自我剖析材料
2014/10/02 职场文书
实习单位意见
2015/06/04 职场文书
CSS3 天气图标动画效果
2021/04/06 HTML / CSS
Python实现打乒乓小游戏
2021/09/25 Python
配置nginx负载均衡
2022/05/06 Servers