使用PyTorch将文件夹下的图片分为训练集和验证集实例


Posted in Python onJanuary 08, 2020

PyTorch提供了ImageFolder的类来加载文件结构如下的图片数据集:

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

使用这个类的问题在于无法将训练集(training dataset)和验证集(validation dataset)分开。我写了两个类来完成这个工作。

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Resize, Compose
from PIL import Image
from sklearn.model_selection import train_test_split

class ImageFolderSplitter:
  # images should be placed in folders like:
  # --root
  # ----root\dogs
  # ----root\dogs\image1.png
  # ----root\dogs\image2.png
  # ----root\cats
  # ----root\cats\image1.png
  # ----root\cats\image2.png  
  # path: the root of the image folder
  def __init__(self, path, train_size = 0.8):
    self.path = path
    self.train_size = train_size
    self.class2num = {}
    self.num2class = {}
    self.class_nums = {}
    self.data_x_path = []
    self.data_y_label = []
    self.x_train = []
    self.x_valid = []
    self.y_train = []
    self.y_valid = []
    for root, dirs, files in os.walk(path):
      if len(files) == 0 and len(dirs) > 1:
        for i, dir1 in enumerate(dirs):
          self.num2class[i] = dir1
          self.class2num[dir1] = i
      elif len(files) > 1 and len(dirs) == 0:
        category = ""
        for key in self.class2num.keys():
          if key in root:
            category = key
            break
        label = self.class2num[category]
        self.class_nums[label] = 0
        for file1 in files:
          self.data_x_path.append(os.path.join(root, file1))
          self.data_y_label.append(label)
          self.class_nums[label] += 1
      else:
        raise RuntimeError("please check the folder structure!")
    self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label, shuffle = True, train_size = self.train_size)

  def getTrainingDataset(self):
    return self.x_train, self.y_train

  def getValidationDataset(self):
    return self.x_valid, self.y_valid

class DatasetFromFilename(Dataset):
  # x: a list of image file full path
  # y: a list of image categories
  def __init__(self, x, y, transforms = None):
    super(DatasetFromFilename, self).__init__()
    self.x = x
    self.y = y
    if transforms == None:
      self.transforms = ToTensor()
    else:
      self.transforms = transforms
    
  def __len__(self):
    return len(self.x)

  def __getitem__(self, idx):
    img = Image.open(self.x[idx])
    img = img.convert("RGB")
    return self.transforms(img), torch.tensor([[self.y[idx]]])

# test code
# splitter = ImageFolderSplitter("for_test")
# transforms = Compose([Resize((51, 51)), ToTensor()])
# x_train, y_train = splitter.getTrainingDataset()
# training_dataset = DatasetFromFilename(x_train, y_train, transforms=transforms)
# training_dataloader = DataLoader(training_dataset, batch_size=2, shuffle=True)
# x_valid, y_valid = splitter.getValidationDataset()
# validation_dataset = DatasetFromFilename(x_valid, y_valid, transforms=transforms)
# validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)
# for x, y in training_dataloader:
#   print(x.shape, y.shape)

更多的代码可以在我的Github reop下找到。

Python 相关文章推荐
python使用点操作符访问字典(dict)数据的方法
Mar 16 Python
python模块之time模块(实例讲解)
Sep 13 Python
python读取与写入csv格式文件的示例代码
Dec 16 Python
Python3安装Pillow与PIL的方法
Apr 03 Python
OpenCV 边缘检测
Jul 10 Python
使用 Python 快速实现 HTTP 和 FTP 服务器的方法
Jul 22 Python
python获取Pandas列名的几种方法
Aug 07 Python
pygame实现非图片按钮效果
Oct 29 Python
pytorch载入预训练模型后,实现训练指定层
Jan 06 Python
基于Django OneToOneField和ForeignKey的区别详解
Mar 30 Python
使用Python实现NBA球员数据查询小程序功能
Nov 09 Python
Django利用elasticsearch(搜索引擎)实现搜索功能
Nov 26 Python
使用 PyTorch 实现 MLP 并在 MNIST 数据集上验证方式
Jan 08 #Python
Pycharm 2020最新永久激活码(附最新激活码和插件)
Sep 17 #Python
将matplotlib绘图嵌入pyqt的方法示例
Jan 08 #Python
pyinstaller还原python代码过程图解
Jan 08 #Python
python Tensor和Array对比分析
Jan 08 #Python
Pycharm小白级简单使用教程
Jan 08 #Python
python如何实现不可变字典inmutabledict
Jan 08 #Python
You might like
PHP 多维数组排序(usort,uasort)
2010/06/30 PHP
Codeigniter注册登录代码示例
2014/06/12 PHP
zend framework重定向方法小结
2016/05/28 PHP
PHP文件操作实例总结【文件上传、下载、分页】
2018/12/08 PHP
js 弹出菜单/窗口效果
2011/10/30 Javascript
jQuery中对节点进行操作的相关介绍
2013/04/16 Javascript
IE下window.onresize 多次调用与死循环bug处理方法介绍
2013/11/12 Javascript
jQuery实现异步获取json数据的2种方式
2014/08/29 Javascript
利用JavaScript的AngularJS库制作电子名片的方法
2015/06/18 Javascript
jquery控制页面部分刷新的方法
2015/06/24 Javascript
深入理解Ajax的get和post请求
2016/06/02 Javascript
jQuery中Nicescroll滚动条插件的用法
2016/11/10 Javascript
node.js学习之交互式解释器REPL详解
2016/12/08 Javascript
详解webpack + vue + node 打造单页面(入门篇)
2017/09/23 Javascript
前端必备插件之纯原生JS的瀑布流插件Macy.js
2017/11/22 Javascript
nodejs nedb 封装库与使用方法示例
2020/02/06 NodeJs
详解vue路由
2020/08/05 Javascript
Vue实现导航栏菜单
2020/08/19 Javascript
SpringBoot+Vue 前后端合并部署的配置方法
2020/12/30 Vue.js
[00:20]TI9观赛名额抽取Ⅱ
2019/07/24 DOTA
python用ConfigObj读写配置文件的实现代码
2013/03/04 Python
Python之re操作方法(详解)
2017/06/14 Python
python使用Plotly绘图工具绘制柱状图
2019/04/01 Python
PyQt5显示GIF图片的方法
2019/06/17 Python
Python使用pyserial进行串口通信的实例
2019/07/02 Python
Python线程指南分享
2019/11/19 Python
django模型动态修改参数,增加 filter 字段的方式
2020/03/16 Python
CSS3制作酷炫的三维相册效果
2016/07/01 HTML / CSS
HTML5的Geolocation地理位置定位API使用教程
2016/05/12 HTML / CSS
芝加哥牛排公司:Chicago Steak Company
2018/10/31 全球购物
大学生找工作推荐信范文
2013/11/28 职场文书
酒店端午节活动方案
2014/08/26 职场文书
2015年实习班主任工作总结
2015/04/23 职场文书
2016年119消防宣传日活动总结
2016/04/05 职场文书
导游词之上海杜莎夫人蜡像馆
2019/11/22 职场文书
前端传参数进行Mybatis调用mysql存储过程执行返回值详解
2022/08/14 MySQL