python 牛顿法实现逻辑回归(Logistic Regression)


Posted in Python onOctober 15, 2020

本文采用的训练方法是牛顿法(Newton Method)。

代码

import numpy as np

class LogisticRegression(object):
 """
 Logistic Regression Classifier training by Newton Method
 """

 def __init__(self, error: float = 0.7, max_epoch: int = 100):
  """
  :param error: float, if the distance between new weight and 
      old weight is less than error, the process 
      of traing will break.
  :param max_epoch: if training epoch >= max_epoch the process 
       of traing will break.
  """
  self.error = error
  self.max_epoch = max_epoch
  self.weight = None
  self.sign = np.vectorize(lambda x: 1 if x >= 0.5 else 0)

 def p_func(self, X_):
  """Get P(y=1 | x)
  :param X_: shape = (n_samples + 1, n_features)
  :return: shape = (n_samples)
  """
  tmp = np.exp(self.weight @ X_.T)
  return tmp / (1 + tmp)

 def diff(self, X_, y, p):
  """Get derivative
  :param X_: shape = (n_samples, n_features + 1) 
  :param y: shape = (n_samples)
  :param p: shape = (n_samples) P(y=1 | x)
  :return: shape = (n_features + 1) first derivative
  """
  return -(y - p) @ X_

 def hess_mat(self, X_, p):
  """Get Hessian Matrix
  :param p: shape = (n_samples) P(y=1 | x)
  :return: shape = (n_features + 1, n_features + 1) second derivative
  """
  hess = np.zeros((X_.shape[1], X_.shape[1]))
  for i in range(X_.shape[0]):
   hess += self.X_XT[i] * p[i] * (1 - p[i])
  return hess

 def newton_method(self, X_, y):
  """Newton Method to calculate weight
  :param X_: shape = (n_samples + 1, n_features)
  :param y: shape = (n_samples)
  :return: None
  """
  self.weight = np.ones(X_.shape[1])
  self.X_XT = []
  for i in range(X_.shape[0]):
   t = X_[i, :].reshape((-1, 1))
   self.X_XT.append(t @ t.T)

  for _ in range(self.max_epoch):
   p = self.p_func(X_)
   diff = self.diff(X_, y, p)
   hess = self.hess_mat(X_, p)
   new_weight = self.weight - (np.linalg.inv(hess) @ diff.reshape((-1, 1))).flatten()

   if np.linalg.norm(new_weight - self.weight) <= self.error:
    break
   self.weight = new_weight

 def fit(self, X, y):
  """
  :param X_: shape = (n_samples, n_features)
  :param y: shape = (n_samples)
  :return: self
  """
  X_ = np.c_[np.ones(X.shape[0]), X]
  self.newton_method(X_, y)
  return self

 def predict(self, X) -> np.array:
  """
  :param X: shape = (n_samples, n_features] 
  :return: shape = (n_samples]
  """
  X_ = np.c_[np.ones(X.shape[0]), X]
  return self.sign(self.p_func(X_))

测试代码

import matplotlib.pyplot as plt
import sklearn.datasets

def plot_decision_boundary(pred_func, X, y, title=None):
 """分类器画图函数,可画出样本点和决策边界
 :param pred_func: predict函数
 :param X: 训练集X
 :param y: 训练集Y
 :return: None
 """

 # Set min and max values and give it some padding
 x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5
 y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5
 h = 0.01
 # Generate a grid of points with distance h between them
 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
 # Predict the function value for the whole gid
 Z = pred_func(np.c_[xx.ravel(), yy.ravel()])
 Z = Z.reshape(xx.shape)
 # Plot the contour and training examples
 plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral)
 plt.scatter(X[:, 0], X[:, 1], s=40, c=y, cmap=plt.cm.Spectral)
 if title:
  plt.title(title)
 plt.show()

效果

python 牛顿法实现逻辑回归(Logistic Regression)

更多机器学习代码,请访问 https://github.com/WiseDoge/plume

以上就是python 牛顿法实现逻辑回归(Logistic Regression)的详细内容,更多关于python 逻辑回归的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python网站验证码识别
Jan 25 Python
在Django中进行用户注册和邮箱验证的方法
May 09 Python
Python中实现switch功能实例解析
Jan 11 Python
小白如何入门Python? 制作一个网站为例
Mar 06 Python
强悍的Python读取大文件的解决方案
Feb 16 Python
Python Django框架模板渲染功能示例
Nov 08 Python
python实现自动化报表功能(Oracle/plsql/Excel/多线程)
Dec 02 Python
Tensorflow 实现将图像与标签数据转化为tfRecord文件
Feb 17 Python
Jupyter notebook设置背景主题,字体大小及自动补全代码的操作
Apr 13 Python
Python yield生成器和return对比代码实例
Apr 20 Python
python3从网络摄像机解析mjpeg http流的示例
Nov 13 Python
Django项目如何获得SSL证书与配置HTTPS
Apr 30 Python
PyCharm 2020.2.2 x64 下载并安装的详细教程
Oct 15 #Python
Python 实现3种回归模型(Linear Regression,Lasso,Ridge)的示例
Oct 15 #Python
Python在centos7.6上安装python3.9的详细教程(默认python版本为2.7.5)
Oct 15 #Python
Pycharm编辑器功能之代码折叠效果的实现代码
Oct 15 #Python
如何用Python 实现全连接神经网络(Multi-layer Perceptron)
Oct 15 #Python
python 实现非极大值抑制算法(Non-maximum suppression, NMS)
Oct 15 #Python
解决pip安装的第三方包在PyCharm无法导入的问题
Oct 15 #Python
You might like
PHPWind与Discuz截取字符函数substrs与cutstr性能比较
2011/12/05 PHP
PHP_SELF,SCRIPT_NAME,REQUEST_URI区别
2014/12/24 PHP
php使用curl简单抓取远程url的方法
2015/03/13 PHP
PHP记录搜索引擎蜘蛛访问网站足迹的方法
2015/04/15 PHP
Symfony2之session与cookie用法小结
2016/03/18 PHP
PHP中的Trait 特性及作用
2016/04/03 PHP
详解json在php中的应用
2018/09/30 PHP
php使用Swoole实现毫秒级定时任务的方法
2020/09/04 PHP
有道JavaScript监听浏览器的问题
2010/06/23 Javascript
JS遮罩层效果 兼容ie firefox jQuery遮罩层
2010/07/26 Javascript
网易JS面试题与Javascript词法作用域说明
2010/11/09 Javascript
jquery制作弹窗提示窗口代码分享
2014/03/02 Javascript
浅析javascript的间隔调用和延时调用
2014/11/12 Javascript
arguments对象验证函数的参数是否合法
2015/06/26 Javascript
Bootstrap多级导航栏(级联导航)的实现代码
2016/03/08 Javascript
jqGrid用法汇总(全经典)
2016/06/28 Javascript
JavaScript基础之this详解
2017/06/04 Javascript
详解nodejs模板引擎制作
2017/06/14 NodeJs
原生js中ajax访问的实例详解
2017/09/19 Javascript
利用js将ajax获取到的后台数据动态加载至网页中的方法
2018/08/08 Javascript
[55:35]VGJ.S vs Mski Supermajor小组赛C组 BO3 第二场 6.3
2018/06/04 DOTA
Python Tkinter GUI编程入门介绍
2015/03/10 Python
python打开url并按指定块读取网页内容的方法
2015/04/29 Python
python使用tcp实现局域网内文件传输
2020/03/20 Python
Python3实现从排序数组中删除重复项算法分析
2019/04/03 Python
Django通过dwebsocket实现websocket的例子
2019/11/15 Python
Python 中@property的用法详解
2020/01/15 Python
简述DNS进行域名解析的过程
2013/12/02 面试题
办公室文秘自我评价
2013/09/21 职场文书
自我鉴定的范文
2013/10/03 职场文书
后勤部长岗位职责
2013/12/14 职场文书
班主任班级寄语大全
2014/04/04 职场文书
小学毕业演讲稿
2014/04/25 职场文书
公安学专业求职信
2014/07/27 职场文书
社区个人对照检查材料(群众路线)
2014/09/26 职场文书
四风问题个人对照检查材料
2014/09/26 职场文书