pytorch 准备、训练和测试自己的图片数据的方法


Posted in Python onJanuary 10, 2020

大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?

一、我的数据

我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:

pytorch 准备、训练和测试自己的图片数据的方法

下载地址:https://github.com/zalandoresearch/fashion-mnist

pytorch 准备、训练和测试自己的图片数据的方法

但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。

我先解压gz文件到e:/fashion_mnist/文件夹

然后运行代码:

import os
from skimage import io
import torchvision.datasets.mnist as mnist

root="E:/fashion_mnist/"
train_set = (
  mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
    )
test_set = (
  mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
  mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
    )
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())

def convert_to_img(train=True):
  if(train):
    f=open(root+'train.txt','w')
    data_path=root+'/train/'
    if(not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
      img_path=data_path+str(i)+'.jpg'
      io.imsave(img_path,img.numpy())
      f.write(img_path+' '+str(label)+'\n')
    f.close()
  else:
    f = open(root + 'test.txt', 'w')
    data_path = root + '/test/'
    if (not os.path.exists(data_path)):
      os.makedirs(data_path)
    for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
      img_path = data_path+ str(i) + '.jpg'
      io.imsave(img_path, img.numpy())
      f.write(img_path + ' ' + str(label) + '\n')
    f.close()

convert_to_img(True)
convert_to_img(False)

这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.

二、进行CNN分类训练和测试

先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。

代码如下:

import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"

# -----------------ready the dataset--------------------------
def default_loader(path):
  return Image.open(path).convert('RGB')
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0],int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader

  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    if self.transform is not None:
      img = self.transform(img)
    return img,label

  def __len__(self):
    return len(self.imgs)

train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


#-----------------create the Net and training------------------------

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = torch.nn.Sequential(
      torch.nn.Conv2d(3, 32, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2))
    self.conv2 = torch.nn.Sequential(
      torch.nn.Conv2d(32, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.conv3 = torch.nn.Sequential(
      torch.nn.Conv2d(64, 64, 3, 1, 1),
      torch.nn.ReLU(),
      torch.nn.MaxPool2d(2)
    )
    self.dense = torch.nn.Sequential(
      torch.nn.Linear(64 * 3 * 3, 128),
      torch.nn.ReLU(),
      torch.nn.Linear(128, 10)
    )

  def forward(self, x):
    conv1_out = self.conv1(x)
    conv2_out = self.conv2(conv1_out)
    conv3_out = self.conv3(conv2_out)
    res = conv3_out.view(conv3_out.size(0), -1)
    out = self.dense(res)
    return out


model = Net()
print(model)

optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()

for epoch in range(10):
  print('epoch {}'.format(epoch + 1))
  # training-----------------------------
  train_loss = 0.
  train_acc = 0.
  for batch_x, batch_y in train_loader:
    batch_x, batch_y = Variable(batch_x), Variable(batch_y)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    train_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    train_correct = (pred == batch_y).sum()
    train_acc += train_correct.data[0]
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
    train_data)), train_acc / (len(train_data))))

  # evaluation--------------------------------
  model.eval()
  eval_loss = 0.
  eval_acc = 0.
  for batch_x, batch_y in test_loader:
    batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
    out = model(batch_x)
    loss = loss_func(out, batch_y)
    eval_loss += loss.data[0]
    pred = torch.max(out, 1)[1]
    num_correct = (pred == batch_y).sum()
    eval_acc += num_correct.data[0]
  print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
    test_data)), eval_acc / (len(test_data))))

打印出来的网络模型:

pytorch 准备、训练和测试自己的图片数据的方法

训练和测试结果:

pytorch 准备、训练和测试自己的图片数据的方法

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之字典操作详解
Mar 25 Python
python在linux中输出带颜色的文字的方法
Jun 19 Python
python测试mysql写入性能完整实例
Jan 18 Python
Django使用HttpResponse返回图片并显示的方法
May 22 Python
numpy给array增加维度np.newaxis的实例
Nov 01 Python
python实现日志按天分割
Jul 22 Python
Python学习笔记之集合的概念和简单使用示例
Aug 22 Python
使用Python的datetime库处理时间(RPA流程)
Nov 24 Python
Anconda环境下Vscode安装Python的方法详解
Mar 29 Python
Python实现aes加密解密多种方法解析
May 15 Python
TensorFlow的自动求导原理分析
May 26 Python
python APScheduler执行定时任务介绍
Apr 19 Python
pytorch GAN伪造手写体mnist数据集方式
Jan 10 #Python
MNIST数据集转化为二维图片的实现示例
Jan 10 #Python
pytorch:实现简单的GAN示例(MNIST数据集)
Jan 10 #Python
pytorch GAN生成对抗网络实例
Jan 10 #Python
解决pytorch报错:AssertionError: Invalid device id的问题
Jan 10 #Python
python3中关于excel追加写入格式被覆盖问题(实例代码)
Jan 10 #Python
mac使用python识别图形验证码功能
Jan 10 #Python
You might like
帖几个PHP的无限分类实现想法~
2007/01/02 PHP
php过滤html标记属性类用法实例
2014/09/23 PHP
通过JAVASCRIPT读取ASP设定的COOKIE
2006/11/24 Javascript
Javascript合并表格中具有相同内容单元格示例
2013/08/11 Javascript
js数组的基本用法及数组根据下标(数值或字符)移除元素
2013/10/20 Javascript
jQuery表单验证简单示例
2016/10/17 Javascript
ES6新特性之类(Class)和继承(Extends)相关概念与用法分析
2017/05/24 Javascript
创建简单的node服务器实例(分享)
2017/06/23 Javascript
Node使用Sequlize连接Mysql报错:Access denied for user ‘xxx’@‘localhost’
2018/01/03 Javascript
Router解决跨模块下的页面跳转示例
2018/01/11 Javascript
vue 组件中使用 transition 和 transition-group实现过渡动画
2019/07/09 Javascript
javascript History对象原理解析
2020/02/17 Javascript
详解python里使用正则表达式的分组命名方式
2017/10/24 Python
python 移除字符串尾部的数字方法
2018/07/17 Python
Python机器学习之scikit-learn库中KNN算法的封装与使用方法
2018/12/14 Python
python获取当前文件路径以及父文件路径的方法
2019/07/10 Python
PyCharm中代码字体大小调整方法
2019/07/29 Python
python 浅谈serial与stm32通信的编码问题
2019/12/18 Python
html5理解head_动力节点Java学院整理
2017/07/13 HTML / CSS
HTML5 微格式和相关的属性名称
2010/02/10 HTML / CSS
canvas进阶之如何画出平滑的曲线
2018/10/15 HTML / CSS
Interhome丹麦:在线预订度假屋和公寓
2019/07/18 全球购物
俄罗斯电子产品、计算机和家用电器购物网站:OLDI
2019/10/27 全球购物
什么是smarty? Smarty的优点是什么?
2013/08/11 面试题
自考生自我鉴定范文
2013/10/01 职场文书
日语翻译个人求职的自我评价
2013/10/14 职场文书
教师实习自我鉴定
2013/12/18 职场文书
家长给孩子的表扬信
2014/01/17 职场文书
母亲节演讲稿
2014/05/27 职场文书
旅游与环境专业求职信
2014/06/05 职场文书
刑事代理授权委托书
2014/09/17 职场文书
查摆问题整改措施
2014/10/24 职场文书
门市房租房协议书
2014/12/04 职场文书
推荐信范文大全
2015/03/27 职场文书
运动会闭幕式主持词
2015/07/01 职场文书
十大冰系宝可梦排名,颜值最高的阿罗拉九尾,第三使用率第一
2022/03/18 日漫