Pytorch 如何实现常用正则化


Posted in Python onMay 27, 2021

Stochastic Depth

论文:Deep Networks with Stochastic Depth

本文的正则化针对于ResNet中的残差结构,类似于dropout的原理,训练时对模块进行随机的删除,从而提升模型的泛化能力。

Pytorch 如何实现常用正则化

对于上述的ResNet网络,模块越在后面被drop掉的概率越大。

作者直觉上认为前期提取的低阶特征会被用于后面的层。

第一个模块保留的概率为1,之后保留概率随着深度线性递减。

对一个模块的drop函数可以采用如下的方式实现:

def drop_connect(inputs, p, training):
    """ Drop connect. """
    if not training: return inputs # 测试阶段
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    # 以样本为单位生成模块是否被drop的01向量
    binary_tensor = torch.floor(random_tensor) 
    # 因为越往后越容易被drop,所以没有被drop的值就要通过除keep_prob来放大
    output = inputs / keep_prob * binary_tensor
    return output

在Pytorch建立的Module类中,具有forward函数

可以在forward函数中进行drop:

def forward(self, x):
 x=...
 if stride == 1 and in_planes == out_planes:
        if drop_connect_rate:
            x = drop_connect(x, p=drop_connect_rate, training=self.training)
        x = x + inputs  # skip connection
    return x

主函数:

for idx, block in enumerate(self._blocks):
    drop_connect_rate = self._global_params.drop_connect_rate
    if drop_connect_rate:
        drop_connect_rate *= float(idx) / len(self._blocks)
    x = block(x, drop_connect_rate=drop_connect_rate)

补充:pytorch中的L2正则化实现方法

搭建神经网络时需要使用L2正则化等操作来防止过拟合,而pytorch不像TensorFlow能在任意卷积函数中添加L2正则化的超参,那怎么在pytorch中实现L2正则化呢?

方法如下:超级简单!

optimizer = torch.optim.Adam(net.parameters(), lr=0.001, weight_decay=5.0)

torch.optim.Adam()参数中的 weight_decay=5.0 即为L2正则化(只是pytorch换了名字),其数值即为L2正则化的惩罚系数,一般设置为1、5、10(根据需要设置,默认为0,不使用L2正则化)。

注:

pytorch中的优化函数L2正则化默认对所有网络参数进行惩罚,且只能实现L2正则化,如需只惩罚指定网络层参数或采用L1正则化,只能自己定义。。。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用xmlrpclib模块实现对百度google的ping功能
Jun 02 Python
Python设计模式编程中Adapter适配器模式的使用实例
Mar 02 Python
python利用rsa库做公钥解密的方法教程
Dec 10 Python
基于Django用户认证系统详解
Feb 21 Python
python批量设置多个Excel文件页眉页脚的脚本
Mar 14 Python
python自动登录12306并自动点击验证码完成登录的实现源代码
Apr 25 Python
Python实现获取邮箱内容并解析的方法示例
Jun 16 Python
对Python2与Python3中__bool__方法的差异详解
Nov 01 Python
Python 运行 shell 获取输出结果的实例
Jan 07 Python
python调用动态链接库的基本过程详解
Jun 19 Python
Python列表删除元素del、pop()和remove()的区别小结
Sep 11 Python
python except异常处理之后不退出,解决异常继续执行的实现
Apr 25 Python
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
解决Python字典查找报Keyerror的问题
You might like
彻底杜绝PHP的session cookie错误
2009/08/09 PHP
PHP文件上传原理简单分析
2011/05/29 PHP
php使用array_rand()函数从数组中随机选择一个或多个元素
2014/04/28 PHP
四个常见html网页乱码问题及解决办法
2015/09/08 PHP
适用于初学者的简易PHP文件上传类
2015/10/29 PHP
Laravel框架验证码类用法实例分析
2019/09/11 PHP
jQuery 自定义函数写法分享
2012/03/30 Javascript
点击按钮或链接不跳转只刷新页面的脚本整理
2013/10/22 Javascript
使用jQuery简单实现模拟浏览器搜索功能
2014/12/21 Javascript
jQuery中queue()方法用法实例
2014/12/29 Javascript
setinterval()与clearInterval()JS函数的调用方法
2015/01/21 Javascript
Bootstrap免费字体和图标网站(值得收藏)
2017/03/16 Javascript
详解Vue 中 extend 、component 、mixins 、extends 的区别
2017/12/20 Javascript
vue滚动插件better-scroll使用详解
2019/10/18 Javascript
JavaScript enum枚举类型定义及使用方法
2020/05/15 Javascript
在Python中关于中文编码问题的处理建议
2015/04/08 Python
Python实现将HTML转换成doc格式文件的方法示例
2017/11/20 Python
Python八大常见排序算法定义、实现及时间消耗效率分析
2018/04/27 Python
在Python中合并字典模块ChainMap的隐藏坑【推荐】
2019/06/27 Python
python切片(获取一个子列表(数组))详解
2019/08/09 Python
python 接口实现 供第三方调用的例子
2019/08/13 Python
python语言线程标准库threading.local解读总结
2019/11/10 Python
Python如何基于smtplib发不同格式的邮件
2019/12/30 Python
意大利综合购物网站:Giordano Shop
2016/10/21 全球购物
澳大利亚药房在线:ThePharmacy
2017/10/04 全球购物
感恩祖国演讲稿
2014/09/09 职场文书
司机工作自我鉴定
2014/09/19 职场文书
2015欢度元旦标语口号
2014/12/09 职场文书
幼儿园小班个人总结
2015/02/12 职场文书
2015年大学组织委员个人工作总结
2015/10/23 职场文书
民事调解协议书
2016/03/21 职场文书
2016年社区“6.26”禁毒日宣传活动总结
2016/04/05 职场文书
Redis6.0搭建集群Redis-cluster的方法
2021/05/08 Redis
python 如何将两个实数矩阵合并为一个复数矩阵
2021/05/19 Python
新手入门Mysql--sql执行过程
2021/06/20 MySQL
Tomcat执行startup.bat出现闪退的原因及解决办法
2022/04/20 Servers