Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python程序设计入门(3)数组的使用
Jun 16 Python
Python实现基于权重的随机数2种方法
Apr 28 Python
Tensorflow 实现修改张量特定元素的值方法
Jul 30 Python
python儿童学游戏编程知识点总结
Jun 03 Python
华为校园招聘上机笔试题 扑克牌大小(python)
Apr 22 Python
Python 数据可视化pyecharts的使用详解
Jun 26 Python
Python中拆分字符串的操作方法
Jul 23 Python
Django实现WebSSH操作物理机或虚拟机的方法
Nov 06 Python
Python使用qrcode二维码库生成二维码方法详解
Feb 17 Python
python3.7添加dlib模块的方法
Jul 01 Python
python实现excel公式格式化的示例代码
Dec 23 Python
使用numpngw和matplotlib生成png动画的示例代码
Jan 24 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
PHP一些常用的正则表达式字符的一些转换
2008/07/29 PHP
一个非常完美的读写ini格式的PHP配置类分享
2015/02/12 PHP
PHP工程师VIM配置分享
2015/12/15 PHP
Smarty变量用法详解
2016/05/11 PHP
PHP+Ajax实现的检测用户名功能简单示例
2019/02/12 PHP
js 动态选中下拉框
2009/11/26 Javascript
jQuery 源码分析笔记(6) jQuery.data
2011/06/08 Javascript
鼠标悬浮停留三秒后自动显示大图js代码
2014/09/09 Javascript
Bootstrap学习笔记之css样式设计(2)
2016/06/07 Javascript
javascript事件冒泡简单示例
2016/06/20 Javascript
javascript滚轮控制模拟滚动条
2016/10/19 Javascript
浅析如何利用angular结合translate为项目实现国际化
2016/12/08 Javascript
Js自定义多选框效果的实例代码
2017/07/05 Javascript
微信小程序-滚动消息通知的实例代码
2017/08/03 Javascript
在Vue中使用axios请求拦截的实现方法
2018/10/25 Javascript
axios 实现post请求时把对象obj数据转为formdata
2019/10/31 Javascript
vue页面跳转实现页面缓存操作
2020/07/22 Javascript
通过实例了解Render Props回调地狱解决方案
2020/11/04 Javascript
用Python操作字符串之rindex()方法的使用
2015/05/19 Python
全面了解Nginx, WSGI, Flask之间的关系
2018/01/09 Python
python3判断url链接是否为404的方法
2018/08/10 Python
Python构建图像分类识别器的方法
2019/01/12 Python
flask应用部署到服务器的方法
2019/07/12 Python
jupyter notebook读取/导出文件/图片实例
2020/04/16 Python
python中类与对象之间的关系详解
2020/12/16 Python
护理专业自荐信
2013/12/03 职场文书
个性大学生自我评价
2013/12/04 职场文书
打架检讨书100字
2014/01/08 职场文书
总经理助理职责
2014/02/04 职场文书
我们的节日中秋活动方案
2014/08/19 职场文书
企业授权委托书范本
2014/09/22 职场文书
反腐倡廉剖析材料
2014/09/30 职场文书
详解TS数字分隔符和更严格的类属性检查
2021/05/06 Javascript
每日六道java新手入门面试题,通往自由的道路
2021/06/30 Java/Android
python非标准时间的转换
2021/07/25 Python
如何利用python实现列表嵌套字典取值
2022/06/10 Python