python实现随机梯度下降法


Posted in Python onMarch 24, 2020

看这篇文章前强烈建议你看看上一篇python实现梯度下降法:

一、为什么要提出随机梯度下降算法

注意看梯度下降法权值的更新方式(推导过程在上一篇文章中有)

python实现随机梯度下降法

 也就是说每次更新权值python实现随机梯度下降法都需要遍历整个数据集(注意那个求和符号),当数据量小的时候,我们还能够接受这种算法,一旦数据量过大,那么使用该方法会使得收敛过程极度缓慢,并且当存在多个局部极小值时,无法保证搜索到全局最优解。为了解决这样的问题,引入了梯度下降法的进阶形式:随机梯度下降法。

二、核心思想

对于权值的更新不再通过遍历全部的数据集,而是选择其中的一个样本即可(对于程序员来说你的第一反应一定是:在这里需要一个随机函数来选择一个样本,不是吗?),一般来说其步长的选择比梯度下降法的步长要小一点,因为梯度下降法使用的是准确梯度,所以它可以朝着全局最优解(当问题为凸问题时)较大幅度的迭代下去,但是随机梯度法不行,因为它使用的是近似梯度,或者对于全局来说有时候它走的也许根本不是梯度下降的方向,故而它走的比较缓,同样这样带来的好处就是相比于梯度下降法,它不是那么容易陷入到局部最优解中去。

三、权值更新方式

python实现随机梯度下降法

(i表示样本标号下标,j表示样本维数下标)

四、代码实现(大体与梯度下降法相同,不同在于while循环中的内容)

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=1000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def SGD(samples, y, step_size=2, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.01 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1]) 
 #样本维数下标
 rand1 = np.random.randint(0,m,1)
 h = np.dot(theta,samples[rand1].T)
 #关键点,只需要一个样本点来更新权值
 for i in range(len(theta)):
 theta[i] =theta[i] - step_size*(1/m)*(h - y[rand1])*samples[rand1,i]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
 h = np.dot(theta.T, samples[i])
 #每组样本点损失的精度
 every_loss = (1/(var*m))*np.power((h - y[i]), 2)
 loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
 
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = SGD(samples, y)
 print(theta) # 会很接近[5, 7]
 
 painter3D(theta1,theta2,loss_list)

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
讲解Python中for循环下的索引变量的作用域
Apr 15 Python
举例区分Python中的浅复制与深复制
Jul 02 Python
python通过socket查询whois的方法
Jul 18 Python
利用django如何解析用户上传的excel文件
Jul 24 Python
Python 基础教程之str和repr的详解
Aug 20 Python
Python输出带颜色的字符串实例
Oct 10 Python
python中requests使用代理proxies方法介绍
Oct 25 Python
django之使用celery-把耗时程序放到celery里面执行的方法
Jul 12 Python
Python3 venv搭建轻量级虚拟环境的步骤(图文)
Aug 09 Python
python面向对象 反射原理解析
Aug 12 Python
python 利用openpyxl读取Excel表格中指定的行或列教程
Feb 06 Python
PyQt 如何创建自定义QWidget
Mar 24 Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
Python实现的读取/更改/写入xml文件操作示例
Aug 30 #Python
You might like
操作Oracle的php类
2006/10/09 PHP
PHP开发环境配置(MySQL数据库安装图文教程)
2010/04/28 PHP
PHP文件操作方法汇总
2015/07/01 PHP
Zend Framework教程之Autoloading用法详解
2016/03/08 PHP
关于PHP中字符串与多进制转换函数的实例代码
2016/11/03 PHP
JS子父窗口互相操作取值赋值的方法介绍
2013/05/11 Javascript
jquery左边浮动到一定位置时显示返回顶部按钮
2014/06/05 Javascript
node.js中的fs.unlinkSync方法使用说明
2014/12/15 Javascript
JavaScript实现自动变换表格边框颜色
2015/05/08 Javascript
JS实现的图片预览插件与用法示例【不上传图片】
2016/11/25 Javascript
JavaScript实现旋转轮播图
2020/08/18 Javascript
在Vue组件化中利用axios处理ajax请求的使用方法
2017/08/25 Javascript
vue vue-Router默认hash模式修改为history需要做的修改详解
2018/09/13 Javascript
Vue动态生成表格的行和列
2019/07/18 Javascript
详解vue-cli项目开发/生产环境代理实现跨域请求
2019/07/23 Javascript
[02:57]2014DOTA2国际邀请赛 选手辛苦解说更辛苦
2014/07/10 DOTA
Python中使用item()方法遍历字典的例子
2014/08/26 Python
Python深入学习之装饰器
2014/08/31 Python
Python下rrdtool模块的基本使用方法
2015/11/13 Python
Python 爬虫学习笔记之多线程爬虫
2016/09/21 Python
对TensorFlow的assign赋值用法详解
2018/07/30 Python
简单了解python调用其他脚本方法实例
2020/03/26 Python
Python 给下载文件显示进度条和下载时间的实现
2020/04/02 Python
python 代码运行时间获取方式详解
2020/09/18 Python
Python 解析xml文件的示例
2020/09/29 Python
美国复古街头服饰精品店:Need Supply Co.
2017/02/22 全球购物
Viking比利时:购买办公用品
2019/10/30 全球购物
Fanatics官网:运动服装、球衣、运动装备
2020/10/12 全球购物
给定一个时间点,希望得到其他时间点
2013/11/07 面试题
行政管理人员精品工作推荐信
2013/11/04 职场文书
国际语言毕业生求职信
2014/07/08 职场文书
离职证明范本(5篇)
2014/09/19 职场文书
2015秋季运动会通讯稿
2015/07/18 职场文书
2015年度考核个人工作总结
2015/10/24 职场文书
CSS中妙用 drop-shadow 实现线条光影效果
2021/11/11 HTML / CSS
如何基于python实现单目三维重建详解
2022/06/25 Python