python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python解析模块(ConfigParser)使用方法
Dec 10 Python
浅析Python中else语句块的使用技巧
Jun 16 Python
Python列表和元组的定义与使用操作示例
Jul 26 Python
Python实现确认字符串是否包含指定字符串的实例
May 02 Python
分享vim python缩进等一些配置
Jul 02 Python
Python3 安装PyQt5及exe打包图文教程
Jan 08 Python
解决django后台样式丢失,css资源加载失败的问题
Jun 11 Python
win10安装tensorflow-gpu1.8.0详细完整步骤
Jan 20 Python
django的403/404/500错误自定义页面的配置方式
May 21 Python
python+pygame实现坦克大战小游戏的示例代码(可以自定义子弹速度)
Aug 11 Python
Python和Bash结合在一起的方法
Nov 13 Python
MoviePy简介及Python视频剪辑自动化
Dec 18 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
Windows下PHP的任意文件执行漏洞
2006/10/09 PHP
PHP学习之整理字符串
2011/04/17 PHP
Ubuntu中搭建Nginx、PHP环境最简单的方法
2015/03/05 PHP
必须收藏的23个php实用代码片段
2016/02/02 PHP
thinkphp3.2实现上传图片的控制器方法
2016/04/28 PHP
Yii统计不同类型邮箱数量的方法
2016/10/18 PHP
Whatever:hover 无需javascript让IE支持丰富伪类
2010/06/29 Javascript
Jquery实现带动画效果的经典二级导航菜单
2013/03/22 Javascript
JQUERY对单选框(radio)操作的小例子
2013/04/25 Javascript
js实现简单随机抽奖的方法
2015/01/27 Javascript
Bootstrap每天必学之媒体对象
2015/11/30 Javascript
原生js实现轮播图的示例代码
2017/02/20 Javascript
angular.JS实现网页禁用调试、复制和剪切
2017/03/31 Javascript
Vue+ElementUI table实现表格分页
2019/12/14 Javascript
vue实现在进行增删改操作后刷新页面
2020/08/05 Javascript
如何利用node转发请求详解
2020/09/17 Javascript
Python中的getopt函数使用详解
2015/07/28 Python
python 动态加载的实现方法
2017/12/22 Python
使用Python通过win32 COM打开Excel并添加Sheet的方法
2018/05/02 Python
Python选择网卡发包及接收数据包
2019/04/04 Python
python对csv文件追加写入列的方法
2019/08/01 Python
pygame实现俄罗斯方块游戏(基础篇3)
2019/10/29 Python
天猫超市:阿里巴巴打造的网上超市
2016/11/02 全球购物
英国家用电器折扣网站:Electrical Discount UK
2018/09/17 全球购物
经典c++面试题六
2012/01/18 面试题
可贵的沉默教学反思
2014/02/06 职场文书
十周年庆典策划方案
2014/06/03 职场文书
工会优秀工作者事迹
2014/08/17 职场文书
办公室主任四风问题对照检查材料思想汇报
2014/09/28 职场文书
2015年幼儿园学前班工作总结
2015/05/18 职场文书
肖申克救赎观后感
2015/06/02 职场文书
学校食堂管理制度
2015/08/04 职场文书
祝福语集锦:给妹妹结婚的祝福语
2019/12/18 职场文书
500字作文之难忘的同学
2019/12/20 职场文书
Hive常用日期格式转换语法
2022/06/25 数据库
jdbc中自带MySQL 连接池实践示例
2022/07/23 MySQL