Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python与php实现分割文件代码
Mar 06 Python
python通过pip更新所有已安装的包实现方法
May 19 Python
CentOS中升级Python版本的方法详解
Jul 10 Python
利用python 更新ssh 远程代码 操作远程服务器的实现代码
Feb 08 Python
TensorFlow数据输入的方法示例
Jun 19 Python
python 调用钉钉机器人的方法
Feb 20 Python
python redis 批量设置过期key过程解析
Nov 26 Python
python中的split()函数和os.path.split()函数使用详解
Dec 21 Python
Pytorch 多块GPU的使用详解
Dec 31 Python
python 截取XML中bndbox的坐标中的图像,另存为jpg的实例
Mar 10 Python
Django数据模型中on_delete使用详解
Nov 30 Python
撤回我也能看到!教你用Python制作微信防撤回脚本
Jun 11 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
PHP HTML代码串截取代码
2008/12/29 PHP
zend框架实现支持sql server的操作方法
2016/12/08 PHP
php正则表达式基本知识与应用详解【经典教程】
2017/04/17 PHP
详解php框架Yaf路由重写
2017/06/20 PHP
通过PHP的Wrapper无缝迁移原有项目到新服务的实现方法
2020/04/02 PHP
javascript开发技术大全 第4章 直接量与字符集
2011/07/03 Javascript
网页防止tab键的使用快速解决方法
2013/11/07 Javascript
setTimeout内不支持jquery的选择器的解决方案
2015/04/28 Javascript
一看就懂:jsonp详解
2015/06/01 Javascript
jQuery实现可用于博客的动态滑动菜单完整实例
2015/09/17 Javascript
移动端翻页插件dropload.js(支持Zepto和jQuery)
2016/07/27 Javascript
jQuery 遍历map()方法详解
2016/11/04 Javascript
jQuery实现Select下拉列表进行状态选择功能
2017/03/30 jQuery
Nodejs之TCP服务端与客户端聊天程序详解
2017/07/07 NodeJs
webpack实用小功能介绍
2018/01/02 Javascript
webpack中的热刷新与热加载的区别
2018/04/09 Javascript
详解vue组件基础
2018/05/04 Javascript
ES6基础之默认参数值
2019/02/21 Javascript
js实现图片局部放大效果详解
2019/03/18 Javascript
JQuery常用简单动画操作方法回顾与总结
2019/12/07 jQuery
JavaScript实现HSL拾色器
2020/05/21 Javascript
JavaScript 防抖和节流遇见的奇怪问题及解决
2020/11/20 Javascript
[57:22]2018DOTA2亚洲邀请赛 4.7总决赛 LGD vs Mineski 第五场
2018/04/10 DOTA
详解C++编程中一元运算符的重载
2016/01/19 Python
Python笔试面试题小结
2019/09/07 Python
pyinstaller打包程序exe踩过的坑
2019/11/19 Python
解决torch.autograd.backward中的参数问题
2020/01/07 Python
python json.dumps() json.dump()的区别详解
2020/07/14 Python
css3实现动画的三种方式
2020/08/24 HTML / CSS
如何在Oracle中查看各个表、表空间占用空间的大小
2015/10/31 面试题
EJB的角色和三个对象
2015/12/31 面试题
文明礼仪演讲稿
2014/05/12 职场文书
企业法人授权委托书范本
2014/09/23 职场文书
习近平在党的群众路线教育实践活动总结大会上的讲话全文
2014/10/25 职场文书
一次性工伤赔偿协议书范本
2014/11/25 职场文书
Golang数据类型和相互转换
2022/04/12 Golang