pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python单元测试unittest实例详解
May 11 Python
Python中elasticsearch插入和更新数据的实现方法
Apr 01 Python
Python第三方库h5py_读取mat文件并显示值的方法
Feb 08 Python
Pycharm新建模板默认添加个人信息的实例
Jul 15 Python
自适应线性神经网络Adaline的python实现详解
Sep 30 Python
tensorflow获取预训练模型某层参数并赋值到当前网络指定层方式
Jan 24 Python
浅谈python 中的 type(), dtype(), astype()的区别
Apr 09 Python
使用Pycharm(Python工具)新建项目及创建Python文件的教程
Apr 26 Python
Python使用Chrome插件实现爬虫过程图解
Jun 09 Python
Python不支持 i ++ 语法的原因解析
Jul 22 Python
python 发送get请求接口详解
Nov 17 Python
pytest fixtures装饰器的使用和如何控制用例的执行顺序
Jan 28 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
PHP实现的下载远程图片自定义函数分享
2015/01/28 PHP
javascript 播放器 控制
2007/01/22 Javascript
kmock javascript 单元测试代码
2011/02/06 Javascript
javascript判断是否按回车键并解决浏览器之间的差异
2014/05/13 Javascript
NODE.JS加密模块CRYPTO常用方法介绍
2014/06/05 Javascript
JavaScript中匿名函数用法实例
2015/03/23 Javascript
jquery动画效果学习笔记(8种效果)
2015/11/13 Javascript
JavaScript操作HTML DOM节点的基础教程
2016/03/11 Javascript
AngularJs 弹出模态框(model)
2016/04/07 Javascript
Javascript 动态改变imput type属性
2016/11/01 Javascript
jQuery中的100个技巧汇总
2016/12/15 Javascript
使用JavaScript为一张图片设置备选路径的方法
2017/01/04 Javascript
vue-cli+webpack项目 修改项目名称的方法
2018/02/28 Javascript
微信小程序自定义toast弹窗效果的实现代码
2018/11/15 Javascript
vue excel上传预览和table内容下载到excel文件中
2019/12/10 Javascript
js实现列表向上无限滚动
2020/01/13 Javascript
python 二分查找和快速排序实例详解
2017/10/13 Python
Tornado高并发处理方法实例代码
2018/01/15 Python
python读取csv文件并把文件放入一个list中的实例讲解
2018/04/27 Python
Python自定义函数实现求两个数最大公约数、最小公倍数示例
2018/05/21 Python
Flask web开发处理POST请求实现(登录案例)
2018/07/26 Python
Python实现的爬取豆瓣电影信息功能案例
2019/09/15 Python
python读取word 中指定位置的表格及表格数据
2019/10/23 Python
Alba Moda德国网上商店:意大利时尚女装销售
2016/11/14 全球购物
菲律宾酒店预订网站:Hotels.com菲律宾
2017/07/12 全球购物
德国网上花店:Valentins
2018/08/15 全球购物
速卖通欧盟:Aliexpress EU
2020/08/19 全球购物
入党自我鉴定
2014/03/25 职场文书
学习实践科学发展观心得体会
2014/09/10 职场文书
财政局党的群众路线教育实践活动剖析材料
2014/10/13 职场文书
德能勤绩廉个人总结
2015/02/14 职场文书
2015新生加入学生会自荐书
2015/03/24 职场文书
保护地球的宣传语
2015/07/13 职场文书
运动会100米广播稿
2015/08/19 职场文书
python脚本框架webpy模板控制结构
2021/11/20 Python
Android 中的类文件和类加载器详情
2022/06/05 Java/Android