pytorch实现建立自己的数据集(以mnist为例)


Posted in Python onJanuary 18, 2020

本文将原始的numpy array数据在pytorch下封装为Dataset类的数据集,为后续深度网络训练提供数据。

加载并保存图像信息

首先导入需要的库,定义各种路径。

import os
import matplotlib
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import scipy.misc

root_path = 'E:/coding_ex/pytorch/Alexnet/data/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

这里将数据集分为三类,baseset为所有数据(trainingset+testset),trainingset是训练集,testset是测试集。直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。

def LoadData(root_path, base_path, training_path, test_path):
  (x_train, y_train), (x_test, y_test) = mnist.load_data()
  x_baseset = np.concatenate((x_train, x_test))
  y_baseset = np.concatenate((y_train, y_test))
  train_num = len(x_train)
  test_num = len(x_test)
  
  #baseset
  file_img = open((os.path.join(root_path, base_path)+'baseset_img.txt'),'w')
  file_label = open((os.path.join(root_path, base_path)+'baseset_label.txt'),'w')
  for i in range(train_num + test_num):
    file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_baseset[i])+'\n') #label
#    scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
    matplotlib.image.imsave(root_path + base_path + 'img/'+str(i) + '.png', x_baseset[i])
  file_img.close()
  file_label.close()
  
  #trainingset
  file_img = open((os.path.join(root_path, training_path)+'trainingset_img.txt'),'w')
  file_label = open((os.path.join(root_path, training_path)+'trainingset_label.txt'),'w')
  for i in range(train_num):
    file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_train[i])+'\n') #label
#    scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
    matplotlib.image.imsave(root_path + training_path + 'img/'+str(i) + '.png', x_train[i])
  file_img.close()
  file_label.close()
  
  #testset
  file_img = open((os.path.join(root_path, test_path)+'testset_img.txt'),'w')
  file_label = open((os.path.join(root_path, test_path)+'testset_label.txt'),'w')
  for i in range(test_num):
    file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n') #name
    file_label.write(str(y_test[i])+'\n') #label
#    scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
    matplotlib.image.imsave(root_path + test_path + 'img/'+str(i) + '.png', x_test[i])
  file_img.close()
  file_label.close()

使用这段代码时,需要建立相应的文件夹及.txt文件,./data文件夹结构如下:

pytorch实现建立自己的数据集(以mnist为例)

/img文件夹

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。

pytorch实现建立自己的数据集(以mnist为例)

如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件

xxx_img.txt文件中存放的是每张图像的名字

pytorch实现建立自己的数据集(以mnist为例)

xxx_label.txt文件

xxx_label.txt文件中存放的是类别标记

pytorch实现建立自己的数据集(以mnist为例)

这里记得保存的时候一行为一个图像信息,便于后续读取。

定义自己的Dataset类

pytorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数据封装成Dataset类,继承该类需要写初始化方法(__init__),获取指定下标数据的方法__getitem__),获取数据个数的方法(__len__)。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
  def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform = None):
    self.root_path = root_path
    self.transform = transform
    self.imagedata_path = imgdata_path
    img_file = open((root_path + imgfile_path),'r')
    self.image_name = [x.strip() for x in img_file]
    img_file.close()
    label_file = open((root_path + labelfile_path), 'r')
    label = [int(x.strip()) for x in label_file]
    label_file.close()
    self.label = torch.LongTensor(label)#这句很重要,一定要把label转为LongTensor类型的
    
  def __getitem__(self, idx):
    image = Image.open(str(self.image_name[idx]))
    image = image.convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    label = self.label[idx]
    return image, label
  def __len__(self):
    return len(self.image_name)

定义完自己的类之后可以测试一下。

LoadData(root_path, base_path, training_path, test_path)
  training_imgfile = training_path + 'trainingset_img.txt'
  training_labelfile = training_path + 'trainingset_label.txt'
  training_imgdata = training_path + 'img/'
  #实例化一个类
  dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)

得到图像名称

name = dataset.image_name

pytorch实现建立自己的数据集(以mnist为例)

这里我们可以单独输出某一个名称看一下是否有换行符

print(name[0])
>>>'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png'

如果定义类的时候self.image_name = [x.strip() for x in img_file]这句没有strip掉,则输出的值将为'E:/coding_ex/pytorch/Alexnet/data/trainingset/img/0.png\n'

获取固定下标的图像

im, label = dataset.__getitem__(0)

得到结果

pytorch实现建立自己的数据集(以mnist为例)

以上这篇pytorch实现建立自己的数据集(以mnist为例)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python兔子毒药问题实例分析
Mar 05 Python
Python聚类算法之基本K均值实例详解
Nov 20 Python
python实现xlsx文件分析详解
Jan 02 Python
在Python中Dataframe通过print输出多行时显示省略号的实例
Dec 22 Python
PyCharm 设置SciView工具窗口的方法
Jan 15 Python
Python中的字符串切片(截取字符串)的详解
May 15 Python
python SVM 线性分类模型的实现
Jul 19 Python
Python 自动登录淘宝并保存登录信息的方法
Sep 04 Python
python爬虫之遍历单个域名
Nov 20 Python
python中文分词库jieba使用方法详解
Feb 11 Python
Python实现计算图像RGB均值方式
Jun 04 Python
Python selenium的这三种等待方式一定要会!
Jun 10 Python
使用PyTorch实现MNIST手写体识别代码
Jan 18 #Python
Pytorch之finetune使用详解
Jan 18 #Python
pytorch 修改预训练model实例
Jan 18 #Python
Pytorch自己加载单通道图片用作数据集训练的实例
Jan 18 #Python
pyinstaller 3.6版本通过pip安装失败的解决办法(推荐)
Jan 18 #Python
Python实现点云投影到平面显示
Jan 18 #Python
Pytorch 实现计算分类器准确率(总分类及子分类)
Jan 18 #Python
You might like
PHP 图片上传实现代码 带详细注释
2010/04/29 PHP
PHP模拟登陆163邮箱发邮件及获取通讯录列表的方法
2015/03/07 PHP
php简单生成随机数的方法
2015/07/30 PHP
实例分析基于PHP微信网页获取用户信息
2017/11/24 PHP
struts2 jquery 打造无限层次的树
2009/10/23 Javascript
document.getElementById方法在Firefox与IE中的区别
2010/05/18 Javascript
js判断IE6/IE7/FF的代码[XMLHttpRequest]
2011/02/16 Javascript
JavaScript高级程序设计 阅读笔记(十七) js事件
2012/08/14 Javascript
Web跨浏览器进程通信(Web跨域)
2013/04/17 Javascript
script不刷新页面的联动前后代码
2013/09/18 Javascript
利用Jquery实现可多选的下拉框
2014/02/21 Javascript
使用ajax+jqtransform实现动态加载select
2014/12/01 Javascript
javascript中typeof操作符和constucor属性检测
2015/02/26 Javascript
JS获取Table中td值的方法
2015/03/19 Javascript
简单实现异步编程promise模式
2015/07/31 Javascript
Bootstrap CDN和本地化环境搭建
2016/10/26 Javascript
JS作用域闭包、预解释和this关键字综合实例解析
2016/12/16 Javascript
JavaScript数组push方法使用注意事项
2017/10/30 Javascript
JavaScript数据结构与算法之二叉树插入节点、生成二叉树示例
2019/02/21 Javascript
Angular利用HTTP POST下载流文件的步骤记录
2020/07/26 Javascript
js 实现碰撞检测的示例
2020/10/28 Javascript
如何利用nodejs实现命令行游戏
2020/11/24 NodeJs
python实现马耳可夫链算法实例分析
2015/05/20 Python
python验证码识别教程之滑动验证码
2018/06/04 Python
Pycharm更换python解释器的方法
2018/10/29 Python
python爬虫 urllib模块反爬虫机制UA详解
2019/08/20 Python
德国大型和小型家用电器网上商店:Energeto
2019/05/15 全球购物
Pandora德国官网:购买潘多拉手链、戒指、项链和耳环
2020/02/20 全球购物
创先争优承诺书范文
2014/03/31 职场文书
校长寄语大全
2014/04/09 职场文书
小学学雷锋活动总结
2014/04/25 职场文书
中药学专业求职信
2014/05/31 职场文书
村官2015年度工作总结
2015/10/14 职场文书
2016年学校综治宣传月活动总结
2016/03/16 职场文书
Vue和Flask通信的实现
2021/05/19 Vue.js
python操作xlsx格式文件并读取
2021/06/02 Python