python实现梯度下降算法


Posted in Python onMarch 24, 2020

梯度下降(Gradient Descent)算法是机器学习中使用非常广泛的优化算法。当前流行的机器学习库或者深度学习库都会包括梯度下降算法的不同变种实现。

本文主要以线性回归算法损失函数求极小值来说明如何使用梯度下降算法并给出python实现。若有不正确的地方,希望读者能指出。 

梯度下降

梯度下降原理:将函数比作一座山,我们站在某个山坡上,往四周看,从哪个方向向下走一小步,能够下降的最快。

python实现梯度下降算法

在线性回归算法中,损失函数为python实现梯度下降算法

在求极小值时,在数据量很小的时候,可以使用矩阵求逆的方式求最优的θ值。但当数据量和特征值非常大,例如几万甚至上亿时,使用矩阵求逆根本就不现实。而梯度下降法就是很好的一个选择了。

使用梯度下降算法的步骤

1)对θ赋初始值,这个值可以是随机的,也可以让θ是一个全零的向量。

2)改变θ的值,使得目标损失函数J(θ)按梯度下降的方向进行减少。

python实现梯度下降算法

其中为学习率或步长,需要人为指定,若过大会导致震荡即不收敛,若过小收敛速度会很慢。

3)当下降的高度小于某个定义的值,则停止下降。

另外,对上面线性回归算法损失函数求梯度,结果如下:

python实现梯度下降算法

在实际应用的过程中,梯度下降算法有三类,它们不同之处在于每次学习(更新模型参数)使用的样本个数,每次更新使用不同的样本会导致每次学习的准确性和学习时间不同。下面将分别介绍原理及python实现。

 批量梯度下降(Batch gradient descent)   

每次使用全量的训练集样本来更新模型参数,即给定一个步长,然后对所有的样本的梯度的和进行迭代: 

python实现梯度下降算法

梯度下降算法最终得到的是局部极小值。而线性回归的损失函数为凸函数,有且只有一个局部最小,则这个局部最小一定是全局最小。所以线性回归中使用批量梯度下降算法,一定可以找到一个全局最优解。

优点全局最优解;易于并行实现;总体迭代次数不多
缺点当样本数目很多时,训练过程会很慢,每次迭代需要耗费大量的时间。

随机梯度下降(Stochastic gradient descent) 

随机梯度下降算法每次从训练集中随机选择一个样本来进行迭代,即:

python实现梯度下降算法

随机梯度下降算法每次只随机选择一个样本来更新模型参数,因此每次的学习是非常快速的,并且可以进行在线更新。 

随机梯度下降最大的缺点在于每次更新可能并不会按照正确的方向进行,因此可以带来优化波动(扰动)。不过从另一个方面来看,随机梯度下降所带来的波动有个好处就是,对于类似盆地区域(即很多局部极小值点)那么这个波动的特点可能会使得优化的方向从当前的局部极小值点跳到另一个更好的局部极小值点,这样便可能对于非凸函数,最终收敛于一个较好的局部极值点,甚至全局极值点。 

优点训练速度快,每次迭代计算量不大
缺点准确度下降,并不是全局最优;不易于并行实现;总体迭代次数比较多。

Mini-batch梯度下降算法

 Mini-batch梯度下降综合了batch梯度下降与stochastic梯度下降,在每次更新速度与更新次数中间取得一个平衡,其每次更新从训练集中随机选择b,b<m个样本进行学习,即:

python实现梯度下降算法

python代码实现

批量梯度下降算法

#!/usr/bin/python
#coding=utf-8
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
 
# 构造训练数据
x = np.arange(0., 10., 0.2)
m = len(x) # 训练数据点数目
print m
x0 = np.full(m, 1.0)
input_data = np.vstack([x0, x]).T # 将偏置b作为权向量的第一个分量
target_data = 2 * x + 5 + np.random.randn(m)
 
# 两种终止条件
loop_max = 10000 # 最大迭代次数(防止死循环)
epsilon = 1e-3
 
# 初始化权值
np.random.seed(0)
theta = np.random.randn(2)
 
alpha = 0.001 # 步长(注意取值过大会导致振荡即不收敛,过小收敛速度变慢)
diff = 0.
error = np.zeros(2)
count = 0 # 循环次数
finish = 0 # 终止标志
 
while count < loop_max:
 count += 1
 
 # 标准梯度下降是在权值更新前对所有样例汇总误差,而随机梯度下降的权值是通过考查某个训练样例来更新的
 # 在标准梯度下降中,权值更新的每一步对多个样例求和,需要更多的计算
 sum_m = np.zeros(2)
 for i in range(m):
 dif = (np.dot(theta, input_data[i]) - target_data[i]) * input_data[i]
 sum_m = sum_m + dif # 当alpha取值过大时,sum_m会在迭代过程中会溢出
 
 theta = theta - alpha * sum_m # 注意步长alpha的取值,过大会导致振荡
 # theta = theta - 0.005 * sum_m # alpha取0.005时产生振荡,需要将alpha调小
 
 # 判断是否已收敛
 if np.linalg.norm(theta - error) < epsilon:
 finish = 1
 break
 else:
 error = theta
 print 'loop count = %d' % count, '\tw:',theta
print 'loop count = %d' % count, '\tw:',theta
 
# check with scipy linear regression
slope, intercept, r_value, p_value, slope_std_error = stats.linregress(x, target_data)
print 'intercept = %s slope = %s' % (intercept, slope)
 
plt.plot(x, target_data, 'g*')
plt.plot(x, theta[1] * x + theta[0], 'r')
plt.show()

运行结果截图:

python实现梯度下降算法

随机梯度下降算法

#!/usr/bin/python
#coding=utf-8
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
 
# 构造训练数据
x = np.arange(0., 10., 0.2)
m = len(x) # 训练数据点数目
x0 = np.full(m, 1.0)
input_data = np.vstack([x0, x]).T # 将偏置b作为权向量的第一个分量
target_data = 2 * x + 5 + np.random.randn(m)
 
# 两种终止条件
loop_max = 10000 # 最大迭代次数(防止死循环)
epsilon = 1e-3
 
# 初始化权值
np.random.seed(0)
theta = np.random.randn(2)
# w = np.zeros(2)
 
alpha = 0.001 # 步长(注意取值过大会导致振荡,过小收敛速度变慢)
diff = 0.
error = np.zeros(2)
count = 0 # 循环次数
finish = 0 # 终止标志
######-随机梯度下降算法
while count < loop_max:
 count += 1
 
 # 遍历训练数据集,不断更新权值
 for i in range(m):
 diff = np.dot(theta, input_data[i]) - target_data[i] # 训练集代入,计算误差值
 
 # 采用随机梯度下降算法,更新一次权值只使用一组训练数据
 theta = theta - alpha * diff * input_data[i]
 
 # ------------------------------终止条件判断-----------------------------------------
 # 若没终止,则继续读取样本进行处理,如果所有样本都读取完毕了,则循环重新从头开始读取样本进行处理。
 
 # ----------------------------------终止条件判断-----------------------------------------
 # 注意:有多种迭代终止条件,和判断语句的位置。终止判断可以放在权值向量更新一次后,也可以放在更新m次后。
 if np.linalg.norm(theta - error) < epsilon: # 终止条件:前后两次计算出的权向量的绝对误差充分小
 finish = 1
 break
 else:
 error = theta
print 'loop count = %d' % count, '\tw:',theta
 
 
# check with scipy linear regression
slope, intercept, r_value, p_value, slope_std_error = stats.linregress(x, target_data)
print 'intercept = %s slope = %s' % (intercept, slope)
 
plt.plot(x, target_data, 'g*')
plt.plot(x, theta[1] * x + theta[0], 'r')
plt.show()

运行结果截图:

python实现梯度下降算法

Mini-batch梯度下降

#!/usr/bin/python
#coding=utf-8
import numpy as np
from scipy importstats
import matplotlib.pyplot as plt
 
# 构造训练数据
x = np.arange(0.,10.,0.2)
m = len(x) # 训练数据点数目
print m
x0 = np.full(m, 1.0)
input_data = np.vstack([x0, x]).T # 将偏置b作为权向量的第一个分量
target_data = 2 *x + 5 +np.random.randn(m)
 
# 两种终止条件
loop_max = 10000 #最大迭代次数(防止死循环)
epsilon = 1e-3
 
# 初始化权值
np.random.seed(0)
theta = np.random.randn(2)
 
alpha = 0.001 #步长(注意取值过大会导致振荡即不收敛,过小收敛速度变慢)
diff = 0.
error = np.zeros(2)
count = 0 #循环次数
finish = 0 #终止标志
minibatch_size = 5 #每次更新的样本数
while count < loop_max:
 count += 1
 
 # minibatch梯度下降是在权值更新前对所有样例汇总误差,而随机梯度下降的权值是通过考查某个训练样例来更新的
 # 在minibatch梯度下降中,权值更新的每一步对多个样例求和,需要更多的计算
 
 for i inrange(1,m,minibatch_size):
 sum_m = np.zeros(2)
 for k inrange(i-1,i+minibatch_size-1,1):
  dif = (np.dot(theta, input_data[k]) - target_data[k]) *input_data[k]
  sum_m = sum_m + dif #当alpha取值过大时,sum_m会在迭代过程中会溢出
 
 theta = theta- alpha * (1.0/minibatch_size) * sum_m #注意步长alpha的取值,过大会导致振荡
 
 # 判断是否已收敛
 if np.linalg.norm(theta- error) < epsilon:
 finish = 1
 break
 else:
 error = theta
 print 'loopcount = %d'% count, '\tw:',theta
print 'loop count = %d'% count, '\tw:',theta
 
# check with scipy linear regression
slope, intercept, r_value, p_value,slope_std_error = stats.linregress(x, target_data)
print 'intercept = %s slope = %s'% (intercept, slope)
 
plt.plot(x, target_data, 'g*')
plt.plot(x, theta[1]* x +theta[0],'r')
plt.show()

运行结果:

python实现梯度下降算法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python启动办公软件进程(word、excel、ppt、以及wps的et、wps、wpp)
Apr 09 Python
Python wxPython库消息对话框MessageDialog用法示例
Sep 03 Python
python判断所输入的任意一个正整数是否为素数的两种方法
Jun 27 Python
Python实现 PS 图像调整中的亮度调整
Jun 28 Python
python3获取当前目录的实现方法
Jul 29 Python
Pytorch中accuracy和loss的计算知识点总结
Sep 10 Python
python实现QQ邮箱发送邮件
Mar 06 Python
Python 读取xml数据,cv2裁剪图片实例
Mar 10 Python
python matplotlib实现将图例放在图外
Apr 17 Python
Python 实现一个计时器
Jul 28 Python
如何使用python-opencv批量生成带噪点噪线的数字验证码
Dec 21 Python
matplotlib制作雷达图报错ValueError的实现
Jan 05 Python
wtfPython—Python中一组有趣微妙的代码【收藏】
Aug 31 #Python
opencv python 图像去噪的实现方法
Aug 31 #Python
python+numpy+matplotalib实现梯度下降法
Aug 31 #Python
python实现随机梯度下降法
Mar 24 #Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
You might like
PHP实现的随机IP函数【国内IP段】
2016/07/20 PHP
php+js实现的拖动滑块验证码验证表单操作示例【附源码下载】
2020/05/27 PHP
一些常用的JS功能函数代码
2009/06/23 Javascript
JavaScript 异步调用框架 (Part 3 - 代码实现)
2009/08/04 Javascript
JavaScript Cookie的读取和写入函数
2009/12/08 Javascript
添加JavaScript重载函数的辅助方法2
2010/07/04 Javascript
关于JavaScript中原型继承中的一点思考
2012/07/25 Javascript
在JS中解析HTML字符串示例代码
2014/04/16 Javascript
JS仿Windows开机启动Loading进度条的方法
2015/02/26 Javascript
jQuery链式调用与show知识浅析
2016/05/11 Javascript
javascript实现任务栏消息提示的简单实例
2016/05/31 Javascript
Nodejs下DNS缓存问题浅析
2016/11/16 NodeJs
js css自定义分页效果
2017/02/24 Javascript
浅谈JavaScript作用域和闭包
2017/09/18 Javascript
详解在vue-cli中引用jQuery、bootstrap以及使用sass、less编写css
2017/11/08 jQuery
JSONP原理及应用实例详解
2018/09/13 Javascript
深入理解 TypeScript Reflect Metadata
2019/12/12 Javascript
js回调函数原理与用法案例分析
2020/03/04 Javascript
JavaScript如何判断对象有某属性
2020/07/03 Javascript
何时/使用 Vue3 render 函数的教程详解
2020/07/25 Javascript
Vue全局使用less样式,组件使用全局样式文件中定义的变量操作
2020/10/21 Javascript
[01:45]DOTA2众星出演!DSPL刀塔次级职业联赛宣传片
2014/11/21 DOTA
[08:08]DOTA2-DPC中国联赛2月28日Recap集锦
2021/03/11 DOTA
Python和Perl绘制中国北京跑步地图的方法
2016/03/03 Python
通过Python 获取Android设备信息的轻量级框架
2017/12/18 Python
python微信跳一跳游戏辅助代码解析
2018/01/29 Python
TensorFlow变量管理详解
2018/03/10 Python
Python实现定时精度可调节的定时器
2018/04/15 Python
HTML5画渐变背景图片并自动下载实现步骤
2013/11/18 HTML / CSS
Lulu Guinness露露·吉尼斯官网:红唇包
2019/02/03 全球购物
俄罗斯园林植物网上商店:Garshinka
2020/07/16 全球购物
护理专科毕业推荐信
2013/11/10 职场文书
初中数学教学反思
2014/01/16 职场文书
MySQL 可扩展设计的基本原则
2021/05/14 MySQL
用Python实现屏幕截图详解
2022/01/22 Python
redis数据结构之压缩列表
2022/03/21 Redis