Pytorch入门之mnist分类实例


Posted in Python onApril 14, 2018

本文实例为大家分享了Pytorch入门之mnist分类的具体代码,供大家参考,具体内容如下

#!/usr/bin/env python
# -*- coding: utf-8 -*-
__author__ = 'denny'
__time__ = '2017-9-9 9:03'

import torch
import torchvision
from torch.autograd import Variable
import torch.utils.data.dataloader as Data

train_data = torchvision.datasets.MNIST(
 './mnist', train=True, transform=torchvision.transforms.ToTensor(), download=True
)
test_data = torchvision.datasets.MNIST(
 './mnist', train=False, transform=torchvision.transforms.ToTensor()
)
print("train_data:", train_data.train_data.size())
print("train_labels:", train_data.train_labels.size())
print("test_data:", test_data.test_data.size())

train_loader = Data.DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = Data.DataLoader(dataset=test_data, batch_size=64)


class Net(torch.nn.Module):
 def __init__(self):
 super(Net, self).__init__()
 self.conv1 = torch.nn.Sequential(
  torch.nn.Conv2d(1, 32, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2))
 self.conv2 = torch.nn.Sequential(
  torch.nn.Conv2d(32, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.conv3 = torch.nn.Sequential(
  torch.nn.Conv2d(64, 64, 3, 1, 1),
  torch.nn.ReLU(),
  torch.nn.MaxPool2d(2)
 )
 self.dense = torch.nn.Sequential(
  torch.nn.Linear(64 * 3 * 3, 128),
  torch.nn.ReLU(),
  torch.nn.Linear(128, 10)
 )

 def forward(self, x):
 conv1_out = self.conv1(x)
 conv2_out = self.conv2(conv1_out)
 conv3_out = self.conv3(conv2_out)
 res = conv3_out.view(conv3_out.size(0), -1)
 out = self.dense(res)
 return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
 print('epoch {}'.format(epoch + 1))
 # training-----------------------------
 train_loss = 0.
 train_acc = 0.
 for batch_x, batch_y in train_loader:
 batch_x, batch_y = Variable(batch_x), Variable(batch_y)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 train_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 train_correct = (pred == batch_y).sum()
 train_acc += train_correct.data[0]
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
 train_data)), train_acc / (len(train_data))))

 # evaluation--------------------------------
 model.eval()
 eval_loss = 0.
 eval_acc = 0.
 for batch_x, batch_y in test_loader:
 batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
 out = model(batch_x)
 loss = loss_func(out, batch_y)
 eval_loss += loss.data[0]
 pred = torch.max(out, 1)[1]
 num_correct = (pred == batch_y).sum()
 eval_acc += num_correct.data[0]
 print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
 test_data)), eval_acc / (len(test_data))))

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python sys.path详细介绍
Oct 17 Python
python通过ftplib登录到ftp服务器的方法
May 08 Python
python读写二进制文件的方法
May 09 Python
Python3读取UTF-8文件及统计文件行数的方法
May 22 Python
Python简单连接MongoDB数据库的方法
Mar 15 Python
PyQt5实现无边框窗口的标题拖动和窗口缩放
Apr 19 Python
Python中常用的内置方法
Jan 28 Python
python实现远程控制电脑
May 23 Python
利用python开发app实战的方法
Jul 09 Python
python 用所有标点符号分隔句子的示例
Jul 15 Python
python中count函数简单用法
Jan 05 Python
Python 中 sorted 如何自定义比较逻辑
Feb 02 Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
python多维数组切片方法
Apr 13 #Python
Python实现判断并移除列表指定位置元素的方法
Apr 13 #Python
Python中的二维数组实例(list与numpy.array)
Apr 13 #Python
You might like
PHP 处理图片的类实现代码
2009/10/23 PHP
jQuery Ajax之$.get()方法和$.post()方法
2009/10/12 Javascript
网页加载时页面显示进度条加载完成之后显示网页内容
2012/12/23 Javascript
JavaScript实现在数组中查找不同顺序排列的字符串
2014/09/26 Javascript
如何实现chrome浏览器关闭页面时弹出“确定要离开此面吗?”
2015/03/05 Javascript
javascript判断并获取注册表中可信任站点的方法
2015/06/01 Javascript
JS实现兼容性好,自动置顶的淘宝悬浮工具栏效果
2015/09/18 Javascript
JQuery实现列表中复选框全选反选功能封装(推荐)
2016/11/24 Javascript
js判断手机系统是android还是ios
2017/03/07 Javascript
jquery中each循环的简单回滚操作
2017/05/05 jQuery
vue-cli脚手架的安装教程图解
2018/09/02 Javascript
微信小程序 setData 对 data数据影响问题
2019/04/18 Javascript
jQuery中使用validate插件校验表单功能
2019/05/24 jQuery
Vue 中获取当前时间并实时刷新的实现代码
2020/05/12 Javascript
[38:44]DOTA2上海特级锦标赛A组小组赛#2 Secret VS CDEC第二局
2016/02/25 DOTA
python基于multiprocessing的多进程创建方法
2015/06/04 Python
Python处理PDF及生成多层PDF实例代码
2017/04/24 Python
python中实现数组和列表读取一列的方法
2018/04/03 Python
详解将Django部署到Centos7全攻略
2018/09/26 Python
python用for循环求和的方法总结
2019/07/08 Python
基于Python和PyYAML读取yaml配置文件数据
2020/01/13 Python
Python 使用 PyQt5 开发的关机小工具分享
2020/07/16 Python
巧用HTML5给按钮背景设计不同的动画简单实例
2016/08/09 HTML / CSS
使用phonegap获取设备的一些信息方法
2017/03/31 HTML / CSS
详解如何通过H5(浏览器/WebView/其他)唤起本地app
2017/12/11 HTML / CSS
AmazeUI 列表的实现示例
2020/08/17 HTML / CSS
置业顾问岗位职责
2014/03/02 职场文书
企业党员公开承诺书
2014/03/26 职场文书
2014年五四青年节演讲稿范文
2014/04/22 职场文书
初三学生评语大全
2014/04/24 职场文书
青春励志演讲稿范文
2014/08/25 职场文书
幼儿园教师的自我评价范文
2014/09/17 职场文书
2014财产信托协议书范本
2014/11/18 职场文书
2014年生产管理工作总结
2014/12/23 职场文书
沈阳故宫导游词
2015/01/31 职场文书
Python办公自动化之教你如何用Python将任意文件转为PDF格式
2021/06/28 Python