python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python简单调用MySQL存储过程并获得返回值的方法
Jul 20 Python
python开发之字符串string操作方法实例详解
Nov 12 Python
Python网络编程中urllib2模块的用法总结
Jul 12 Python
python opencv实现切变换 不裁减图片
Jul 26 Python
pytorch 实现打印模型的参数值
Dec 30 Python
Pycharm和Idea支持的vim插件的方法
Feb 21 Python
python 等差数列末项计算方式
May 03 Python
Python中for后接else的语法使用
May 18 Python
pytorch model.cuda()花费时间很长的解决
Jun 01 Python
Python中OpenCV实现查找轮廓的实例
Jun 08 Python
python自动化测试通过日志3分钟定位bug
Nov 20 Python
Python中的socket网络模块介绍
Jul 23 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
php桌面中心(二) 数据库写入
2007/03/11 PHP
超级实用的7个PHP代码片段分享
2012/01/05 PHP
基于php split()函数的用法详解
2013/06/05 PHP
两个php日期控制类实例
2014/12/09 PHP
php+javascript实现的动态显示服务器运行程序进度条功能示例
2017/08/07 PHP
IE6/7 and IE8/9/10(IE7模式)依次隐藏具有absolute或relative的父元素和子元素后再显示父元素
2011/07/31 Javascript
13 个JavaScript 性能提升技巧分享
2012/07/26 Javascript
js操作模态窗口及父子窗口间相互传值示例
2014/06/09 Javascript
深入理解JavaScript中的箭头函数
2015/07/28 Javascript
jQuery实现指定内容滚动同时左侧或其它地方不滚动的方法
2015/08/08 Javascript
Angularjs 创建可复用组件实例代码
2016/10/09 Javascript
详解Html a标签中href和onclick用法、区别、优先级别
2017/01/16 Javascript
Angular.JS内置服务$http对数据库的增删改使用教程
2017/05/07 Javascript
JavaScript使用递归和循环实现阶乘的实例代码
2018/08/28 Javascript
微信小程序实现搜索功能并跳转搜索结果页面
2019/05/18 Javascript
微信小程序template模板与component组件的区别和使用详解
2019/05/22 Javascript
vue 组件中使用 transition 和 transition-group实现过渡动画
2019/07/09 Javascript
vue移动端使用canvas签名的实现
2020/01/15 Javascript
jQuery实现移动端笔触canvas电子签名
2020/05/21 jQuery
jquery实现点击左右按钮切换图片
2021/01/27 jQuery
[00:43]DOTA2小紫本全民票选福利PA至宝全方位展示
2014/11/25 DOTA
[02:10]2018DOTA2亚洲邀请赛赛前采访-Liquid
2018/04/03 DOTA
python爬取cnvd漏洞库信息的实例
2019/02/14 Python
Python 可变类型和不可变类型及引用过程解析
2019/09/27 Python
django实现用户注册实例讲解
2019/10/30 Python
VSCode中自动为Python文件添加头部注释
2019/11/14 Python
pytorch-RNN进行回归曲线预测方式
2020/01/14 Python
如何在Windows中安装多个python解释器
2020/06/16 Python
努力学习演讲稿
2014/05/10 职场文书
小学雷锋月活动总结
2014/07/03 职场文书
老人节标语大全
2014/10/08 职场文书
继承公证书格式
2015/01/26 职场文书
2016年中秋祝酒词
2015/11/26 职场文书
2016消防宣传标语口号
2015/12/26 职场文书
残联2016年全国助残日活动总结
2016/04/01 职场文书
在校大学生才艺比赛策划书怎么写?
2019/08/26 职场文书