pytorch实现mnist分类的示例讲解


Posted in Python onJanuary 10, 2020

torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。

torchvision.datasets中包含了以下数据集

MNIST
COCO(用于图像标注和目标检测)(Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10

torchvision.models

torchvision.models模块的 子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet You can construct a model with random weights by calling its constructor:

pytorch torchvision transform

对PIL.Image进行变换

from __future__ import print_function
import argparse #Python 命令行解析工具
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torchvision import datasets, transforms

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % args.log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test(args, model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
      pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

def main():
  # Training settings
  parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  parser.add_argument('--batch-size', type=int, default=64, metavar='N',
            help='input batch size for training (default: 64)')
  parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
            help='input batch size for testing (default: 1000)')
  parser.add_argument('--epochs', type=int, default=10, metavar='N',
            help='number of epochs to train (default: 10)')
  parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
            help='learning rate (default: 0.01)')
  parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
            help='SGD momentum (default: 0.5)')
  parser.add_argument('--no-cuda', action='store_true', default=False,
            help='disables CUDA training')
  parser.add_argument('--seed', type=int, default=1, metavar='S',
            help='random seed (default: 1)')
  parser.add_argument('--log-interval', type=int, default=10, metavar='N',
            help='how many batches to wait before logging training status')
  args = parser.parse_args()
  use_cuda = not args.no_cuda and torch.cuda.is_available()

  torch.manual_seed(args.seed)

  device = torch.device("cuda" if use_cuda else "cpu")

  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
  test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


  model = Net().to(device)
  optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

  for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(args, model, device, test_loader)


if __name__ == '__main__':
  main()

以上这篇pytorch实现mnist分类的示例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
在Python中使用zlib模块进行数据压缩的教程
Jun 26 Python
Python的Flask框架中使用Flask-Migrate扩展迁移数据库的教程
Jun 14 Python
windows下安装Python的XlsxWriter模块方法
May 03 Python
python入门:这篇文章带你直接学会python
Sep 14 Python
python requests爬取高德地图数据的实例
Nov 10 Python
对Pandas MultiIndex(多重索引)详解
Nov 16 Python
在Pycharm中对代码进行注释和缩进的方法详解
Jan 20 Python
PyTorch的深度学习入门之PyTorch安装和配置
Jun 27 Python
解决Python中回文数和质数的问题
Nov 24 Python
python json load json 数据后出现乱序的解决方案
Feb 27 Python
python 的numpy库中的mean()函数用法介绍
Mar 03 Python
python3爬虫中异步协程的用法
Jul 10 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 #Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
You might like
PHP通过COM使用ADODB的简单例子
2006/12/31 PHP
常见的PHP五种设计模式小结
2011/03/23 PHP
PHP中preg_match正则匹配中的/u、/i、/s含义
2015/04/17 PHP
PHP 实现的将图片转换为TXT
2015/10/21 PHP
php构造函数与析构函数
2016/04/23 PHP
PHP 用session与gd库实现简单验证码生成与验证的类方法
2016/11/15 PHP
一组JS创建和操作表格的函数集合
2009/05/07 Javascript
jquery 简短右键菜单 多浏览器兼容
2010/01/01 Javascript
JS 如果改变span标签的是否隐藏属性
2011/10/06 Javascript
jQuery学习笔记 操作jQuery对象 CSS处理
2012/09/19 Javascript
纯JavaScript实现HTML5 Canvas六种特效滤镜示例
2013/06/28 Javascript
JS简单实现登陆验证附效果图
2013/11/19 Javascript
jquery delay()介绍及使用指南
2014/09/02 Javascript
jQuery 复合选择器应用的几个例子
2014/09/11 Javascript
学习JavaScript设计模式(多态)
2015/11/25 Javascript
js删除局部变量的实现方法
2016/06/25 Javascript
Vue.js仿Metronic高级表格(二)数据渲染
2017/04/19 Javascript
Vue之Watcher源码解析(2)
2017/07/19 Javascript
jquery实现自定义树形表格的方法【自定义树形结构table】
2019/07/12 jQuery
如何在vue中使用HTML 5 拖放API
2021/01/14 Vue.js
[06:35]2014DOTA2国际邀请赛 老男孩梦圆西雅图中国军团世界最强
2014/07/22 DOTA
[09:47]2018DOTA2亚洲邀请赛4.5SOLO赛 No[o]ne vs Sumail
2018/04/06 DOTA
[48:22]VGJ.S vs VG 2018国际邀请赛小组赛BO2 第一场 8.16
2018/08/17 DOTA
在Gnumeric下使用Python脚本操作表格的教程
2015/04/14 Python
python实现查找excel里某一列重复数据并且剔除后打印的方法
2015/05/26 Python
Python实现的自定义多线程多进程类示例
2018/03/23 Python
python 爬取疫情数据的源码
2020/02/09 Python
xadmin使用formfield_for_dbfield函数过滤下拉表单实例
2020/04/07 Python
python中列表的含义及用法
2020/05/26 Python
keras slice layer 层实现方式
2020/06/11 Python
jupyter notebook 写代码自动补全的实现
2020/11/02 Python
HTML5的革新 结构之美
2011/06/20 HTML / CSS
java程序员面试交流
2012/11/29 面试题
故意伤害人身损害赔偿协议书
2014/11/19 职场文书
安全学习心得体会范文
2016/01/18 职场文书
Mysql systemctl start mysqld报错的问题解决
2021/06/03 MySQL