pytorch  网络参数 weight bias 初始化详解


Posted in Python onJune 24, 2020

权重初始化对于训练神经网络至关重要,好的初始化权重可以有效的避免梯度消失等问题的发生。

在pytorch的使用过程中有几种权重初始化的方法供大家参考。

注意:第一种方法不推荐。尽量使用后两种方法。

# not recommend
def weights_init(m):
 classname = m.__class__.__name__
 if classname.find('Conv') != -1:
  m.weight.data.normal_(0.0, 0.02)
 elif classname.find('BatchNorm') != -1:
  m.weight.data.normal_(1.0, 0.02)
  m.bias.data.fill_(0)
# recommend
def initialize_weights(m):
 if isinstance(m, nn.Conv2d):
  m.weight.data.normal_(0, 0.02)
  m.bias.data.zero_()
 elif isinstance(m, nn.Linear):
  m.weight.data.normal_(0, 0.02)
  m.bias.data.zero_()
# recommend
def weights_init(m): 
 if isinstance(m, nn.Conv2d): 
  nn.init.xavier_normal_(m.weight.data) 
  nn.init.xavier_normal_(m.bias.data)
 elif isinstance(m, nn.BatchNorm2d):
  nn.init.constant_(m.weight,1)
  nn.init.constant_(m.bias, 0)
 elif isinstance(m, nn.BatchNorm1d):
  nn.init.constant_(m.weight,1)
  nn.init.constant_(m.bias, 0)

编写好weights_init函数后,可以使用模型的apply方法对模型进行权重初始化。

net = Residual() # generate an instance network from the Net class

net.apply(weights_init) # apply weight init

补充知识:Pytorch权值初始化及参数分组

1. 模型参数初始化

# ————————————————— 利用model.apply(weights_init)实现初始化
def weights_init(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
    if m.bias is not None:
      m.bias.data.zero_()
  elif classname.find('BatchNorm') != -1:
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif classname.find('Linear') != -1:
    n = m.weight.size(1)
    m.weight.data.normal_(0, 0.01)
    m.bias.data = torch.ones(m.bias.data.size())
    
# ————————————————— 直接放在__init__构造函数中实现初始化
for m in self.modules():
  if isinstance(m, nn.Conv2d):
    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
    m.weight.data.normal_(0, math.sqrt(2. / n))
    if m.bias is not None:
      m.bias.data.zero_()
  elif isinstance(m, nn.BatchNorm2d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif isinstance(m, nn.BatchNorm1d):
    m.weight.data.fill_(1)
    m.bias.data.zero_()
  elif isinstance(m, nn.Linear):
    nn.init.xavier_uniform_(m.weight.data)
    if m.bias is not None:
      m.bias.data.zero_()
    
# —————————————————
self.weight = Parameter(torch.Tensor(out_features, in_features))
self.bias = Parameter(torch.FloatTensor(out_features))
nn.init.xavier_uniform_(self.weight)
nn.init.zero_(self.bias)
nn.init.constant_(m, initm)
# nn.init.kaiming_uniform_()
# self.weight.data.normal_(std=0.001)

2. 模型参数分组weight_decay

def separate_bn_prelu_params(model, ignored_params=[]):
  bn_prelu_params = []
  for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
      ignored_params += list(map(id, m.parameters()))  
      bn_prelu_params += m.parameters()
    if isinstance(m, nn.BatchNorm1d):
      ignored_params += list(map(id, m.parameters()))  
      bn_prelu_params += m.parameters()
    elif isinstance(m, nn.PReLU):
      ignored_params += list(map(id, m.parameters()))
      bn_prelu_params += m.parameters()
  base_params = list(filter(lambda p: id(p) not in ignored_params, model.parameters()))

  return base_params, bn_prelu_params, ignored_params

OPTIMIZER = optim.SGD([
    {'params': base_params, 'weight_decay': WEIGHT_DECAY},     
    {'params': fc_head_param, 'weight_decay': WEIGHT_DECAY * 10},
    {'params': bn_prelu_params, 'weight_decay': 0.0}
    ], lr=LR, momentum=MOMENTUM ) # , nesterov=True

Note 1:PReLU(x) = max(0,x) + a * min(0,x). Here a is a learnable parameter. When called without arguments, nn.PReLU() uses a single parameter a across all input channels. If called with nn.PReLU(nChannels), a separate a is used for each input channel.

Note 2: weight decay should not be used when learning a for good performance.

Note 3: The default number of a to learn is 1, the default initial value of a is 0.25.

3. 参数分组weight_decay?其他

第2节中的内容可以满足一般的参数分组需求,此部分可以满足更个性化的分组需求。参考:face_evoLVe_Pytorch-master

自定义schedule

def schedule_lr(optimizer):
  for params in optimizer.param_groups:
    params['lr'] /= 10.
  print(optimizer)

方法一:利用model.modules()和obj.__class__ (更普适)

# model.modules()和model.children()的区别:model.modules()会迭代地遍历模型的所有子层,而model.children()只会遍历模型下的一层
# 下面的关键词if 'model',源于模型定义文件。如model_resnet.py中自定义的所有nn.Module子类,都会前缀'model_resnet',所以可通过这种方式一次性筛选出自定义的模块
def separate_irse_bn_paras(model):
  paras_only_bn = []         
  paras_no_bn = []
  for layer in model.modules():
    if 'model' in str(layer.__class__):		      # eg. a=[1,2] type(a): <class 'list'> a.__class__: <class 'list'>
      continue
    if 'container' in str(layer.__class__):       # 去掉Sequential型的模块
      continue
    else:
      if 'batchnorm' in str(layer.__class__):
        paras_only_bn.extend([*layer.parameters()])
      else:
        paras_no_bn.extend([*layer.parameters()])  # extend()用于在列表末尾一次性追加另一个序列中的多个值(用新列表扩展原来的列表)

  return paras_only_bn, paras_no_bn

方法二:调用modules.parameters和named_parameters()

但是本质上,parameters()是根据named_parameters()获取,named_parameters()是根据modules()获取。使用此方法的前提是,须按下文1,2中的方式定义模型,或者利用Sequential+OrderedDict定义模型。

def separate_resnet_bn_paras(model):
  all_parameters = model.parameters()
  paras_only_bn = []

  for pname, p in model.named_parameters():
    if pname.find('bn') >= 0:
      paras_only_bn.append(p)
      
  paras_only_bn_id = list(map(id, paras_only_bn))
  paras_no_bn = list(filter(lambda p: id(p) not in paras_only_bn_id, all_parameters))
  
  return paras_only_bn, paras_no_bn

两种方法的区别

参数分组的区别,其实对应了模型构造时的区别。举例:

1、构造ResNet的basic block,在__init__()函数中定义了

self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = BatchNorm2d(planes)
self.relu = ReLU(inplace = True)
…

2、在forward()中定义

out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
…

3、对ResNet取model.name_parameters()返回的pname形如:

‘layer1.0.conv1.weight'
‘layer1.0.bn1.weight'
‘layer1.0.bn1.bias'
# layer对应conv2_x, …, conv5_x; '0'对应各layer中的block索引,比如conv2_x有3个block,对应索引为layer1.0, …, layer1.2; 'conv1'就是__init__()中定义的self.conv1

4、若构造model时采用了Sequential(),则model.name_parameters()返回的pname形如:

‘body.3.res_layer.1.weight',此处的1.weight实际对应了BN的weight,无法通过pname.find(‘bn')找到该模块。

self.res_layer = Sequential(
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(depth),
ReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth)
)

5、针对4中的情况,两种解决办法:利用OrderedDict修饰Sequential,或利用方法一

downsample = Sequential( OrderedDict([
(‘conv_ds', conv1x1(self.inplanes, planes * block.expansion, stride)),
(‘bn_ds', BatchNorm2d(planes * block.expansion)),
]))
# 如此,相应模块的pname将会带有'conv_ds',‘bn_ds'字样

以上这篇pytorch 网络参数 weight bias 初始化详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python程序设计入门(4)模块和包
Jun 16 Python
Python实现对一个函数应用多个装饰器的方法示例
Feb 09 Python
Python3.5内置模块之os模块、sys模块、shutil模块用法实例分析
Apr 27 Python
Python定时发送天气预报邮件代码实例
Sep 09 Python
python安装本地whl的实例步骤
Oct 12 Python
django 框架实现的用户注册、登录、退出功能示例
Nov 28 Python
python全局变量引用与修改过程解析
Jan 07 Python
Tensorflow不支持AVX2指令集的解决方法
Feb 03 Python
django实现更改数据库某个字段以及字段段内数据
Mar 31 Python
Opencv python 图片生成视频的方法示例
Nov 18 Python
安装pyinstaller遇到的各种问题(小结)
Nov 20 Python
python爬虫中url管理器去重操作实例
Nov 30 Python
可视化pytorch 模型中不同BN层的running mean曲线实例
Jun 24 #Python
python3.x中安装web.py步骤方法
Jun 23 #Python
python如何删除文件、目录
Jun 23 #Python
TensorFlow保存TensorBoard图像操作
Jun 23 #Python
python和js交互调用的方法
Jun 23 #Python
virtualenv介绍及简明教程
Jun 23 #Python
python不同系统中打开方法
Jun 23 #Python
You might like
在PHP的图形函数中显示汉字
2006/10/09 PHP
PHP动态变静态原理
2006/11/25 PHP
php四种基础算法代码实例
2013/10/29 PHP
PHP中is_file不能替代file_exists的理由
2014/03/04 PHP
php读取目录所有文件信息dir示例
2014/03/18 PHP
CI框架中通过hook的方式实现简单的权限控制
2015/01/07 PHP
Yii2处理密码加密及验证的方法
2019/05/12 PHP
js下判断 iframe 是否加载完成的完美方法
2010/10/26 Javascript
Web开发者必备的12款超赞jQuery插件
2010/12/03 Javascript
js数组的基本用法及数组根据下标(数值或字符)移除元素
2013/10/20 Javascript
js实现数组冒泡排序、快速排序原理
2016/03/08 Javascript
原生JS实现层叠轮播图
2017/05/17 Javascript
实现div内部滚动条滚动到底部和顶部的代码
2017/11/15 Javascript
Vue进度条progressbar组件功能
2018/04/17 Javascript
element-ui循环显示radio控件信息的方法
2018/08/24 Javascript
微信小程序class封装http代码实例
2019/08/24 Javascript
[00:26]TI7不朽珍藏III——冥界亚龙不朽展示
2017/07/15 DOTA
Python实现对PPT文件进行截图操作的方法
2015/04/28 Python
python中子类继承父类的__init__方法实例
2016/12/15 Python
mac系统安装Python3初体验
2018/01/02 Python
python检测空间储存剩余大小和指定文件夹内存占用的实例
2018/06/11 Python
TensorFlow基于MNIST数据集实现车牌识别(初步演示版)
2019/08/05 Python
windows环境中利用celery实现简单任务队列过程解析
2019/11/29 Python
Python unittest单元测试openpyxl实现过程解析
2020/05/27 Python
解决python对齐错误的方法
2020/07/16 Python
手把手教你将Flask应用封装成Docker服务的实现
2020/08/19 Python
用sleep间隔进行python反爬虫的实例讲解
2020/11/30 Python
CSS改变网页中鼠标选中文字背景颜色例子
2014/04/23 HTML / CSS
分享CSS3制作卡片式图片的方法
2016/07/08 HTML / CSS
欧洲最大的滑雪假期供应商之一:Sunweb Holidays
2018/01/06 全球购物
e路東瀛(JAPANiCAN)香港:日本旅游、日本酒店和温泉旅馆预订
2018/11/21 全球购物
Kappa英国官方在线商店:服装和运动器材
2020/11/22 全球购物
设计专业自荐信
2014/06/19 职场文书
党的生日活动方案
2014/08/15 职场文书
2015年教研员工作总结
2015/05/26 职场文书
一文搞懂python异常处理、模块与包
2021/06/26 Python