pytorch 图像中的数据预处理和批标准化实例


Posted in Python onJanuary 15, 2020

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')
 
import torch
 
def simple_batch_norm_1d(x, gamma, beta):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
   
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x, dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x
 
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net, self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1, 6, 3, padding=1),
      nn.BatchNorm2d(6),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2),
      nn.Conv2d(6, 16, 5),
      nn.BatchNorm2d(16),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2)
    )
    
    self.classfy = nn.Linear(400, 10)
  def forward(self, x):
    x = self.stage1(x)
    x = x.view(x.shape[0], -1)
    x = self.classfy(x)
    return x
 
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
 
 
train(net, train_data, test_data, 5, optimizer, criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中精确输出JSON浮点数的方法
Apr 18 Python
python图像处理之反色实现方法
May 30 Python
Python+django实现文件上传
Jan 17 Python
Python中基础的socket编程实战攻略
Jun 01 Python
Swift 3.0在集合类数据结构上的一些新变化总结
Jul 11 Python
Python算法之图的遍历
Nov 16 Python
python中set()函数简介及实例解析
Jan 09 Python
PyQt5每天必学之单行文本框
Apr 19 Python
使用11行Python代码盗取了室友的U盘内容
Oct 23 Python
python Cartopy的基础使用详解
Nov 01 Python
10个示例带你掌握python中的元组
Nov 23 Python
教你用Python+selenium搭建自动化测试环境
Jun 18 Python
pytorch实现特殊的Module--Sqeuential三种写法
Jan 15 #Python
python实现删除列表中某个元素的3种方法
Jan 15 #Python
python opencv根据颜色进行目标检测的方法示例
Jan 15 #Python
Python基于Tensor FLow的图像处理操作详解
Jan 15 #Python
OpenCV哈里斯(Harris)角点检测的实现
Jan 15 #Python
Pytorch模型转onnx模型实例
Jan 15 #Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
Jan 15 #Python
You might like
粗略计算在线时间,bug:ip相同
2006/12/09 PHP
PHP的简易冒泡法代码分享
2012/08/28 PHP
PHP生成不同颜色、不同大小的tag标签函数
2013/09/23 PHP
PHP yield关键字功能与用法分析
2019/01/03 PHP
jQuery异步获取json数据方法汇总
2014/12/22 Javascript
js对象的复制继承实例
2015/01/10 Javascript
JavaScript动态设置div的样式的方法
2015/12/26 Javascript
JavaScript中数组添加值和访问值常见问题
2016/02/06 Javascript
Jquery Easyui验证组件ValidateBox使用详解(20)
2016/12/18 Javascript
jQuery实现遍历复选框的方法示例
2017/03/06 Javascript
Vue2.0表单校验组件vee-validate的使用详解
2017/05/02 Javascript
vue的状态管理模式vuex
2017/11/30 Javascript
解决iview多表头动态更改列元素发生的错误的方法
2018/11/02 Javascript
Vue props 单向数据流的实现
2018/11/06 Javascript
浅谈layer弹出层按钮颜色修改方法
2019/09/11 Javascript
vue实现移动端省市区选择
2019/09/27 Javascript
[01:46]DOTA2上海特锦赛小组赛英文解说KotlGuy采访
2016/02/27 DOTA
[36:22]VP vs Serenity 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
[01:03:03]VP vs Mineski 2018国际邀请赛淘汰赛BO3 第一场 8.22
2018/08/23 DOTA
Python连接mysql数据库的正确姿势
2016/02/03 Python
Python二维码生成识别实例详解
2019/07/16 Python
python+numpy实现的基本矩阵操作示例
2019/07/19 Python
Python 内置变量和函数的查看及说明介绍
2019/12/25 Python
基于python实现对文件进行切分行
2020/04/26 Python
Html5页面在微信端的分享的实现方法
2018/08/30 HTML / CSS
在线购买廉价折扣书籍和小说:BookOutlet.com
2018/02/19 全球购物
Giglio俄罗斯奢侈品购物网:男士、女士、儿童高级时装
2018/07/27 全球购物
业务部主管岗位职责
2014/01/29 职场文书
培训主管的职业生涯规划
2014/03/06 职场文书
幼儿园秋游感想
2014/03/12 职场文书
恶搞卫生巾广告词
2014/03/18 职场文书
服务承诺书范文
2014/05/19 职场文书
2016年第二十五次全国助残日活动总结
2016/04/01 职场文书
2019年工作总结范文
2019/05/21 职场文书
python 破解加密zip文件的密码
2021/04/22 Python
使用CSS实现按钮边缘跑马灯动画
2023/05/07 HTML / CSS