利用pytorch实现对CIFAR-10数据集的分类


Posted in Python onJanuary 14, 2020

步骤如下:

1.使用torchvision加载并预处理CIFAR-10数据集、

2.定义网络

3.定义损失函数和优化器

4.训练网络并更新网络参数

5.测试网络

运行环境:

windows+python3.6.3+pycharm+pytorch0.3.0

import torchvision as tv
import torchvision.transforms as transforms
import torch as t
from torchvision.transforms import ToPILImage
show=ToPILImage()    #把Tensor转成Image,方便可视化
import matplotlib.pyplot as plt
import torchvision
import numpy as np


###############数据加载与预处理
transform = transforms.Compose([transforms.ToTensor(),#转为tensor
                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),#归一化
                ])
#训练集
trainset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=True,
               download=True,
               transform=transform)

trainloader=t.utils.data.DataLoader(trainset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)
#测试集
testset=tv.datasets.CIFAR10(root='/python projects/test/data/',
               train=False,
               download=True,
               transform=transform)

testloader=t.utils.data.DataLoader(testset,
                  batch_size=4,
                  shuffle=True,
                  num_workers=0)


classes=('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

(data,label)=trainset[100]
print(classes[label])

show((data+1)/2).resize((100,100))

# dataiter=iter(trainloader)
# images,labels=dataiter.next()
# print(''.join('11%s'%classes[labels[j]] for j in range(4)))
# show(tv.utils.make_grid(images+1)/2).resize((400,100))
def imshow(img):
  img = img / 2 + 0.5
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.size())
imshow(torchvision.utils.make_grid(images))
plt.show()#关掉图片才能往后继续算


#########################定义网络
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1=nn.Conv2d(3,6,5)
    self.conv2=nn.Conv2d(6,16,5)
    self.fc1=nn.Linear(16*5*5,120)
    self.fc2=nn.Linear(120,84)
    self.fc3=nn.Linear(84,10)

  def forward(self, x):
    x = F.max_pool2d(F.relu(self.conv1(x)),2)
    x = F.max_pool2d(F.relu(self.conv2(x)),2)
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

net=Net()
print(net)

#############定义损失函数和优化器
from torch import optim
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

##############训练网络
from torch.autograd import Variable
import time

start_time = time.time()
for epoch in range(2):
  running_loss=0.0
  for i,data in enumerate(trainloader,0):
    #输入数据
    inputs,labels=data
    inputs,labels=Variable(inputs),Variable(labels)
    #梯度清零
    optimizer.zero_grad()

    outputs=net(inputs)
    loss=criterion(outputs,labels)
    loss.backward()
    #更新参数
    optimizer.step()

    # 打印log
    running_loss += loss.data[0]
    if i % 2000 == 1999:
      print('[%d,%5d] loss:%.3f' % (epoch + 1, i + 1, running_loss / 2000))
      running_loss = 0.0
print('finished training')
end_time = time.time()
print("Spend time:", end_time - start_time)

以上这篇利用pytorch实现对CIFAR-10数据集的分类就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python创建只读属性对象的方法(ReadOnlyObject)
Feb 10 Python
python益智游戏计算汉诺塔问题示例
Mar 05 Python
Python和php通信乱码问题解决方法
Apr 15 Python
python使用marshal模块序列化实例
Sep 25 Python
使用Python实现下载网易云音乐的高清MV
Mar 16 Python
Python画图高斯分布的示例
Jul 10 Python
python实现的按要求生成手机号功能示例
Oct 08 Python
Python 操作mysql数据库查询之fetchone(), fetchmany(), fetchall()用法示例
Oct 17 Python
MNIST数据集转化为二维图片的实现示例
Jan 10 Python
Python list运算操作代码实例解析
Jan 20 Python
python3爬虫中异步协程的用法
Jul 10 Python
用OpenCV进行年龄和性别检测的实现示例
Jan 29 Python
pytorch下使用LSTM神经网络写诗实例
Jan 14 #Python
python使用openCV遍历文件夹里所有视频文件并保存成图片
Jan 14 #Python
pytorch实现mnist数据集的图像可视化及保存
Jan 14 #Python
Pytorch在dataloader类中设置shuffle的随机数种子方式
Jan 14 #Python
python3.7通过thrift操作hbase的示例代码
Jan 14 #Python
解决pytorch DataLoader num_workers出现的问题
Jan 14 #Python
PyTorch实现ResNet50、ResNet101和ResNet152示例
Jan 14 #Python
You might like
php在window iis的莫名问题的测试方法
2013/05/14 PHP
详解Yii2 rules 的验证规则
2016/12/02 PHP
php中数组最简单的使用方法
2020/12/27 PHP
JavaScript Cookie 直接浏览网站分网址
2009/12/08 Javascript
JavaScript写的一个自定义弹出式对话框代码
2010/01/17 Javascript
javascript+xml实现简单图片轮换(只支持IE)
2012/12/23 Javascript
博客侧边栏模块跟随滚动条滑动固定效果的实现方法(js+jquery等)
2013/03/24 Javascript
js时间戳格式化成日期格式的多种方法
2013/11/11 Javascript
基于javascript实现图片滑动效果
2016/05/07 Javascript
AngularJs基本特性解析(一)
2016/07/21 Javascript
Js实现京东无延迟菜单效果实例(demo)
2017/06/02 Javascript
浅谈Angular2 ng-content 指令在组件中嵌入内容
2017/08/18 Javascript
setTimeout时间设置为0详细解析
2018/03/13 Javascript
微信小程序实现tab页面切换功能
2018/07/13 Javascript
js实现轮播图的完整代码
2020/10/26 Javascript
在Vue 中使用Typescript的示例代码
2018/09/10 Javascript
ES6对象操作实例详解
2020/05/23 Javascript
[01:19:33]DOTA2-DPC中国联赛 正赛 iG vs VG BO3 第一场 2月2日
2021/03/11 DOTA
python使用arp欺骗伪造网关的方法
2015/04/24 Python
Python的爬虫包Beautiful Soup中用正则表达式来搜索
2016/01/20 Python
Python爬取腾讯视频评论的思路详解
2019/12/19 Python
Python+OpenCV图像处理——打印图片属性、设置存储路径、调用摄像头
2020/10/22 Python
如何使用Django Admin管理后台导入CSV
2020/11/06 Python
pytorch下的unsqueeze和squeeze的用法说明
2021/02/06 Python
初探CSS3中的calc()功能
2015/07/14 HTML / CSS
使用CSS3来实现滚动视差效果的教程
2015/08/24 HTML / CSS
意大利团购网站:Groupon意大利
2016/10/11 全球购物
微软新西兰官方网站:Microsoft New Zealand
2018/08/17 全球购物
英国领先的在线高尔夫设备零售商:Golfgeardirect
2020/12/11 全球购物
优秀员工表扬信
2014/01/17 职场文书
公司副总经理任命书
2014/06/05 职场文书
群众路线教育实践活动的心得体会
2014/09/03 职场文书
《领导干部从政道德启示录》学习心得体会
2016/01/20 职场文书
餐饮行业关注的9大营销策略
2019/08/26 职场文书
Python实现简单的俄罗斯方块游戏
2021/09/25 Python
《Estab Life》4月6日播出 正式PV、主视觉图公开
2022/03/20 日漫