python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
你真的了解Python的random模块吗?
Dec 12 Python
python中requests和https使用简单示例
Jan 18 Python
python中subprocess批量执行linux命令
Apr 27 Python
python下载微信公众号相关文章
Feb 26 Python
Python后台开发Django会话控制的实现
Apr 15 Python
Python爬虫 bilibili视频弹幕提取过程详解
Jul 31 Python
python实现高斯(Gauss)迭代法的例子
Nov 20 Python
Python hashlib模块实例使用详解
Dec 24 Python
解决Django部署设置Debug=False时xadmin后台管理系统样式丢失
Apr 07 Python
基于python检查矩阵计算结果
May 21 Python
opencv 图像轮廓的实现示例
Jul 08 Python
Python趣味挑战之实现简易版音乐播放器
May 28 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
php4的session功能评述(一)
2006/10/09 PHP
PHP 变量定义和变量替换的方法
2009/07/30 PHP
php自定文件保存session的方法
2014/12/10 PHP
PHP框架自动加载类文件原理详解
2017/06/06 PHP
PHP编程获取图片的主色调的方法【基于Imagick扩展】
2017/08/02 PHP
解决laravel上传图片之后,目录有图片,但是访问不到(404)的问题
2019/10/14 PHP
支持ie与FireFox的剪切板操作代码
2009/09/28 Javascript
javascript中有趣的反柯里化深入分析
2012/12/05 Javascript
JavaScript的instanceof运算符学习教程
2016/06/08 Javascript
Highcharts学习之坐标轴
2016/08/02 Javascript
动态生成的DOM不会触发onclick事件的原因及解决方法
2016/08/06 Javascript
jQuery插件zTree实现删除树子节点的方法示例
2017/03/08 Javascript
jQuery简单实现向列表动态添加新元素的方法示例
2017/12/25 jQuery
JS实现的input选择图片本地预览功能示例
2018/08/29 Javascript
[01:39](回顾)各路豪强针锋相对,几经鏖战四强产生
2014/07/01 DOTA
[01:16]2014DOTA2 TI专访C9战队EE:中国五强中会占三席
2014/07/10 DOTA
[03:42]2016国际邀请赛中国区预选赛首日现场玩家采访
2016/06/26 DOTA
python 多线程实现检测服务器在线情况
2015/11/25 Python
Python之Scrapy爬虫框架安装及简单使用详解
2017/12/22 Python
PyQT实现多窗口切换
2018/04/20 Python
浅谈Tensorflow由于版本问题出现的几种错误及解决方法
2018/06/13 Python
python 读取.csv文件数据到数组(矩阵)的实例讲解
2018/06/14 Python
Python发送邮件的实例代码讲解
2019/10/16 Python
Python 实现将数组/矩阵转换成Image类
2020/01/09 Python
np.random.seed() 的使用详解
2020/01/14 Python
Python如何在bool函数中取值
2020/09/21 Python
python excel多行合并的方法
2020/12/09 Python
一款纯css3实现的鼠标经过按钮特效教程
2014/11/09 HTML / CSS
毕业自我鉴定范文
2013/11/06 职场文书
写自荐信三大法宝
2014/01/24 职场文书
研讨会主持词
2014/04/02 职场文书
保护黄河倡议书
2014/05/16 职场文书
自愿离婚协议书范本
2014/09/13 职场文书
公安交警个人对照检查材料思想汇报
2014/10/01 职场文书
社区元宵节活动总结
2015/02/06 职场文书
立案决定书范文
2015/06/24 职场文书