pytorch 准备、训练和测试自己的图片数据的方法


Posted in Python onJanuary 10, 2020

大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?

一、我的数据

我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:

pytorch 准备、训练和测试自己的图片数据的方法

下载地址:https://github.com/zalandoresearch/fashion-mnist

pytorch 准备、训练和测试自己的图片数据的方法

但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。

我先解压gz文件到e:/fashion_mnist/文件夹

然后运行代码:

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
  mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
    )
test_set = (
  mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
    )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
  if(train):
    f=open(root+'train.txt','w')
    data_path=root+'/train/'
    if(not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
      img_path=data_path+str(i)+'.jpg'
      io.imsave(img_path,img.numpy())
      f.write(img_path+' '+str(label)+'\n')
    f.close()
  else:
    f = open(root + 'test.txt', 'w')
    data_path = root + '/test/'
    if (not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
      img_path = data_path+ str(i) + '.jpg'
      io.imsave(img_path, img.numpy())
      f.write(img_path + ' ' + str(label) + '\n')
    f.close()

convert_to_img(True)
convert_to_img(False)

这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.

二、进行CNN分类训练和测试

先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。

代码如下:

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"

# -----------------ready the dataset--------------------------
def default_loader(path):
  return Image.open(path).convert('RGB')
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


#-----------------create the Net and training------------------------

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv1_out = self.conv1(x)
    conv2_out = self.conv2(conv1_out)
    conv3_out = self.conv3(conv2_out)
    res = conv3_out.view(conv3_out.size(0), -1)
    out = self.dense(res)
    return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
  print('epoch {}'.format(epoch + 1))
  # training-----------------------------
  train_loss = 0.
  train_acc = 0.
  for batch_x, batch_y in train_loader:
    batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    train_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    train_correct = (pred == batch_y).sum()
    train_acc += train_correct.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
    train_data)), train_acc / (len(train_data))))

  # evaluation--------------------------------
  model.eval()
  eval_loss = 0.
  eval_acc = 0.
  for batch_x, batch_y in test_loader:
    batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    eval_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    num_correct = (pred == batch_y).sum()
    eval_acc += num_correct.data[0]
  print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
    test_data)), eval_acc / (len(test_data))))

打印出来的网络模型:

pytorch 准备、训练和测试自己的图片数据的方法

训练和测试结果:

pytorch 准备、训练和测试自己的图片数据的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python正则分组的应用
Nov 10 Python
Python random模块常用方法
Nov 03 Python
Python Web框架Flask信号机制(signals)介绍
Jan 01 Python
Python 实现链表实例代码
Apr 07 Python
python中defaultdict的用法详解
Jun 07 Python
在Django中输出matplotlib生成的图片方法
May 24 Python
Python2.7版os.path.isdir中文路径返回false的解决方法
Jun 21 Python
Python3远程监控程序的实现方法
Jul 15 Python
解决pyinstaller 打包exe文件太大,用pipenv 缩小exe的问题
Jul 13 Python
如何利用pygame实现打飞机小游戏
May 30 Python
Python 中random 库的详细使用
Jun 03 Python
Python中glob库实现文件名的匹配
Jun 18 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
You might like
discuz Passport 通行证 整合笔记
2008/06/30 PHP
php 将字符串按大写字母分隔成字符串数组
2010/04/30 PHP
利用PHP生成静态HTML文档的原理
2012/10/29 PHP
解析php根据ip查询所在地区(非常有用,赶集网就用到)
2013/07/01 PHP
php判断表是否存在的方法
2015/06/18 PHP
PHP基于工厂模式实现的计算器实例
2015/07/16 PHP
Yii列表定义与使用分页方法小结(3种方法)
2016/07/15 PHP
利用javascript中的call实现继承
2007/01/22 Javascript
Jquery iframe内部出滚动条
2010/02/11 Javascript
工作需要写的一个js拖拽组件
2011/07/28 Javascript
js中的replace方法使用介绍
2013/10/28 Javascript
JS:window.onload的使用介绍
2013/11/13 Javascript
jquery validate在ie8下的bug解决方法
2013/11/13 Javascript
javascript中字符串的定义示例代码
2013/12/19 Javascript
js支持键盘控制的左右切换立体式图片轮播效果代码分享
2015/08/26 Javascript
jQuery的ajax下载blob文件
2016/07/21 Javascript
一篇文章搞定JavaScript类型转换(面试常见)
2017/01/21 Javascript
selenium 与 chrome 进行qq登录并发邮件操作实例详解
2017/04/06 Javascript
JavaScript简介_动力节点Java学院整理
2017/06/26 Javascript
express 项目分层实践详解
2018/12/10 Javascript
用Vue.js方法创建模板并使用多个模板合成
2019/06/28 Javascript
[02:40]DOTA2英雄基础教程 先知
2013/11/29 DOTA
Python操作MySQL数据库的三种方法总结
2018/01/30 Python
python的xpath获取div标签内html内容,实现innerhtml功能的方法
2019/01/02 Python
djang常用查询SQL语句的使用代码
2019/02/15 Python
python下载微信公众号相关文章
2019/02/26 Python
Python基于BeautifulSoup和requests实现的爬虫功能示例
2019/08/02 Python
在python Numpy中求向量和矩阵的范数实例
2019/08/26 Python
tensorflow查看ckpt各节点名称实例
2020/01/21 Python
keras做CNN的训练误差loss的下降操作
2020/06/22 Python
python eventlet绿化和patch原理
2020/11/21 Python
美国最大的存储市场:SpareFoot
2018/07/23 全球购物
Crocs欧洲官网:Crocs Europe
2020/01/14 全球购物
大学生通用个人的自我评价
2014/02/10 职场文书
社会学专业求职信
2014/07/17 职场文书
OpenCV图像变换之傅里叶变换的一些应用
2021/07/26 Python