Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python使用正则匹配实现抓图代码分享
Apr 02 Python
python xlsxwriter库生成图表的应用示例
Mar 16 Python
python 中文件输入输出及os模块对文件系统的操作方法
Aug 27 Python
python模块导入的细节详解
Dec 10 Python
Python txt文件加入字典并查询的方法
Jan 15 Python
提升Python程序性能的7个习惯
Apr 14 Python
Python3中urlencode和urldecode的用法详解
Jul 23 Python
python多线程分块读取文件
Aug 29 Python
解决pycharm debug时界面下方不出现step等按钮及变量值的问题
Jun 09 Python
python实现批量命名照片
Jun 18 Python
python文件读取失败怎么处理
Jun 23 Python
python基于scrapy爬取京东笔记本电脑数据并进行简单处理和分析
Apr 14 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
虹吸壶煮咖啡26个注意事项
2021/03/03 冲泡冲煮
PHP json_encode中文乱码问题的解决办法
2013/09/09 PHP
在win7中搭建Linux+PHP 开发环境
2014/10/08 PHP
php微信公众号开发(4)php实现自定义关键字回复
2016/12/15 PHP
PHP脚本自动识别验证码查询汽车违章
2016/12/20 PHP
PHP获取中国时间(上海时区时间)及美国时间的方法
2017/02/23 PHP
PHP程序员简单的开展服务治理架构操作详解(三)
2020/05/14 PHP
JQuery获取各种宽度、高度(format函数)实例
2013/03/04 Javascript
chrome下jq width()方法取值为0的解决方法
2014/05/26 Javascript
jquery mobile界面数据刷新的实现方法
2016/05/28 Javascript
微信小程序 前端源码逻辑和工作流详解
2016/10/08 Javascript
js编写的treeview使用方法
2016/11/11 Javascript
基于JS实现仿京东搜索栏随滑动透明度渐变效果
2017/07/10 Javascript
JS实现的贪吃蛇游戏完整实例
2019/01/18 Javascript
layui点击数据表格添加或删除一行的例子
2019/09/12 Javascript
jQuery操作动画完整实例分析
2020/01/10 jQuery
Python 获取新浪微博的最新公共微博实例分享
2014/07/03 Python
python之文件读取一行一行的方法
2018/07/12 Python
Python 从一个文件中调用另一个文件的类方法
2019/01/10 Python
Python搭建Keras CNN模型破解网站验证码的实现
2020/04/07 Python
python和php学习哪个更有发展
2020/06/17 Python
总结30个CSS3选择器
2017/04/13 HTML / CSS
CSS3实现王者匹配时的粒子动画效果
2019/04/12 HTML / CSS
html5 Canvas画图教程(4)—未闭合的路径及渐变色的填充方法
2013/01/09 HTML / CSS
英国标志性奢侈品牌:Burberry
2016/07/28 全球购物
预订全球最佳旅行体验:Viator
2018/03/30 全球购物
Coltorti Boutique官网:来自意大利的设计师品牌买手店
2018/11/09 全球购物
英国家喻户晓的家居商店:The Range
2019/03/25 全球购物
Guess荷兰官网:美国服饰品牌
2020/01/22 全球购物
机电专业个人自荐信格式模板
2013/09/23 职场文书
个人简历自我评价八例
2013/10/31 职场文书
《雨霖铃》听课反思
2014/02/13 职场文书
团党委领导干部党的群众路线教育实践活动个人对照检查材料思想汇
2014/10/05 职场文书
2016学习雷锋精神活动倡议书
2015/04/27 职场文书
小学英语教师2015年度个人工作总结
2015/10/14 职场文书
房地产置业顾问工作总结
2015/10/23 职场文书