python梯度下降算法的实现


Posted in Python onFebruary 24, 2020

本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下

简介

本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4

代码

'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys

# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)


class GradientDescent(object):
 eps = 1.0e-8
 max_iter = 1000000 # 暂时不需要
 dim = 1
 func_args = [2.1, 2.7] # [w_0, .., w_dim, b]

 def __init__(self, func_arg=None, N=1000):
 self.data_num = N
 if func_arg is not None:
 self.FuncArgs = func_arg
 self._getData()

 def _getData(self):
 x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
 b_1 = np.ones((self.data_num, 1), dtype=np.float)
 # x = np.concatenate((x, b_1), axis=1)
 self.x = np.concatenate((x, b_1), axis=1)

 def func(self, x):
 # noise太大的话, 梯度下降法失去作用
 noise = 0.01 * np.random.randn(self.data_num) + 0
 w = np.array(self.func_args)
 # y1 = w * self.x[0, ] # 直接相乘
 y = np.dot(self.x, w) # 矩阵乘法
 y += noise
 return y

 @property
 def FuncArgs(self):
 return self.func_args

 @FuncArgs.setter
 def FuncArgs(self, args):
 if not isinstance(args, list):
 raise Exception(
 'args is not list, it should be like [w_0, ..., w_dim, b]')
 if len(args) == 0:
 raise Exception('args is empty list!!')
 if len(args) == 1:
 args.append(0.0)
 self.func_args = args
 self.dim = len(args) - 1
 self._getData()

 @property
 def EPS(self):
 return self.eps

 @EPS.setter
 def EPS(self, value):
 if not isinstance(value, float) and not isinstance(value, int):
 raise Exception("The type of eps should be an float number")
 self.eps = value

 def plotFunc(self):
 # 一维画图
 if self.dim == 1:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 fig, ax = plt.subplots()
 ax.plot(x, y, 'o')
 ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
 ax.grid()
 plt.show()
 # 二维画图
 if self.dim == 2:
 # x = np.sort(self.x, axis=0)
 x = self.x
 y = self.func(x)
 xs = x[:, 0]
 ys = x[:, 1]
 zs = y
 fig = plt.figure()
 ax = fig.add_subplot(111, projection='3d')
 ax.scatter(xs, ys, zs, c='r', marker='o')

 ax.set_xlabel('X Label')
 ax.set_ylabel('Y Label')
 ax.set_zlabel('Z Label')
 plt.show()
 else:
 # plt.axis('off')
 plt.text(
 0.5,
 0.5,
 "The dimension(x.dim > 2) \n is too high to draw",
 size=17,
 rotation=0.,
 ha="center",
 va="center",
 bbox=dict(
 boxstyle="round",
 ec=(1., 0.5, 0.5),
 fc=(1., 0.8, 0.8), ))
 plt.draw()
 plt.show()
 # print('The dimension(x.dim > 2) is too high to draw')

 # 梯度下降法只能求解凸函数
 def _gradient_descent(self, bs, lr, epoch):
 x = self.x
 # shuffle数据集没有必要
 # np.random.shuffle(x)
 y = self.func(x)
 w = np.ones((self.dim + 1, 1), dtype=float)
 for e in range(epoch):
 print('epoch:' + str(e), end=',')
 # 批量梯度下降,bs为1时 等价单样本梯度下降
 for i in range(0, self.data_num, bs):
 y_ = np.dot(x[i:i + bs], w)
 loss = y_ - y[i:i + bs].reshape(-1, 1)
 d = loss * x[i:i + bs]
 d = d.sum(axis=0) / bs
 d = lr * d
 d.shape = (-1, 1)
 w = w - d

 y_ = np.dot(self.x, w)
 loss_ = abs((y_ - y).sum())
 print('\tLoss = ' + str(loss_))
 print('拟合的结果为:', end=',')
 print(sum(w.tolist(), []))
 print()
 if loss_ < self.eps:
 print('The Gradient Descent algorithm has converged!!\n')
 break
 pass

 def __call__(self, bs=1, lr=0.1, epoch=10):
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')
 if not isinstance(bs, int) or not isinstance(epoch, int):
 raise Exception(
 "The type of BatchSize/Epoch should be an integer number")
 self._gradient_descent(bs, lr, epoch)
 pass

 pass


if __name__ == "__main__":
 if sys.version_info < (3, 4):
 raise RuntimeError('At least Python 3.4 is required')

 gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
 # gd = GradientDescent([1.2, 1.4, 2.1])
 print("要拟合的参数结果是: ")
 print(gd.FuncArgs)
 print("===================\n\n")
 # gd.EPS = 0.0
 gd.plotFunc()
 gd(10, 0.01)
 print("Finished!")

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python socket 超时设置 errno 10054
Jul 01 Python
Python3中的2to3转换工具使用示例
Jun 12 Python
Python中对元组和列表按条件进行排序的方法示例
Nov 10 Python
Python使用Windows API创建窗口示例【基于win32gui模块】
May 09 Python
pandas 使用均值填充缺失值列的小技巧分享
Jul 04 Python
用Cython加速Python到“起飞”(推荐)
Aug 01 Python
在pycharm中配置Anaconda以及pip源配置详解
Sep 09 Python
Python Django中的STATIC_URL 设置和使用方式
Mar 27 Python
解决python中import文件夹下面py文件报错问题
Jun 01 Python
python 将列表里的字典元素合并为一个字典实例
Sep 01 Python
pytorch--之halfTensor的使用详解
May 24 Python
浅析Python中的套接字编程
Jun 22 Python
利用python实现逐步回归
Feb 24 #Python
python数据分析:关键字提取方式
Feb 24 #Python
python数据预处理 :数据共线性处理详解
Feb 24 #Python
使用python实现多维数据降维操作
Feb 24 #Python
python数据预处理 :数据抽样解析
Feb 24 #Python
Python找出列表中出现次数最多的元素三种方式
Feb 24 #Python
Python流程控制常用工具详解
Feb 24 #Python
You might like
PHP为表单获取的URL 地址预设 http 字符串函数代码
2010/05/26 PHP
在windows平台上构建自己的PHP实现方法(仅适用于php5.2)
2013/07/05 PHP
在Linux系统下一键重新安装WordPress的脚本示例
2015/06/30 PHP
Prototype1.6 JS 官方下载地址
2007/11/30 Javascript
asp.net 30分钟掌握无刷新 Repeater
2011/09/16 Javascript
javascript for循环从入门到偏门(效率优化+奇特用法)
2012/08/01 Javascript
从数据结构分析看:用for each...in 比 for...in 要快些
2013/04/17 Javascript
js调用图片隐藏&amp;显示实现代码
2013/09/13 Javascript
详谈javascript中的cookie
2015/06/03 Javascript
JS实现滑动菜单效果代码(包括Tab,选项卡,横向等效果)
2015/09/24 Javascript
JS实现图片高亮展示效果实例
2015/11/24 Javascript
jQuery实现的placeholder效果完整实例
2016/08/02 Javascript
javascript+jQuery实现360开机时间显示效果
2017/11/03 jQuery
JavaScript实现多叉树的递归遍历和非递归遍历算法操作示例
2018/02/08 Javascript
jquery判断滚动条距离顶部的距离方法
2018/09/05 jQuery
[00:23]DOTA2群星共贺开放测试 25日无码时代来袭
2013/09/23 DOTA
[05:39]2014DOTA2国际邀请赛 DK晋级胜者组专访战队国士无双
2014/07/14 DOTA
[16:21]教你分分钟做大人:圣堂刺客
2014/12/03 DOTA
python实现中文转换url编码的方法
2016/06/14 Python
python实现对指定输入的字符串逆序输出的6种方法
2018/04/26 Python
Python实现 PS 图像调整中的亮度调整
2019/06/28 Python
Python Pandas对缺失值的处理方法
2019/09/27 Python
python+requests接口自动化框架的实现
2020/08/31 Python
Python基于pillow库实现生成图片水印
2020/09/14 Python
详解基于python的全局与局部序列比对的实现(DNA)
2020/10/07 Python
美国专业级皮肤病和spa品质护肤品的高级零售网站:SkinCareRx
2017/02/06 全球购物
亚洲在线旅行门户网站:Expedia.com.hk(智游网)
2020/04/14 全球购物
应聘编辑自荐信范文
2014/03/12 职场文书
校长寄语大全
2014/04/09 职场文书
资助贫困学生倡议书
2014/05/16 职场文书
讲文明懂礼貌演讲稿
2014/09/11 职场文书
2014国庆节幼儿园亲子活动方案
2014/09/16 职场文书
群众路线教育实践活动思想汇报(2014特荐篇)
2014/09/16 职场文书
认真学习保证书
2015/02/26 职场文书
2015年圣诞节寄语
2015/08/17 职场文书
Nginx搭建rtmp直播服务器实现代码
2021/03/31 Servers