python实现随机梯度下降(SGD)


Posted in Python onMarch 24, 2020

使用神经网络进行样本训练,要实现随机梯度下降算法。这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义):

def SGD(self, training_data, epochs, mini_batch_size, eta, test_data=None):
 if test_data:
  n_test = len(test_data)#有多少个测试集
  n = len(training_data)
  for j in xrange(epochs):
   random.shuffle(training_data)
   mini_batches = [
    training_data[k:k+mini_batch_size] 
    for k in xrange(0,n,mini_batch_size)]
   for mini_batch in mini_batches:
    self.update_mini_batch(mini_batch, eta)
   if test_data:
    print "Epoch {0}: {1}/{2}".format(j, self.evaluate(test_data),n_test)
   else:
    print "Epoch {0} complete".format(j)

其中training_data是训练集,是由很多的tuples(元组)组成。每一个元组(x,y)代表一个实例,x是图像的向量表示,y是图像的类别。
epochs表示训练多少轮。
mini_batch_size表示每一次训练的实例个数。
eta表示学习率。
test_data表示测试集。
比较重要的函数是self.update_mini_batch,他是更新权重和偏置的关键函数,接下来就定义这个函数。

def update_mini_batch(self, mini_batch,eta): 
 nabla_b = [np.zeros(b.shape) for b in self.biases]
 nabla_w = [np.zeros(w.shape) for w in self.weights]
 for x,y in mini_batch:
  delta_nabla_b, delta_nable_w = self.backprop(x,y)#目标函数对b和w的偏导数
  nabla_b = [nb+dnb for nb,dnb in zip(nabla_b,delta_nabla_b)]
  nabla_w = [nw+dnw for nw,dnw in zip(nabla_w,delta_nabla_w)]#累加b和w
 #最终更新权重为
 self.weights = [w-(eta/len(mini_batch))*nw for w, nw in zip(self.weights, nabla_w)]
 self.baises = [b-(eta/len(mini_batch))*nb for b, nb in zip(self.baises, nabla_b)]

这个update_mini_batch函数根据你传入的一些数据进行更新神经网络的权重和偏置。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
分析Python的Django框架的运行方式及处理流程
Apr 08 Python
python字典的常用操作方法小结
May 16 Python
Python正则替换字符串函数re.sub用法示例
Jan 19 Python
python虚拟环境virtualenv的安装与使用
Sep 21 Python
python处理csv中的空值方法
Jun 22 Python
对Python信号处理模块signal详解
Jan 09 Python
python opencv实现证件照换底功能
Aug 19 Python
Python面向对象魔法方法和单例模块代码实例
Mar 25 Python
Django添加bootstrap框架时无法加载静态文件的解决方式
Mar 27 Python
python 常见的排序算法实现汇总
Aug 21 Python
python/golang实现循环链表的示例代码
Sep 14 Python
Jupyter Notebook添加代码自动补全功能的实现
Jan 07 Python
Python实现将一个正整数分解质因数的方法分析
Dec 14 #Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 #Python
rabbitmq(中间消息代理)在python中的使用详解
Dec 14 #Python
用python的requests第三方模块抓取王者荣耀所有英雄的皮肤实例
Dec 14 #Python
用Python删除本地目录下某一时间点之前创建的所有文件的实例
Dec 14 #Python
python编程通过蒙特卡洛法计算定积分详解
Dec 13 #Python
Python编程产生非均匀随机数的几种方法代码分享
Dec 13 #Python
You might like
php程序的国际化实现方法(利用gettext)
2011/08/14 PHP
ThinkPHP让分页保持搜索状态的方法
2014/07/02 PHP
php调用新浪短链接API的方法
2014/11/08 PHP
php实现QQ空间获取当前用户的用户名并生成图片
2015/07/25 PHP
你不知道的文件上传漏洞php代码分析
2016/09/29 PHP
详解PHP归并排序的实现
2016/10/18 PHP
Thinkphp5 如何隐藏入口文件index.php(URL重写)
2019/10/16 PHP
laravel5.6中的外键约束示例
2019/10/23 PHP
JS 字符串连接[性能比较]
2009/05/10 Javascript
JQ获取动态加载的图片大小的正确方法分享
2013/11/08 Javascript
js或jquery实现页面打印可局部打印
2014/03/27 Javascript
JS按回车键实现登录的方法
2014/08/25 Javascript
移动设备web开发首选框架:zeptojs介绍
2015/01/29 Javascript
javascript模拟C#格式化字符串
2015/08/26 Javascript
微信小程序开发之大转盘 仿天猫超市抽奖实例
2016/12/08 Javascript
bootstrap datetimepicker日期插件超详细使用方法介绍
2017/02/23 Javascript
vue轮播图插件vue-awesome-swiper的使用代码实例
2017/07/10 Javascript
微信小程序图片右边加两行文字的代码
2020/04/23 Javascript
vue路由的配置和页面切换详解
2020/09/09 Javascript
vue打开其他项目页面并传入数据详解
2020/11/25 Vue.js
python抓取网页图片示例(python爬虫)
2014/04/27 Python
Python对象体系深入分析
2014/10/28 Python
python实现模拟按键,自动翻页看u17漫画
2015/03/17 Python
python使用正则表达式提取网页URL的方法
2015/05/26 Python
python flask实现分页效果
2017/06/27 Python
详解Python进阶之切片的误区与高级用法
2018/12/24 Python
Python类型转换的魔术方法详解
2020/12/23 Python
世界上最大的在线汽车租赁预订平台:Rentalcars.com(支持中文)
2018/10/12 全球购物
你们项目是如何进行变更控制的
2015/08/26 面试题
自动化工程专业个人应聘自荐信
2013/09/26 职场文书
八年级英语教学计划
2015/01/23 职场文书
警告通知
2015/04/25 职场文书
小学四年级班主任工作经验交流材料
2015/11/02 职场文书
Redis持久化与主从复制的实践
2021/04/27 Redis
CSS3实现指纹特效代码
2022/03/17 HTML / CSS
Win10/Win11 任务栏替换成经典样式
2022/04/19 数码科技