python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python with用法实例
Apr 14 Python
使用python遍历指定城市的一周气温
Mar 31 Python
Django重装mysql后启动报错:No module named ‘MySQLdb’的解决方法
Apr 22 Python
Python OpenCV处理图像之滤镜和图像运算
Jul 10 Python
基于Python对数据shape的常见操作详解
Dec 25 Python
python安装pywin32clipboard的操作方法
Jan 24 Python
PYTHON绘制雷达图代码实例
Oct 15 Python
Python编译为二进制so可执行文件实例
Dec 23 Python
关于tf.reverse_sequence()简述
Jan 20 Python
python 已知三条边求三角形的角度案例
Apr 12 Python
用Python写一个简易版弹球游戏
Apr 13 Python
linux中nohup和后台运行进程查看及终止
Jun 24 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
php外部执行命令函数用法小结
2016/10/11 PHP
Ubuntu VPS中wordpress网站打开时提示”建立数据库连接错误”的解决办法
2016/11/03 PHP
yii框架使用分页的方法分析
2019/07/25 PHP
javascript编码的几个方法详细介绍
2013/01/06 Javascript
jQuery之字体大小的设置方法
2014/02/27 Javascript
input:checkbox多选框实现单选效果跟radio一样
2014/06/16 Javascript
图片放大镜jquery.jqzoom.js使用实例附放大镜图标
2014/06/19 Javascript
用jquery实现动画跳到顶部和底部(这个比较简单)
2014/09/01 Javascript
jQuery基本选择器之标签名选择器
2016/09/03 Javascript
原生js实现网易轮播图效果
2020/04/10 Javascript
Bootstrap CDN和本地化环境搭建
2016/10/26 Javascript
JavaScript实现自定义媒体播放器方法介绍
2017/01/03 Javascript
jquery实现手机端单店铺购物车结算删除功能
2017/02/22 Javascript
详解VUE Element-UI多级菜单动态渲染的组件
2019/04/25 Javascript
配置node服务器并且链接微信公众号接口配置步骤详解
2019/06/21 Javascript
微信小程序登录对接Django后端实现JWT方式验证登录详解
2019/07/29 Javascript
JS实现导航栏楼层特效
2020/01/01 Javascript
JavaScript十大取整方法实例教程
2020/12/03 Javascript
Python实现的视频播放器功能完整示例
2018/02/01 Python
python中(str,list,tuple)基础知识汇总
2018/02/20 Python
对python函数签名的方法详解
2019/01/22 Python
详解python中的线程与线程池
2019/05/10 Python
解决python 文本过滤和清理问题
2019/08/28 Python
使用py-spy解决scrapy卡死的问题方法
2020/09/29 Python
利用纯CSS3实现文字向右循环闪过效果实例(可用于移动端)
2017/06/15 HTML / CSS
俄罗斯旅游网站:Tripadvisor俄罗斯
2017/03/21 全球购物
优秀学生获奖感言
2014/02/15 职场文书
竞选劳动委员演讲稿
2014/04/28 职场文书
企业承诺书怎么写
2014/05/24 职场文书
帮一个朋友写的求职信
2014/08/09 职场文书
2014领导干部学习焦裕禄同志先进事迹思想汇报
2014/09/19 职场文书
国庆65周年演讲稿:回首往昔,展望未来
2014/09/21 职场文书
中秋节作文(五年级)之关于月亮
2019/09/11 职场文书
一道JS算法面试题——冒泡、选择排序
2021/04/21 Javascript
Python基础之Socket通信原理
2021/04/22 Python
Java实现二分搜索树的示例代码
2022/03/17 Java/Android