pytorch实现建立自己的数据集(以mnist为例)


Posted in Python onJanuary 18, 2020

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)
  
  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()
  
  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()
  
  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

pytorch实现建立自己的数据集(以mnist为例)

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

pytorch实现建立自己的数据集(以mnist为例)

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

pytorch实现建立自己的数据集(以mnist为例)

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

pytorch实现建立自己的数据集(以mnist为例)

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的
    
  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

pytorch实现建立自己的数据集(以mnist为例)

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

pytorch实现建立自己的数据集(以mnist为例)

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基于mysql实现的简单队列以及跨进程锁实例详解
Jul 07 Python
Python探索之修改Python搜索路径
Oct 25 Python
Python实现删除时保留特定文件夹和文件的示例
Apr 27 Python
python 获取当天凌晨零点的时间戳方法
May 22 Python
python通过微信发送邮件实现电脑关机
Jun 20 Python
python检测主机的连通性并记录到文件的实例
Jun 21 Python
python引入不同文件夹下的自定义模块方法
Oct 27 Python
Django中的forms组件实例详解
Nov 08 Python
Pycharm如何打断点的方法步骤
Jun 13 Python
Python 使用 Pillow 模块给图片添加文字水印的方法
Aug 30 Python
详细整理python 字符串(str)与列表(list)以及数组(array)之间的转换方法
Aug 30 Python
Django Model层F,Q对象和聚合函数原理解析
Nov 12 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
You might like
PHP define函数的使用说明
2008/08/27 PHP
CI框架中zip类应用示例
2014/06/17 PHP
WordPress中设置Post Type自定义文章类型的实例教程
2016/05/10 PHP
PHP如何开启Opcache功能提升程序处理效率
2020/04/27 PHP
JavaScript 序列化对象实现代码
2009/12/18 Javascript
JQuery学习笔记 nt-child的使用
2011/01/17 Javascript
javascript错误的认识不用关心内存管理
2012/12/15 Javascript
JavaScript使用addEventListener添加事件监听用法实例
2015/06/01 Javascript
js变形金刚文字特效代码分享
2015/08/20 Javascript
jQuery插件实现静态HTML验证码校验
2015/11/06 Javascript
easyui datagrid 大数据加载效率慢,优化解决方法(推荐)
2016/11/09 Javascript
jQuery插件HighCharts绘制2D带有Legend的饼图效果示例【附demo源码下载】
2017/03/10 Javascript
详解angular中的作用域及继承
2017/05/31 Javascript
Angular2管道Pipe及自定义管道格式数据用法实例分析
2017/11/29 Javascript
React 高阶组件入门介绍
2018/01/11 Javascript
在vue里面设置全局变量或数据的方法
2018/03/09 Javascript
使用layui实现的左侧菜单栏以及动态操作tab项方法
2019/09/10 Javascript
如何优雅地取消 JavaScript 异步任务
2020/03/22 Javascript
[03:07]【DOTA2亚洲邀请赛】我们,梦开始的地方
2017/03/07 DOTA
举例介绍Python中的25个隐藏特性
2015/03/30 Python
Python同时向控制台和文件输出日志logging的方法
2015/05/26 Python
Python编程之Re模块下的函数介绍
2017/10/28 Python
Python面向对象class类属性及子类用法分析
2018/02/02 Python
Tensorflow卷积神经网络实例
2018/05/24 Python
Python模拟自动存取款机的查询、存取款、修改密码等操作
2018/09/02 Python
YUV转为jpg图像的实现
2019/12/09 Python
解决Python import docx出错DLL load failed的问题
2020/02/13 Python
解决img标签上下出现间隙的方法
2016/12/14 HTML / CSS
Allen Edmonds官方网站:一家美国优质男士鞋类及配饰制造商
2019/03/12 全球购物
建筑设计所实习生自我鉴定
2013/09/25 职场文书
银行门卫岗位职责
2013/12/29 职场文书
节电标语大全
2014/06/23 职场文书
有关西游记的读书笔记
2015/06/25 职场文书
小学毕业教师寄语
2019/06/21 职场文书
Go 在 MongoDB 中常用查询与修改的操作
2021/05/07 Golang
OpenCV3.3+Python3.6实现图片高斯模糊
2021/05/18 Python