pytorch实现建立自己的数据集(以mnist为例)


Posted in Python onJanuary 18, 2020

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)
  
  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()
  
  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()
  
  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

pytorch实现建立自己的数据集(以mnist为例)

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

pytorch实现建立自己的数据集(以mnist为例)

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

pytorch实现建立自己的数据集(以mnist为例)

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

pytorch实现建立自己的数据集(以mnist为例)

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的
    
  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

pytorch实现建立自己的数据集(以mnist为例)

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

pytorch实现建立自己的数据集(以mnist为例)

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python读取Android permission文件
Nov 01 Python
轻松实现python搭建微信公众平台
Feb 16 Python
python解决汉字编码问题:Unicode Decode Error
Jan 19 Python
python使用PyCharm进行远程开发和调试
Nov 02 Python
Python基于OpenCV库Adaboost实现人脸识别功能详解
Aug 25 Python
Python实现的删除重复文件或图片功能示例【去重】
Apr 23 Python
如何利用python web框架做文件流下载的实现示例
Jun 02 Python
Python 如何实现数据库表结构同步
Sep 29 Python
Python importlib模块重载使用方法详解
Oct 13 Python
Python 高级库15 个让新手爱不释手(推荐)
May 15 Python
一文搞懂python异常处理、模块与包
Jun 26 Python
简单谈谈Python面向对象的相关知识
Jun 28 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
You might like
用PHP实现小型站点广告管理(修正版)
2006/10/09 PHP
php开发工具有哪五款
2015/11/09 PHP
微信支付PHP SDK之微信公众号支付代码详解
2015/12/09 PHP
php简单实现文件或图片强制下载的方法
2016/12/06 PHP
提升你网站水平的jQuery插件集合推荐
2011/04/19 Javascript
js控制frameSet示例
2013/09/10 Javascript
解析img图片没找到onerror事件 Stack overflow at line: 0
2013/12/23 Javascript
jQuery中:checkbox选择器用法实例
2015/01/03 Javascript
jQuery实现自定义checkbox和radio样式
2015/07/13 Javascript
详解Webwork中Action 调用的方法
2016/02/02 Javascript
JS获取鼠标相对位置的方法
2016/09/20 Javascript
JavaScript解析JSON格式数据的方法示例
2017/01/24 Javascript
ES6深入理解之“let”能替代”var“吗?
2017/06/28 Javascript
nodejs 日志模块winston的使用方法
2018/05/02 NodeJs
浅谈如何使用webpack构建多页面应用
2018/05/30 Javascript
js布局实现单选按钮控件
2020/01/17 Javascript
vue项目中使用bpmn-自定义platter的示例代码
2020/05/11 Javascript
wxpython中利用线程防止假死的实现方法
2014/08/11 Python
Python实现将通信达.day文件读取为DataFrame
2018/12/22 Python
python3射线法判断点是否在多边形内
2019/06/28 Python
python绘制双Y轴折线图以及单Y轴双变量柱状图的实例
2019/07/08 Python
Python命令行参数定义及需要注意的地方
2020/11/30 Python
We Fashion荷兰:一家国际时装公司
2018/04/18 全球购物
The Beach People美国:澳洲海滨奢华品牌
2018/07/05 全球购物
千禧酒店及度假村官方网站:Millennium Hotels and Resorts
2019/05/10 全球购物
建筑工程技术应届生求职信
2013/11/17 职场文书
情人节活动策划方案
2014/02/27 职场文书
毕业论文评语大全
2014/04/29 职场文书
教育合作协议范本
2014/10/17 职场文书
作风建设年度心得体会
2014/10/29 职场文书
班主任先进事迹材料
2014/12/17 职场文书
参观邀请函范文
2015/02/02 职场文书
灵魂歌王观后感
2015/06/17 职场文书
学生会副主席竞选稿
2015/11/19 职场文书
CSS预处理框架——Stylus
2021/04/21 HTML / CSS
mysql中整数数据类型tinyint详解
2021/12/06 MySQL