python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
把大数据数字口语化(python与js)两种实现
Feb 21 Python
使用Python下的XSLT API进行web开发的简单教程
Apr 15 Python
学习python类方法与对象方法
Mar 15 Python
Apache如何部署django项目
May 21 Python
Python简单的制作图片验证码实例
May 31 Python
python自动重试第三方包retrying模块的方法
Apr 24 Python
Python的matplotlib绘图如何修改背景颜色的实现
Jul 16 Python
使用Python实现画一个中国地图
Nov 23 Python
Pytorch实现基于CharRNN的文本分类与生成示例
Jan 08 Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 Python
Python如何在main中调用函数内的函数方式
Jun 01 Python
python简单实现插入排序实例代码
Dec 16 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
一个图形显示IP的PHP程序代码
2007/10/19 PHP
PHP的PDO错误与错误处理
2019/01/27 PHP
getElementsByTagName vs selectNodes效率 及兼容的selectNodes实现
2010/02/26 Javascript
jquery获取下拉列表的值为null的解决方法
2011/03/18 Javascript
关于jQuery中.attr()和.prop()的问题探讨
2013/09/06 Javascript
jquery对单选框,多选框,文本框等常见操作小结
2014/01/08 Javascript
巧用replace将文字表情替换为图片
2014/04/17 Javascript
jQuery异步上传文件插件ajaxFileUpload详细介绍
2015/05/19 Javascript
js+html5实现canvas绘制简单矩形的方法
2015/06/05 Javascript
jQuery添加和删除输入文本框标签代码
2016/05/20 Javascript
Node.js服务器开启Gzip压缩教程
2017/08/11 Javascript
vue-cli+webpack项目 修改项目名称的方法
2018/02/28 Javascript
vue2.0实现移动端的输入框实时检索更新列表功能
2018/05/08 Javascript
vue 权限认证token的实现方法
2018/07/17 Javascript
axios全局注册,设置token,以及全局设置url请求网段的方法
2018/09/25 Javascript
vue+render+jsx实现可编辑动态多级表头table的实例代码
2020/04/01 Javascript
vue如何使用外部特殊字体的操作
2020/07/30 Javascript
jenkins自动构建发布vue项目的方法步骤
2021/01/04 Vue.js
iview实现动态表单和自定义验证时间段重叠
2021/01/10 Javascript
[01:28:44]DOTA2-DPC中国联赛定级赛 RNG vs iG BO3第一场 1月10日
2021/03/11 DOTA
django 创建过滤器的实例详解
2017/08/14 Python
python爬虫之urllib,伪装,超时设置,异常处理的方法
2018/12/19 Python
Python如何访问字符串中的值
2020/02/09 Python
python实现在内存中读写str和二进制数据代码
2020/04/24 Python
Django使用django-simple-captcha做验证码的实现示例
2021/01/07 Python
Links of London官方网站:英国标志性的珠宝品牌
2017/04/09 全球购物
彪马土耳其官网:PUMA土耳其
2019/07/14 全球购物
关于爱情的广播稿
2014/01/16 职场文书
家庭贫困证明范本(经典版)
2014/09/22 职场文书
自查自纠工作情况报告
2014/10/29 职场文书
财务人员岗位职责
2015/02/03 职场文书
2015年七夕情人节活动方案
2015/05/06 职场文书
大学社团活动总结怎么写
2019/06/21 职场文书
中国十大神话动漫电影排行榜 哪吒登顶 白蛇缘起排第七
2022/03/21 国漫
Linux中各个目录的作用与内容
2022/06/28 Servers
Navicat Premium自定义 sql 标签的创建方式
2022/09/23 数据库