python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python+opencv轮廓检测代码解析
Jan 05 Python
在Pycharm terminal中字体大小设置的方法
Jan 16 Python
通过pycharm使用git的步骤(图文详解)
Jun 13 Python
详解Python 定时框架 Apscheduler原理及安装过程
Jun 14 Python
python turtle库画一个方格和圆实例
Jun 27 Python
对Python中小整数对象池和大整数对象池的使用详解
Jul 09 Python
教你如何编写、保存与运行Python程序的方法
Jul 12 Python
python异常处理try except过程解析
Feb 03 Python
python3.8.1+selenium实现登录滑块验证功能
May 22 Python
Python趣味挑战之用pygame实现简单的金币旋转效果
May 31 Python
python分分钟绘制精美地图海报
Feb 15 Python
详解python的异常捕获
Mar 03 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
PHP中动态HTML的输出技术
2006/10/09 PHP
php session和cookie使用说明
2010/04/07 PHP
php数组函数序列之array_flip() 将数组键名与值对调
2011/11/07 PHP
比较详细PHP生成静态页面教程
2012/01/10 PHP
php学习笔记之基础知识
2014/11/08 PHP
php目录拷贝实现方法
2015/07/10 PHP
php生成动态验证码gif图片
2015/10/19 PHP
golang与PHP输出excel示例
2016/07/22 PHP
javascript实现面向对象类的功能书写技巧
2010/03/07 Javascript
Extjs优化(一)删除冗余代码提高运行速度
2013/04/15 Javascript
JS兼容浏览器的导出Excel(CSV)文件的方法
2014/05/03 Javascript
AngularJS入门教程之链接与图片模板详解
2016/08/19 Javascript
微信小程序图片自适应支持多图实例详解
2017/06/21 Javascript
vue实现表格增删改查效果的实例代码
2017/07/18 Javascript
es6在react中的应用代码解析
2017/11/08 Javascript
JS+CSS实现随机点名(实例代码)
2019/11/04 Javascript
react-intl实现React国际化多语言的方法
2020/09/27 Javascript
js+h5 canvas实现图片验证码
2020/10/11 Javascript
Vue 的 v-model用法实例
2020/11/23 Vue.js
python实现mysql的单引号字符串过滤方法
2015/11/14 Python
使用Python生成随机密码的示例分享
2016/02/18 Python
Python实现网络端口转发和重定向的方法
2016/09/19 Python
深入理解NumPy简明教程---数组3(组合)
2016/12/17 Python
Python 中开发pattern的string模板(template) 实例详解
2017/04/01 Python
Python Numpy库安装与基本操作示例
2019/01/08 Python
使用Python做定时任务及时了解互联网动态
2019/05/15 Python
Python实现爬取并分析电商评论
2020/06/19 Python
python smtplib发送多个email联系人的实现
2020/10/09 Python
《灯光》教学反思
2014/02/08 职场文书
计算机维护专业推荐信
2014/02/27 职场文书
2014学雷锋活动心得体会
2014/03/10 职场文书
煤矿安全承诺书
2014/05/22 职场文书
理发店策划方案
2014/06/05 职场文书
旅游饭店管理专业自荐书
2014/06/28 职场文书
考试没考好检讨书
2015/05/06 职场文书
党风廉政建设心得体会
2019/05/21 职场文书