python+numpy+matplotalib实现梯度下降法


Posted in Python onAugust 31, 2018

这个阶段一直在做和梯度一类算法相关的东西,索性在这儿做个汇总:

一、算法论述

梯度下降法(gradient  descent)别名最速下降法(曾经我以为这是两个不同的算法-.-),是用来求解无约束最优化问题的一种常用算法。下面以求解线性回归为题来叙述:

设:一般的线性回归方程(拟合函数)为:(其中python+numpy+matplotalib实现梯度下降法的值为1)

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法这一组向量参数选择的好与坏就需要一个机制来评估,据此我们提出了其损失函数为(选择均方误差):

python+numpy+matplotalib实现梯度下降法

我们现在的目的就是使得损失函数python+numpy+matplotalib实现梯度下降法取得最小值,即目标函数为:

python+numpy+matplotalib实现梯度下降法

如果python+numpy+matplotalib实现梯度下降法的值取到了0,意味着我们构造出了极好的拟合函数,也即选择出了最好的python+numpy+matplotalib实现梯度下降法值,但这基本是达不到的,我们只能使得其无限的接近于0,当满足一定精度时停止迭代。

那么问题来了如何调整python+numpy+matplotalib实现梯度下降法使得python+numpy+matplotalib实现梯度下降法取得的值越来越小呢?方法很多,此处以梯度下降法为例:

分为两步:(1)初始化python+numpy+matplotalib实现梯度下降法的值。

(2)改变python+numpy+matplotalib实现梯度下降法的值,使得python+numpy+matplotalib实现梯度下降法按梯度下降的方向减少。

python+numpy+matplotalib实现梯度下降法值的更新使用如下的方式来完成:

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

其中python+numpy+matplotalib实现梯度下降法为步长因子,这里我们取定值,但注意如果python+numpy+matplotalib实现梯度下降法取得过小会导致收敛速度过慢,python+numpy+matplotalib实现梯度下降法过大则损失函数可能不会收敛,甚至逐渐变大,可以在下述的代码中修改python+numpy+matplotalib实现梯度下降法的值来进行验证。后面我会再写一篇关于随机梯度下降法的文章,其实与梯度下降法最大的不同就在于一个求和符号。

二、代码实现

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import style
 
 
#构造数据
def get_data(sample_num=10000):
 """
 拟合函数为
 y = 5*x1 + 7*x2
 :return:
 """
 x1 = np.linspace(0, 9, sample_num)
 x2 = np.linspace(4, 13, sample_num)
 x = np.concatenate(([x1], [x2]), axis=0).T
 y = np.dot(x, np.array([5, 7]).T) 
 return x, y
#梯度下降法
def GD(samples, y, step_size=0.01, max_iter_count=1000):
 """
 :param samples: 样本
 :param y: 结果value
 :param step_size: 每一接迭代的步长
 :param max_iter_count: 最大的迭代次数
 :param batch_size: 随机选取的相对于总样本的大小
 :return:
 """
 #确定样本数量以及变量的个数初始化theta值
 m, var = samples.shape
 theta = np.zeros(2)
 y = y.flatten()
 #进入循环内
 print(samples)
 loss = 1
 iter_count = 0
 iter_list=[]
 loss_list=[]
 theta1=[]
 theta2=[]
 #当损失精度大于0.01且迭代此时小于最大迭代次数时,进行
 while loss > 0.001 and iter_count < max_iter_count:
 loss = 0
 #梯度计算
 theta1.append(theta[0])
 theta2.append(theta[1])
 for i in range(m):
  h = np.dot(theta,samples[i].T) 
 #更新theta的值,需要的参量有:步长,梯度
  for j in range(len(theta)):
  theta[j] = theta[j] - step_size*(1/m)*(h - y[i])*samples[i,j]
 #计算总体的损失精度,等于各个样本损失精度之和
 for i in range(m):
  h = np.dot(theta.T, samples[i])
  #每组样本点损失的精度
  every_loss = (1/(var*m))*np.power((h - y[i]), 2)
  loss = loss + every_loss
 
 print("iter_count: ", iter_count, "the loss:", loss)
 
 iter_list.append(iter_count)
 loss_list.append(loss)
 
 iter_count += 1
 plt.plot(iter_list,loss_list)
 plt.xlabel("iter")
 plt.ylabel("loss")
 plt.show()
 return theta1,theta2,theta,loss_list
def painter3D(theta1,theta2,loss):
 style.use('ggplot')
 fig = plt.figure()
 ax1 = fig.add_subplot(111, projection='3d')
 x,y,z = theta1,theta2,loss
 ax1.plot_wireframe(x,y,z, rstride=5, cstride=5)
 ax1.set_xlabel("theta1")
 ax1.set_ylabel("theta2")
 ax1.set_zlabel("loss")
 plt.show()
def predict(x, theta):
 y = np.dot(theta, x.T)
 return y 
if __name__ == '__main__':
 samples, y = get_data()
 theta1,theta2,theta,loss_list = GD(samples, y)
 print(theta) # 会很接近[5, 7] 
 painter3D(theta1,theta2,loss_list)
 predict_y = predict(theta, [7,8])
 print(predict_y)

三、绘制的图像如下:

迭代次数与损失精度间的关系图如下:步长为0.01

python+numpy+matplotalib实现梯度下降法

变量python+numpy+matplotalib实现梯度下降法python+numpy+matplotalib实现梯度下降法与损失函数loss之间的关系:(从初始化之后会一步步收敛到loss满足精度,之后python+numpy+matplotalib实现梯度下降法python+numpy+matplotalib实现梯度下降法会变的稳定下来)

python+numpy+matplotalib实现梯度下降法

下面我们来看一副当步长因子变大后的图像:步长因子为0.5(很明显其收敛速度变缓了)

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

当步长因子设置为1.8左右时,其损失值已经开始震荡

python+numpy+matplotalib实现梯度下降法

python+numpy+matplotalib实现梯度下降法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python的string模块中的Template类字符串模板用法
Jun 27 Python
Python如何为图片添加水印
Nov 25 Python
python3中zip()函数使用详解
Jun 29 Python
Python使用pickle模块报错EOFError Ran out of input的解决方法
Aug 16 Python
python实现反转部分单向链表
Sep 27 Python
python里 super类的工作原理详解
Jun 19 Python
python装饰器练习题及答案
Nov 01 Python
matplotlib实现显示伪彩色图像及色度条
Dec 07 Python
win10安装tesserocr配置 Python使用tesserocr识别字母数字验证码
Jan 16 Python
结束运行python的方法
Jun 16 Python
Python实现播放和录制声音的功能
Aug 12 Python
python实现三次密码验证的示例
Apr 29 Python
python实现随机梯度下降法
Mar 24 #Python
python实现决策树分类(2)
Aug 30 #Python
python实现决策树分类
Aug 30 #Python
python实现多人聊天室
Mar 31 #Python
Python实现将数据写入netCDF4中的方法示例
Aug 30 #Python
Python使用爬虫抓取美女图片并保存到本地的方法【测试可用】
Aug 30 #Python
Python使用一行代码获取上个月是几月
Aug 30 #Python
You might like
实用PHP会员权限控制实现原理分析
2011/05/29 PHP
PHP中“简单工厂模式”实例代码讲解
2012/09/04 PHP
提高PHP性能的编码技巧以及性能优化详细解析
2013/08/24 PHP
PHP网页游戏学习之Xnova(ogame)源码解读(八)
2014/06/23 PHP
smarty模板判断数组为空的方法
2015/06/10 PHP
codeigniter实现get分页的方法
2015/07/10 PHP
Django中的cookie与session操作实例代码
2017/08/17 PHP
jquery URL参数判断,确定菜单样式
2010/05/31 Javascript
解析使用js判断只能输入数字、字母等验证的方法(总结)
2013/05/14 Javascript
JS如何判断移动端访问设备并解析对应CSS
2013/11/27 Javascript
使用jquery实现放大镜效果
2014/09/02 Javascript
jquery实现拖拽调整Div大小
2015/01/30 Javascript
JS解析XML文件和XML字符串详解
2015/04/17 Javascript
jquery判断输入密码两次是否相等
2020/04/22 Javascript
详解js中class的多种函数封装方法
2016/01/03 Javascript
学习JavaScript设计模式之模板方法模式
2016/01/20 Javascript
浅谈javascript控制HTML5的全屏操控,浏览器兼容的问题
2016/10/10 Javascript
最好用的Bootstrap fileinput.js文件上传组件
2016/12/12 Javascript
详解AngularJS 模块化
2017/06/14 Javascript
浅谈vue的iview列表table render函数设置DOM属性值的方法
2017/09/30 Javascript
JS回调函数原理与用法详解【附PHP回调函数】
2019/07/20 Javascript
Vue检测屏幕变化来改变不同的charts样式实例
2020/10/26 Javascript
[03:17]DOTA2英雄基础教程 剧毒术士
2013/12/12 DOTA
python对url格式解析的方法
2015/05/13 Python
使用Python进行二进制文件读写的简单方法(推荐)
2016/09/12 Python
Python 40行代码实现人脸识别功能
2017/04/02 Python
Python使用reportlab模块生成PDF格式的文档
2019/03/11 Python
python实现图像拼接
2020/03/05 Python
python 制作网站筛选工具(附源码)
2021/01/21 Python
台湾网购生鲜第一品牌:i3Fresh爱上新鲜
2017/10/26 全球购物
汇智创新科技发展有限公司
2015/12/06 面试题
人力资源作业细则
2014/03/03 职场文书
信息技术教研组工作总结
2015/08/13 职场文书
公安忠诚教育心得体会
2016/01/23 职场文书
Python获取指定日期是"星期几"的6种方法
2022/03/13 Python
抖音动画片,皮皮虾,《治愈系》动画在用这首REMIX作为背景音乐,Anak ,The last world with you完整版
2022/03/16 杂记