使用PyTorch将文件夹下的图片分为训练集和验证集实例


Posted in Python onJanuary 08, 2020

PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开。我写了两个类来完成这个工作。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose
from PIL import Image
from sklearn.model_selection import train_test_split

class ImageFolderSplitter:
  # images should be placed in folders like:
  # --root
  # ----root\dogs
  # ----root\dogs\image1.png
  # ----root\dogs\image2.png
  # ----root\cats
  # ----root\cats\image1.png
  # ----root\cats\image2.png  
  # path: the root of the image folder
  def __init__(self, path, train_size = 0.8):
    self.path = path
    self.train_size = train_size
    self.class2num = {}
    self.num2class = {}
    self.class_nums = {}
    self.data_x_path = []
    self.data_y_label = []
    self.x_train = []
    self.x_valid = []
    self.y_train = []
    self.y_valid = []
    for root, dirs, files in os.walk(path):
      if len(files) == 0 and len(dirs) > 1:
        for i, dir1 in enumerate(dirs):
          self.num2class[i] = dir1
          self.class2num[dir1] = i
      elif len(files) > 1 and len(dirs) == 0:
        category = ""
        for key in self.class2num.keys():
          if key in root:
            category = key
            break
        label = self.class2num[category]
        self.class_nums[label] = 0
        for file1 in files:
          self.data_x_path.append(os.path.join(root, file1))
          self.data_y_label.append(label)
          self.class_nums[label] += 1
      else:
        raise RuntimeError("please check the folder structure!")
    self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label, shuffle = True, train_size = self.train_size)

  def getTrainingDataset(self):
    return self.x_train, self.y_train

  def getValidationDataset(self):
    return self.x_valid, self.y_valid

class DatasetFromFilename(Dataset):
  # x: a list of image file full path
  # y: a list of image categories
  def __init__(self, x, y, transforms = None):
    super(DatasetFromFilename, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = ToTensor()
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)

  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    return self.transforms(img), torch.tensor([[self.y[idx]]])

# test code
# splitter = ImageFolderSplitter("for_test")
# transforms = Compose([Resize((51, 51)), ToTensor()])
# x_train, y_train = splitter.getTrainingDataset()
# training_dataset = DatasetFromFilename(x_train, y_train, transforms=transforms)
# training_dataloader = DataLoader(training_dataset, batch_size=2, shuffle=True)
# x_valid, y_valid = splitter.getValidationDataset()
# validation_dataset = DatasetFromFilename(x_valid, y_valid, transforms=transforms)
# validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)
# for x, y in training_dataloader:
#   print(x.shape, y.shape)

更多的代码可以在我的Github reop下找到。

Python 相关文章推荐
Python 面向对象 成员的访问约束
Dec 23 Python
python 采集中文乱码问题的完美解决方法
Sep 27 Python
python3.4用循环往mysql5.7中写数据并输出的实现方法
Jun 20 Python
简单了解Django模板的使用
Dec 20 Python
Python之ReportLab绘制条形码和二维码的实例
Jan 15 Python
Tensorflow环境搭建的方法步骤
Feb 07 Python
Python 批量合并多个txt文件的实例讲解
May 08 Python
浅析python中while循环和for循环
Nov 19 Python
python用pip install时安装失败的一系列问题及解决方法
Feb 24 Python
Python matplotlib模块及柱状图用法解析
Aug 10 Python
python处理写入数据代码讲解
Oct 22 Python
使用Pytorch训练two-head网络的操作
May 28 Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 #Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
You might like
浅析PHP微信支付通知的处理方式
2014/05/25 PHP
php实现搜索类封装示例
2016/03/31 PHP
基于jQuery中对数组进行操作的方法
2013/04/16 Javascript
JavaScript自执行闭包的小例子
2013/06/29 Javascript
jquery mobile事件多次绑定示例代码
2013/09/13 Javascript
Javascript中使用A标签获取当前目录的绝对路径方法
2015/03/02 Javascript
jQuery AjaxUpload 上传图片代码
2016/02/02 Javascript
jQuery 获取多选框的值及多选框中文的函数
2016/05/16 Javascript
Js调用Java方法并互相传参的简单实例
2016/08/11 Javascript
使用Ajax生成的Excel文件并下载的实例
2016/11/21 Javascript
jQuery Easy UI中根据第一个下拉框选中的值设置第二个下拉框是否可以编辑
2016/11/29 Javascript
如何解决React官方脚手架不支持Less的问题(小结)
2018/09/12 Javascript
常见的浏览器存储方式(cookie、localStorage、sessionStorage)
2019/05/07 Javascript
vue实现在线翻译功能
2019/09/27 Javascript
Vue Render函数原理及代码实例解析
2020/07/30 Javascript
vue 解决provide和inject响应的问题
2020/11/12 Javascript
Vue项目打包部署到apache服务器的方法步骤
2021/02/01 Vue.js
深入学习Python中的装饰器使用
2016/06/20 Python
python xlsxwriter库生成图表的应用示例
2018/03/16 Python
python 实现对数据集的归一化的方法(0-1之间)
2018/07/17 Python
Python+Pandas 获取数据库并加入DataFrame的实例
2018/07/25 Python
windows支持哪个版本的python
2020/07/03 Python
python opencv角点检测连线功能的实现代码
2020/11/24 Python
python excel和yaml文件的读取封装
2021/01/12 Python
利用CSS的Sass预处理器(框架)来制作居中效果
2016/03/10 HTML / CSS
美国经典刺绣和字母儿童服装特卖:Smocked Auctions
2018/07/16 全球购物
日本索尼音乐商店:Sony Music Shop
2018/07/17 全球购物
公司行政经理岗位职责
2013/12/24 职场文书
吸烟检讨书2000字
2014/02/13 职场文书
初二学习计划书范文
2014/04/27 职场文书
安全责任书
2015/01/29 职场文书
幼儿园食品安全责任书
2015/05/08 职场文书
孝女彩金观后感
2015/06/10 职场文书
分享几个JavaScript运算符的使用技巧
2021/04/24 Javascript
python中if和elif的区别介绍
2021/11/07 Python
Javascript中Microtask和Macrotask鲜为人知的知识点
2022/04/02 Javascript