pytorch实现mnist分类的示例讲解


Posted in Python onJanuary 10, 2020

torchvision包 包含了目前流行的数据集,模型结构和常用的图片转换工具。

torchvision.datasets中包含了以下数据集

MNIST
COCO(用于图像标注和目标检测)(Captioning and Detection)
LSUN Classification
ImageFolder
Imagenet-12
CIFAR10 and CIFAR100
STL10

torchvision.models

torchvision.models模块的 子模块中包含以下模型结构。
AlexNet
VGG
ResNet
SqueezeNet
DenseNet You can construct a model with random weights by calling its constructor:

pytorch torchvision transform

对PIL.Image进行变换

from __future__ import print_function
import argparse #Python 命令行解析工具
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torchvision import datasets, transforms

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.conv2_drop = nn.Dropout2d()
    self.fc1 = nn.Linear(320, 50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = F.relu(F.max_pool2d(self.conv1(x), 2))
    x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
    x = x.view(-1, 320)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, training=self.training)
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

def train(args, model, device, train_loader, optimizer, epoch):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % args.log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))

def test(args, model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
      pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
      correct += pred.eq(target.view_as(pred)).sum().item()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

def main():
  # Training settings
  parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
  parser.add_argument('--batch-size', type=int, default=64, metavar='N',
            help='input batch size for training (default: 64)')
  parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
            help='input batch size for testing (default: 1000)')
  parser.add_argument('--epochs', type=int, default=10, metavar='N',
            help='number of epochs to train (default: 10)')
  parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
            help='learning rate (default: 0.01)')
  parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
            help='SGD momentum (default: 0.5)')
  parser.add_argument('--no-cuda', action='store_true', default=False,
            help='disables CUDA training')
  parser.add_argument('--seed', type=int, default=1, metavar='S',
            help='random seed (default: 1)')
  parser.add_argument('--log-interval', type=int, default=10, metavar='N',
            help='how many batches to wait before logging training status')
  args = parser.parse_args()
  use_cuda = not args.no_cuda and torch.cuda.is_available()

  torch.manual_seed(args.seed)

  device = torch.device("cuda" if use_cuda else "cpu")

  kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
  train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
            transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
  test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
              transforms.ToTensor(),
              transforms.Normalize((0.1307,), (0.3081,))
            ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


  model = Net().to(device)
  optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

  for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(args, model, device, test_loader)


if __name__ == '__main__':
  main()

以上这篇pytorch实现mnist分类的示例讲解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中类的定义与使用
Apr 11 Python
python获取代理IP的实例分享
May 07 Python
解决Django migrate No changes detected 不能创建表的问题
May 27 Python
TensorFlow用expand_dim()来增加维度的方法
Jul 26 Python
程序员写Python时的5个坏习惯,你有几条?
Nov 26 Python
OpenCV 边缘检测
Jul 10 Python
Python3.7 读取 mp3 音频文件生成波形图效果
Nov 05 Python
手把手教你进行Python虚拟环境配置教程
Feb 03 Python
pytorch模型存储的2种实现方法
Feb 14 Python
Python面向对象程序设计之私有变量,私有方法原理与用法分析
Mar 23 Python
Python 高效编程技巧分享
Sep 10 Python
python使用selenium爬虫知乎的方法示例
Oct 28 Python
pytorch 准备、训练和测试自己的图片数据的方法
Jan 10 #Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
You might like
一步一步学习PHP(8) php 数组
2010/03/05 PHP
php输入流php://input使用示例(php发送图片流到服务器)
2013/12/25 PHP
php二维数组转成字符串示例
2014/02/17 PHP
PHP实现的sqlite数据库连接类
2014/12/12 PHP
腾讯微博提示missing parameter errorcode 102 错误的解决方法
2014/12/22 PHP
laravel实现简单用户权限的示例代码
2019/05/28 PHP
js 分页全选或反选标识实现代码
2011/08/09 Javascript
FusionCharts图表显示双Y轴双(多)曲线
2012/11/22 Javascript
JavaScript原型链示例分享
2014/01/26 Javascript
Jquery给基本控件的取值、赋值示例
2014/05/23 Javascript
Backbone.js中的集合详解
2015/01/14 Javascript
JS实现日期时间动态显示的方法
2015/12/07 Javascript
jQuery插件echarts去掉垂直网格线用法示例
2017/03/03 Javascript
Web前端框架Angular4.0.0 正式版发布
2017/03/28 Javascript
浅谈vue的几种绑定变量的值 防止其改变的方法
2018/03/01 Javascript
jQuery实现的点击按钮改变样式功能示例
2018/07/21 jQuery
Python正则表达式匹配中文用法示例
2017/01/17 Python
Python中函数及默认参数的定义与调用操作实例分析
2017/07/25 Python
Python读写/追加excel文件Demo分享
2018/05/03 Python
Python使用numpy模块创建数组操作示例
2018/06/20 Python
使用TensorFlow实现SVM
2018/09/06 Python
python实现顺序表的简单代码
2018/09/28 Python
详解Python requests 超时和重试的方法
2018/12/18 Python
关于PyTorch 自动求导机制详解
2019/08/18 Python
《风娃娃》教学反思
2014/04/19 职场文书
企业委托书范本
2014/09/13 职场文书
2016大学自主招生推荐信范文
2015/03/23 职场文书
幼儿园百日安全活动总结
2015/05/07 职场文书
2015年电厂工作总结范文
2015/05/13 职场文书
教师年度考核自我评鉴
2015/08/11 职场文书
2016年秋季运动会加油稿
2015/12/21 职场文书
SQL Server基本使用和简单的CRUD操作
2021/04/05 SQL Server
MySQL8.0.24版本Release Note的一些改进点
2021/04/22 MySQL
MySQL Router的安装部署
2021/04/24 MySQL
Keras在mnist上的CNN实践,并且自定义loss函数曲线图操作
2021/05/25 Python
SQL Server代理:理解SQL代理错误日志处理方法
2021/06/30 SQL Server