使用PyTorch训练一个图像分类器实例


Posted in Python onJanuary 08, 2020

如下所示:

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

print("torch: %s" % torch.__version__)
print("tortorchvisionch: %s" % torchvision.__version__)
print("numpy: %s" % np.__version__)

Out:

torch: 1.0.0
tortorchvisionch: 0.2.1
numpy: 1.15.4

数据从哪儿来?

通常来说,你可以通过一些python包来把图像、文本、音频和视频数据加载为numpy array。然后将其转换为torch.*Tensor。

图像。Pillow、OpenCV是用得比较多的

音频。scipy和librosa

文本。纯Python或者Cython就可以完成数据加载,可以在NLTK和SpaCy找到数据

对于计算机视觉而言,我们有torchvision包,它可以用来加载一下常用数据集如Imagenet、CIFAR10、MINIST等等,也有一些常用的为图像准备数据转换例如torchvision.datasets和torch.utils.data.DataLoader。

这次的教程中,我们使用CIFAR10数据集,他有‘airplane', ‘automobile', ‘bird', ‘cat', ‘deer', ‘dog', ‘frog', ‘horse', ‘ship', ‘truck'这几个类别的图像。图像大小都是3x32x32的。也就是说,图像都是三通道的,每一张图的尺寸都是32x32。

使用PyTorch训练一个图像分类器实例

训练一个图像分类器

步骤如下:

使用torchvision加载、归一化训练集和测试集

定义卷积神经网络

定义损失函数

使用训练集训练网络

使用测试集测试网络

1. 加载、归一化CIFAR10

我们可以使用torchvision很轻松的完成

torchvision的数据集是基于PILImage的,数值是[0, 1],我们需要将其转成范围为[-1, 1]的Tensor

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                    download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, 
                     shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, 
                    download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, 
                     shuffle=True, num_workers=4)
classes = ('plane', 'car', 'bird', 'cat', 
      'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Out:

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz
Files already downloaded and verified

让我们来看看训练集的图片

# 显示一张图片
def imshow(img):
  img = img / 2 + 0.5   # 逆归一化
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
  plt.show()


# 任意地拿到一些图片
dataiter = iter(trainloader)
images, labels = dataiter.next()

# 显示图片
imshow(torchvision.utils.make_grid(images))
# 显示类标
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

Out:

使用PyTorch训练一个图像分类器实例

truck  dog ship  dog

2. 定义卷积神经网络

可以直接复制神经网络的代码,修改里面的几层即可。

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)
    
  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

net = Net()

3. 定义损失函数和优化器

使用多分类交叉熵损失函数,和带有momentum的SGD作为优化器

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9)

4. 训练网络

我们直接使用循环语句遍历数据集即可完成训练

nums_epoch = 2
for epoch in range(nums_epoch):
  _loss = 0.0
  for i, (inputs, labels) in enumerate(trainloader, 0):
    inputs, labels = inputs.to(device), labels.to(device)
    optimizer.zero_grad()
    
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    
    _loss += loss.item()
    if i % 2000 == 1999:  # 每2000步打印一次损失值
      print('[%d, %5d] loss: %.3f' %
         (epoch + 1, i + 1, _loss / 2000))
      _loss = 0.0

print('Finished Training')

Out:

[1, 2000] loss: 1.178
[1, 4000] loss: 1.200
[1, 6000] loss: 1.168
[1, 8000] loss: 1.175
[1, 10000] loss: 1.185
[1, 12000] loss: 1.165
[2, 2000] loss: 1.073
[2, 4000] loss: 1.066
[2, 6000] loss: 1.100
[2, 8000] loss: 1.107
[2, 10000] loss: 1.083
[2, 12000] loss: 1.103
Finished Training

5. 测试网络

这个网络已经训练了两个epoch,我们现在来看看这个网络是不是学到了一些什么东西。

我们让这个神经网络预测几张图片,看看它的答案与真实答案的差别。

下面我们选取一些测试数据集中的数据,看看他们的真实标签。

# 展示测试数据集
dataiter = iter(testloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print('GraoundTruth: ', ' '.join(['%5s' % classes[labels[j]] for j in range(4)]))

Out:

使用PyTorch训练一个图像分类器实例

GraoundTruth:  ship ship deer ship

接着我们让神经网络来给出预测标签

神经网络的输出是10个信号值,信号值最高的那个神经元表示整个网络的预测值,所以我们需要拿到信号最强的那个节点的索引值

# 展示预测值
outputs = net(images)
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join(['%5s' % classes[predicted[j]] for j in range(4)]))

Out:

Predicted:  car ship horse ship

下面我们对整个测试集做一次评估:

# 评估测试数据集
correct, total = 0, 0
with torch.no_grad():
  for images, labels in testloader:
    outputs = net(images)
    _, predicted = torch.max(outputs, 1)
    total += labels.size(0)
    correct += (labels == predicted).sum().item()
  
print('Accuracy of the network on the 10000 test images: %d %%' % (
  100 * correct / total))

Out:

Accuracy of the network on the 10000 test images: 58 %

整个结果比随机猜要好得多(随机猜是10%的概率)。看来我们的神经网络还是学到了点东西。

下面我们来看看它在哪一个类别的分类上做得最好:

# 按类标评估
n_classes = len(classes)
class_correct, class_total = [0]*n_classes, [0]* n_classes

with torch.no_grad():
  for images, labels in testloader:
    outputs = net(images)
    _, predicted = torch.max(outputs, 1)
    is_correct = (labels == predicted).squeeze()
    for i in range(len(labels)):
      label = labels[i]
      class_total[label] += 1
      class_correct[label] += is_correct[i].item()

for i in range(n_classes):
  print('Accuracy of %5s: %.2f %%' % (
    classes[i], 100.0 * class_correct[i] / class_total[i]
  ))

Out:

Accuracy of plane: 67.00 %
Accuracy of  car: 71.50 %
Accuracy of bird: 55.20 %
Accuracy of  cat: 45.60 %
Accuracy of deer: 38.20 %
Accuracy of  dog: 47.00 %
Accuracy of frog: 78.80 %
Accuracy of horse: 55.90 %
Accuracy of ship: 72.70 %
Accuracy of truck: 57.50 %

在GPU上训练

就像把Tensor从CPU转移到GPU一样,神经网络也可以转移到GPU上

首先需要检查是否有可用的GPU

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 假设我们在支持CUDA的机器上,我们可以打印出CUDA设备:

print(device)

Out:

cuda:0

我们假设device已经是CUDA设备了

下面命令将递归的将所有模块和参数、缓存转移到CUDA设备上去

net.to(device)

Out:

Net(
 (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
 (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
 (fc1): Linear(in_features=400, out_features=120, bias=True)
 (fc2): Linear(in_features=120, out_features=84, bias=True)
 (fc3): Linear(in_features=84, out_features=10, bias=True)
)

注意,在训练过程中的传入输入数据时,也需要转移到GPU上

并且,需要重新实例化优化器,否则会报错

inputs, labels = inputs.to(device), labels.to(device)

练习:尝试增加神经网络的宽度。第一个nn.Conv2d的第二个参数和第二个nn.Conv2d的第一个参数的值必须一样。看看会有什么样的效果。

以上这篇使用PyTorch训练一个图像分类器实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现识别手写数字 简易图片存储管理系统
Jan 29 Python
python使用numpy读取、保存txt数据的实例
Oct 14 Python
Python实现繁?转为简体的方法示例
Dec 18 Python
Python实现的统计文章单词次数功能示例
Jul 08 Python
Django上线部署之IIS的配置方法
Aug 22 Python
python实现差分隐私Laplace机制详解
Nov 25 Python
解决python 读取 log日志的编码问题
Dec 24 Python
python字典setdefault方法和get方法使用实例
Dec 25 Python
基于python3实现倒叙字符串
Feb 18 Python
Numpy中ndim、shape、dtype、astype的用法详解
Jun 14 Python
pytorch分类模型绘制混淆矩阵以及可视化详解
Apr 07 Python
Python编写车票订购系统 Python实现快递收费系统
Aug 14 Python
pytorch 实现将自己的图片数据处理成可以训练的图片类型
Jan 08 #Python
pytorch下大型数据集(大型图片)的导入方式
Jan 08 #Python
Python 实现训练集、测试集随机划分
Jan 08 #Python
Pyecharts绘制全球流向图的示例代码
Jan 08 #Python
PyTorch 解决Dataset和Dataloader遇到的问题
Jan 08 #Python
使用PyTorch将文件夹下的图片分为训练集和验证集实例
Jan 08 #Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 #Python
You might like
深入解析PHP的引用计数机制
2013/06/14 PHP
PHP中的插件机制原理和实例
2014/07/08 PHP
PHP中使用foreach()遍历二维数组的简单实例
2016/06/13 PHP
javaScript复制功能调用实现方案
2012/12/13 Javascript
javascript调试过程中找不到哪里出错的可能原因
2013/12/16 Javascript
js Dialog 去掉右上角的X关闭功能
2014/04/23 Javascript
使用jQuery不判断浏览器高度解决iframe自适应高度问题
2014/12/16 Javascript
angularjs基础教程
2014/12/25 Javascript
基于bootstrap插件实现autocomplete自动完成表单
2016/05/07 Javascript
详解AngularJs HTTP响应拦截器实现登陆、权限校验
2017/04/11 Javascript
jquery 禁止鼠标右键并监听右键事件
2017/04/27 jQuery
vue bootstrap小例子一枚
2017/06/09 Javascript
原生javascript实现文件异步上传的实例讲解
2017/10/26 Javascript
vue2.0之多页面的开发的示例
2018/01/30 Javascript
jQuery实现表格的增、删、改操作示例
2019/01/27 jQuery
Python处理json字符串转化为字典的简单实现
2016/07/07 Python
详解Python里使用正则表达式的ASCII模式
2017/11/02 Python
python检索特定内容的文本文件实例
2018/06/05 Python
对Python 窗体(tkinter)文本编辑器(Text)详解
2018/10/11 Python
python如何查看微信消息撤回
2018/11/27 Python
梅尔频率倒谱系数(mfcc)及Python实现
2019/06/18 Python
如何使用Python处理HDF格式数据及可视化问题
2020/06/24 Python
Pycharm 设置默认解释器路径和编码格式的操作
2021/02/05 Python
英国复古和经典球衣网站:Vintage Football Shirts
2018/10/05 全球购物
个人简历自我鉴定
2013/10/11 职场文书
喷漆工的岗位职责
2014/03/17 职场文书
校庆活动方案
2014/03/31 职场文书
《彩色世界》教学反思
2014/04/12 职场文书
《红军不怕远征难》教学反思
2014/04/14 职场文书
员工安全生产承诺书
2014/05/22 职场文书
婚礼嘉宾致辞
2015/07/28 职场文书
教师病假条范文
2015/08/17 职场文书
2016年第104个国际护士节活动总结
2016/04/06 职场文书
HTML+CSS制作心跳特效的实现
2021/05/26 HTML / CSS
html+css实现赛博朋克风格按钮
2021/05/26 HTML / CSS
Win10玩csgo闪退如何解决?Win10玩csgo闪退的解决方法
2022/07/23 数码科技