深度学习入门之Pytorch 数据增强的实现


Posted in Python onFebruary 26, 2020

数据增强

卷积神经网络非常容易出现过拟合的问题,而数据增强的方法是对抗过拟合问题的一个重要方法。

2012 年 AlexNet 在 ImageNet 上大获全胜,图片增强方法功不可没,因为有了图片增强,使得训练的数据集比实际数据集多了很多'新'样本,减少了过拟合的问题,下面我们来具体解释一下。

常用的数据增强方法

常用的数据增强方法如下:
1.对图片进行一定比例缩放
2.对图片进行随机位置的截取
3.对图片进行随机的水平和竖直翻转
4.对图片进行随机角度的旋转
5.对图片进行亮度、对比度和颜色的随机变化

这些方法 pytorch 都已经为我们内置在了 torchvision 里面,我们在安装 pytorch 的时候也安装了 torchvision,下面我们来依次展示一下这些数据增强方法。

import sys
sys.path.append('..')

from PIL import Image
from torchvision import transforms as tfs

# 读入一张图片
im = Image.open('./cat.png')
im

深度学习入门之Pytorch 数据增强的实现

随机比例放缩

随机比例缩放主要使用的是 torchvision.transforms.Resize() 这个函数,第一个参数可以是一个整数,那么图片会保存现在的宽和高的比例,并将更短的边缩放到这个整数的大小,第一个参数也可以是一个 tuple,那么图片会直接把宽和高缩放到这个大小;第二个参数表示放缩图片使用的方法,比如最邻近法,或者双线性差值等,一般双线性差值能够保留图片更多的信息,所以 pytorch 默认使用的是双线性差值,你可以手动去改这个参数,更多的信息可以看看文档

# 比例缩放
print('before scale, shape: {}'.format(im.size))
new_im = tfs.Resize((100, 200))(im)
print('after scale, shape: {}'.format(new_im.size))
new_im

深度学习入门之Pytorch 数据增强的实现

随机位置截取

随机位置截取能够提取出图片中局部的信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果。在 torchvision 中主要有下面两种方式,一个是 torchvision.transforms.RandomCrop(),传入的参数就是截取出的图片的长和宽,对图片在随机位置进行截取;第二个是 torchvision.transforms.CenterCrop(),同样传入介曲初的图片的大小作为参数,会在图片的中心进行截取

# 随机裁剪出 100 x 100 的区域
random_im1 = tfs.RandomCrop(100)(im)
random_im1

深度学习入门之Pytorch 数据增强的实现

# 中心裁剪出 100 x 100 的区域
center_im = tfs.CenterCrop(100)(im)
center_im

深度学习入门之Pytorch 数据增强的实现

随机的水平和竖直方向翻转

对于上面这一张猫的图片,如果我们将它翻转一下,它仍然是一张猫,但是图片就有了更多的多样性,所以随机翻转也是一种非常有效的手段。在 torchvision 中,随机翻转使用的是 torchvision.transforms.RandomHorizontalFlip()torchvision.transforms.RandomVerticalFlip()

# 随机水平翻转
h_filp = tfs.RandomHorizontalFlip()(im)
h_filp

深度学习入门之Pytorch 数据增强的实现

# 随机竖直翻转
v_flip = tfs.RandomVerticalFlip()(im)
v_flip

深度学习入门之Pytorch 数据增强的实现

随机角度旋转

一些角度的旋转仍然是非常有用的数据增强方式,在 torchvision 中,使用 torchvision.transforms.RandomRotation() 来实现,其中第一个参数就是随机旋转的角度,比如填入 10,那么每次图片就会在 -10 ~ 10 度之间随机旋转

rot_im = tfs.RandomRotation(45)(im)
rot_im

深度学习入门之Pytorch 数据增强的实现

亮度、对比度和颜色的变化

除了形状变化外,颜色变化又是另外一种增强方式,其中可以设置亮度变化,对比度变化和颜色变化等,在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 来实现的,第一个参数就是亮度的比例,第二个是对比度,第三个是饱和度,第四个是颜色

# 亮度
bright_im = tfs.ColorJitter(brightness=1)(im) # 随机从 0 ~ 2 之间亮度变化,1 表示原图
bright_im

深度学习入门之Pytorch 数据增强的实现

# 对比度
contrast_im = tfs.ColorJitter(contrast=1)(im) # 随机从 0 ~ 2 之间对比度变化,1 表示原图
contrast_im

深度学习入门之Pytorch 数据增强的实现

# 颜色
color_im = tfs.ColorJitter(hue=0.5)(im) # 随机从 -0.5 ~ 0.5 之间对颜色变化
color_im

深度学习入门之Pytorch 数据增强的实现

上面我们讲了这么图片增强的方法,其实这些方法都不是孤立起来用的,可以联合起来用,比如先做随机翻转,然后随机截取,再做对比度增强等等,torchvision 里面有个非常方便的函数能够将这些变化合起来,就是 torchvision.transforms.Compose(),下面我们举个例子

im_aug = tfs.Compose([
  tfs.Resize(120),
  tfs.RandomHorizontalFlip(),
  tfs.RandomCrop(96),
  tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)
])
import matplotlib.pyplot as plt
%matplotlib inline
nrows = 3
ncols = 3
figsize = (8, 8)
_, figs = plt.subplots(nrows, ncols, figsize=figsize)
for i in range(nrows):
  for j in range(ncols):
    figs[i][j].imshow(im_aug(im))
    figs[i][j].axes.get_xaxis().set_visible(False)
    figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()

深度学习入门之Pytorch 数据增强的实现

可以看到每次做完增强之后的图片都有一些变化,所以这就是我们前面讲的,增加了一些'新'数据
下面我们使用图像增强进行训练网络,看看具体的提升究竟在什么地方,使用 ResNet 进行训练

使用数据增强

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from utils import train, resnet
from torchvision import transforms as tfs
# 使用数据增强
def train_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(120),
    tfs.RandomHorizontalFlip(),
    tfs.RandomCrop(96),
    tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

def test_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=train_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=test_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

深度学习入门之Pytorch 数据增强的实现

不使用数据增强

# 不使用数据增强
def data_tf(x):
  im_aug = tfs.Compose([
    tfs.Resize(96),
    tfs.ToTensor(),
    tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
  ])
  x = im_aug(x)
  return x

train_set = CIFAR10('./data', train=True, transform=data_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_set = CIFAR10('./data', train=False, transform=data_tf)
test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
train(net, train_data, test_data, 10, optimizer, criterion)

深度学习入门之Pytorch 数据增强的实现

从上面可以看出,对于训练集,不做数据增强跑 10 次,准确率已经到了 95%,而使用了数据增强,跑 10 次准确率只有 75%,说明数据增强之后变得更难了。

而对于测试集,使用数据增强进行训练的时候,准确率会比不使用更高,因为数据增强提高了模型应对于更多的不同数据集的泛化能力,所以有更好的效果。

以上就是深度学习入门之Pytorch 数据增强的实现的详细内容,更多关于Pytorch 数据增强的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
使用Python的PEAK来适配协议的教程
Apr 14 Python
python+pyqt实现右下角弹出框
Oct 26 Python
Python随机生成均匀分布在三角形内或者任意多边形内的点
Dec 14 Python
Python网络爬虫中的同步与异步示例详解
Feb 03 Python
python自动发送邮件脚本
Jun 20 Python
Python OpenCV处理图像之图像像素点操作
Jul 10 Python
Django使用paginator插件实现翻页功能的实例
Oct 24 Python
python判断文件是否存在,不存在就创建一个的实例
Feb 18 Python
python 实现return返回多个值
Nov 19 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 Python
Python3+selenium实现cookie免密登录的示例代码
Mar 18 Python
pytorch model.cuda()花费时间很长的解决
Jun 01 Python
Python基于Dlib的人脸识别系统的实现
Feb 26 #Python
python 回溯法模板详解
Feb 26 #Python
python实现信号时域统计特征提取代码
Feb 26 #Python
Python 基于FIR实现Hilbert滤波器求信号包络详解
Feb 26 #Python
python实现逆滤波与维纳滤波示例
Feb 26 #Python
Python全面分析系统的时域特性和频率域特性
Feb 26 #Python
解决pycharm每次打开项目都需要配置解释器和安装库问题
Feb 26 #Python
You might like
php 信息采集程序代码
2009/03/17 PHP
IIS下PHP连接数据库提示mysql undefined function mysql_connect()
2010/06/04 PHP
php判断页面是否是微信打开的示例(微信打开网页)
2014/04/25 PHP
php实现批量删除挂马文件及批量替换页面内容完整实例
2016/07/08 PHP
php转换上传word文件为PDF的方法【基于COM组件】
2019/06/10 PHP
ComboBox 和 DateField 在IE下消失的解决方法
2013/08/30 Javascript
javascript的parseFloat()方法精度问题探讨
2013/11/26 Javascript
IE6-IE9中tbody的innerHTML不能赋值的解决方法
2014/06/05 Javascript
jQuery中queue()方法用法实例
2014/12/29 Javascript
Angular的Bootstrap(引导)和Compiler(编译)机制
2016/06/20 Javascript
微信js-sdk预览图片接口及从拍照或手机相册中选图接口用法示例
2016/10/13 Javascript
微信小程序获取用户openId的实现方法
2017/05/23 Javascript
js时间戳转yyyy-MM-dd HH-mm-ss工具类详解
2019/04/30 Javascript
JavaScript使用表单元素验证表单的示例代码
2019/08/20 Javascript
vue element自定义表单验证请求后端接口验证
2019/12/11 Javascript
swiper自定义分页器的样式
2020/09/14 Javascript
uniapp实现横向滚动选择日期
2020/10/21 Javascript
python基础教程项目三之万能的XML
2018/04/02 Python
Python3中关于cookie的创建与保存
2018/10/21 Python
python实现图片彩色转化为素描
2019/01/15 Python
Django学习笔记之为Model添加Action
2019/04/30 Python
itchat-python搭建微信机器人(附示例)
2019/06/11 Python
django 微信网页授权登陆的实现
2019/07/30 Python
python爬虫 execjs安装配置及使用
2019/07/30 Python
Python 最强编辑器详细使用指南(PyCharm )
2019/09/16 Python
Python之Numpy的超实用基础详细教程
2019/10/23 Python
python 用 xlwings 库 生成图表的操作方法
2019/12/22 Python
英国领先的汽车轮胎和快速健康中心:Kwik Fit
2017/10/29 全球购物
荷兰游戏商店:Allyouplay
2019/03/16 全球购物
JavaScript实现前端网页版倒计时
2021/03/24 Javascript
副主任竞聘演讲稿
2014/08/18 职场文书
设备收款委托书范本
2014/10/02 职场文书
2016廉洁从业学习心得体会
2016/01/19 职场文书
Oracle 多表查询基本语法实例
2022/04/18 Oracle
python实现商品进销存管理系统
2022/05/30 Python
Vue深入理解插槽slot的使用
2022/08/05 Vue.js