pytorch 准备、训练和测试自己的图片数据的方法


Posted in Python onJanuary 10, 2020

大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?

一、我的数据

我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:

pytorch 准备、训练和测试自己的图片数据的方法

下载地址:https://github.com/zalandoresearch/fashion-mnist

pytorch 准备、训练和测试自己的图片数据的方法

但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。

我先解压gz文件到e:/fashion_mnist/文件夹

然后运行代码:

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
  mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
    )
test_set = (
  mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
    )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
  if(train):
    f=open(root+'train.txt','w')
    data_path=root+'/train/'
    if(not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
      img_path=data_path+str(i)+'.jpg'
      io.imsave(img_path,img.numpy())
      f.write(img_path+' '+str(label)+'\n')
    f.close()
  else:
    f = open(root + 'test.txt', 'w')
    data_path = root + '/test/'
    if (not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
      img_path = data_path+ str(i) + '.jpg'
      io.imsave(img_path, img.numpy())
      f.write(img_path + ' ' + str(label) + '\n')
    f.close()

convert_to_img(True)
convert_to_img(False)

这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.

二、进行CNN分类训练和测试

先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。

代码如下:

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"

# -----------------ready the dataset--------------------------
def default_loader(path):
  return Image.open(path).convert('RGB')
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


#-----------------create the Net and training------------------------

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv1_out = self.conv1(x)
    conv2_out = self.conv2(conv1_out)
    conv3_out = self.conv3(conv2_out)
    res = conv3_out.view(conv3_out.size(0), -1)
    out = self.dense(res)
    return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
  print('epoch {}'.format(epoch + 1))
  # training-----------------------------
  train_loss = 0.
  train_acc = 0.
  for batch_x, batch_y in train_loader:
    batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    train_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    train_correct = (pred == batch_y).sum()
    train_acc += train_correct.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
    train_data)), train_acc / (len(train_data))))

  # evaluation--------------------------------
  model.eval()
  eval_loss = 0.
  eval_acc = 0.
  for batch_x, batch_y in test_loader:
    batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    eval_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    num_correct = (pred == batch_y).sum()
    eval_acc += num_correct.data[0]
  print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
    test_data)), eval_acc / (len(test_data))))

打印出来的网络模型:

pytorch 准备、训练和测试自己的图片数据的方法

训练和测试结果:

pytorch 准备、训练和测试自己的图片数据的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之数字处理(math)模块详解
Mar 25 Python
零基础写python爬虫之urllib2使用指南
Nov 05 Python
初步解析Python中的yield函数的用法
Apr 03 Python
Python多线程编程简单介绍
Apr 13 Python
使用Python 正则匹配两个特定字符之间的字符方法
Dec 24 Python
Python3安装Pillow与PIL的方法
Apr 03 Python
PyQt5实现简易电子词典
Jun 25 Python
实例详解Python装饰器与闭包
Jul 29 Python
python 对任意数据和曲线进行拟合并求出函数表达式的三种解决方案
Feb 18 Python
Python3 ID3决策树判断申请贷款是否成功的实现代码
May 21 Python
Django中使用Json返回数据的实现方法
Jun 03 Python
简单了解Python变量作用域正确使用方法
Jun 12 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
You might like
我的论坛源代码(九)
2006/10/09 PHP
php5.3中连接sqlserver2000的两种方法(com与ODBC)
2012/12/29 PHP
php设计模式之单例模式使用示例
2014/01/20 PHP
nginx+thinkphp下解决不支持pathinfo模式
2015/07/01 PHP
PHP和MySql中32位和64位的整形范围是多少
2016/02/18 PHP
js option删除代码集合
2008/11/12 Javascript
JavaScript中对象property的删除方法介绍
2014/12/30 Javascript
浅谈js基本数据类型和typeof
2016/08/09 Javascript
JavaScript学习笔记整理_用于模式匹配的String方法
2016/09/19 Javascript
js addDqmForPP给标签内属性值加上双引号的函数
2016/12/24 Javascript
Reactjs实现通用分页组件的实例代码
2017/01/19 Javascript
使用Vue写一个datepicker的示例
2018/01/27 Javascript
JavaScript的数据类型转换原则(干货)
2018/03/15 Javascript
js+SVG实现动态时钟效果
2018/07/14 Javascript
分享8个JavaScript库可更好地处理本地存储
2020/10/12 Javascript
浅析微信小程序自定义日历组件及flex布局最后一行对齐问题
2020/10/29 Javascript
es5 类与es6中class的区别小结
2020/11/09 Javascript
[06:43]DAC2018 4.5 SOLO赛 Maybe vs Paparazi
2018/04/06 DOTA
Python中的生成器和yield详细介绍
2015/01/09 Python
Python产生Gnuplot绘图数据的方法
2018/11/09 Python
Python字符串逆序的实现方法【一题多解】
2019/02/18 Python
Python3中urlencode和urldecode的用法详解
2019/07/23 Python
使用Windows批处理和WMI设置Python的环境变量方法
2019/08/14 Python
python利用Excel读取和存储测试数据完成接口自动化教程
2020/04/30 Python
html5的画布canvas——画出弧线、旋转的图形实例代码+效果图
2013/06/09 HTML / CSS
巧克力领导品牌瑞士莲美国官网:Lindt Chocolate美国
2016/08/25 全球购物
美国在线乐器和设备商店:Musician’s Friend
2018/07/06 全球购物
阿巴庭院:Abba Patio
2019/06/18 全球购物
世界上最好的野生海鲜和有机食品:Vital Choice
2020/01/16 全球购物
物理教育专业毕业生推荐信
2013/11/03 职场文书
竞选体育委员演讲稿
2014/04/26 职场文书
反腐倡廉学习心得体会范文
2015/08/15 职场文书
一文读懂navicat for mysql基础知识
2021/05/31 MySQL
Java中CyclicBarrier和CountDownLatch的用法与区别
2021/08/23 Java/Android
golang中的struct操作
2021/11/11 Golang
springboot入门 之profile设置方式
2022/04/04 Java/Android