pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用PyFetion来发送短信的例子
Apr 22 Python
python任务调度实例分析
May 19 Python
Python 装饰器深入理解
Mar 16 Python
浅谈Matplotlib简介和pyplot的简单使用——文本标注和箭头
Jan 09 Python
python爬虫爬取网页表格数据
Mar 07 Python
Python实现输出某区间范围内全部素数的方法
May 02 Python
python之pyqt5通过按钮改变Label的背景颜色方法
Jun 13 Python
通过python实现随机交换礼物程序详解
Jul 10 Python
python+openCV调用摄像头拍摄和处理图片的实现
Aug 06 Python
Python字符串中删除特定字符的方法
Jan 15 Python
Python3加密解密库Crypto的RSA加解密和签名/验签实现方法实例
Feb 11 Python
浅谈Python的方法解析顺序(MRO)
Mar 05 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
转换中文日期的PHP程序
2006/10/09 PHP
给php新手谈谈我的学习心得
2007/02/25 PHP
php实现mysql数据库备份类
2008/03/20 PHP
php self,$this,const,static,->的使用
2009/10/22 PHP
PHP用星号隐藏部份用户名、身份证、IP、手机号等实例
2014/04/08 PHP
php利用scws实现mysql全文搜索功能的方法
2014/12/25 PHP
PHP实现搜索地理位置及计算两点地理位置间距离的实例
2016/01/08 PHP
PHP中利用sleep函数实现定时执行功能实现代码
2016/08/25 PHP
php基于闭包实现函数的自调用(递归)实例分析
2016/11/11 PHP
PHP获取日期对应星期、一周日期、星期开始与结束日期的方法
2018/06/22 PHP
日期 时间js控件
2009/05/07 Javascript
Dom与浏览器兼容性说明
2010/10/25 Javascript
jquery中页面Ajax方法$.load的功能使用介绍
2014/10/20 Javascript
使用javascript实现简单的选项卡切换
2015/01/09 Javascript
Angular.JS判断复选框checkbox是否选中并实时显示
2016/11/30 Javascript
详解Angular的双向数据绑定(MV-VM)
2016/12/26 Javascript
详解vue-cli 构建Vue项目遇到的坑
2017/08/30 Javascript
JS去掉字符串末尾的标点符号及删除最后一个字符的方法
2017/10/24 Javascript
JavaScript学习笔记之图片库案例分析
2019/01/08 Javascript
微信小程序和百度的语音识别接口详解
2019/05/06 Javascript
vue项目中将element-ui table表格写成组件的实现代码
2019/06/12 Javascript
解决layui使用layui-icon出现默认图标的问题
2019/09/11 Javascript
详解JavaScript中new操作符的解析和实现
2020/09/04 Javascript
[09:13]2014DOTA2国际邀请赛 中国区预选赛coser表演
2014/05/23 DOTA
[43:57]Liquid vs Mineski 2019国际邀请赛小组赛 BO2 第二场 8.16
2019/08/19 DOTA
Python可变参数用法实例分析
2017/04/02 Python
Python实现数据可视化看如何监控你的爬虫状态【推荐】
2018/08/10 Python
Python try except异常捕获机制原理解析
2020/04/18 Python
python模块内置属性概念及实例
2021/02/18 Python
金融专业毕业生推荐信
2013/11/26 职场文书
大二学生学习个人自我评价
2014/01/19 职场文书
《一本男孩子必读的书》教学反思
2014/02/19 职场文书
高校自主招生自荐信2015
2015/03/04 职场文书
2015年光棍节活动总结
2015/03/24 职场文书
未婚证明范本
2015/06/15 职场文书
《用字母表示数》教学反思
2016/02/17 职场文书