pytorch ImageFolder的覆写实例


Posted in Python onFebruary 20, 2020

在为数据分类训练分类器的时候,比如猫狗分类时,我们经常会使用pytorch的ImageFolder:

CLASS torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)

使用可见pytorch torchvision.ImageFolder的用法介绍

这里想实现的是如果想要覆写该函数,即能使用它的特性,又可以实现自己的功能

首先先分析下其源代码:

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp']

class ImageFolder(DatasetFolder):
 """A generic data loader where the images are arranged in this way: ::

  root/dog/xxx.png
  root/dog/xxy.png
  root/dog/xxz.png

  root/cat/123.png
  root/cat/nsdf3.png
  root/cat/asd932_.png

 Args:
  root (string): Root directory path.
  transform (callable, optional): A function/transform that takes in an PIL image
   and returns a transformed version. E.g, ``transforms.RandomCrop``
  target_transform (callable, optional): A function/transform that takes in the
   target and transforms it.
  loader (callable, optional): A function to load an image given its path.

  Attributes:
  classes (list): List of the class names.
  class_to_idx (dict): Dict with items (class_name, class_index).
  imgs (list): List of (image path, class_index) tuples
 """
 def __init__(self, root, transform=None, target_transform=None,
     loader=default_loader):
  super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
           transform=transform,
           target_transform=target_transform)
  self.imgs = self.samples

ImageFolder的代码很简单,主要是继承了DatasetFolder:

def has_file_allowed_extension(filename, extensions):
 """查看文件是否是支持的可扩展类型

 Args:
  filename (string): 文件路径
  extensions (iterable of strings): 可扩展类型列表,即能接受的图像文件类型

 Returns:
  bool: True if the filename ends with one of given extensions
 """
 filename_lower = filename.lower()
 return any(filename_lower.endswith(ext) for ext in extensions) # 返回True或False列表


def make_dataset(dir, class_to_idx, extensions):
 """
  返回形如[(图像路径, 该图像对应的类别索引值),(),...]
 """
 images = []
 dir = os.path.expanduser(dir)
 for target in sorted(class_to_idx.keys()):
  d = os.path.join(dir, target)
  if not os.path.isdir(d):
   continue

  for root, _, fnames in sorted(os.walk(d)): #层层遍历文件夹,返回当前文件夹路径,存在的所有文件夹名,存在的所有文件名
   for fname in sorted(fnames):
    if has_file_allowed_extension(fname, extensions):查看文件是否是支持的可扩展类型,是则继续
     path = os.path.join(root, fname)
     item = (path, class_to_idx[target])
     images.append(item)

 return images

class DatasetFolder(data.Dataset):
 """A generic data loader where the samples are arranged in this way: ::

  root/class_x/xxx.ext
  root/class_x/xxy.ext
  root/class_x/xxz.ext

  root/class_y/123.ext
  root/class_y/nsdf3.ext
  root/class_y/asd932_.ext

 Args:
  root (string): 根目录路径
  loader (callable): 根据给定的路径来加载样本的可调用函数
  extensions (list[string]): 可扩展类型列表,即能接受的图像文件类型.
  transform (callable, optional): 用于样本的transform函数,然后返回样本transform后的版本
   E.g, ``transforms.RandomCrop`` for images.
  target_transform (callable, optional): 用于样本标签的transform函数

  Attributes:
  classes (list): 类别名列表
  class_to_idx (dict): 项目(class_name, class_index)字典,如{'cat': 0, 'dog': 1}
  samples (list): (sample path, class_index) 元组列表,即(样本路径, 类别索引)
  targets (list): 在数据集中每张图片的类索引值,为列表
 """

 def __init__(self, root, loader, extensions, transform=None, target_transform=None):
  classes, class_to_idx = self._find_classes(root) # 得到类名和类索引,如['cat', 'dog']和{'cat': 0, 'dog': 1}
  # 返回形如[(图像路径, 该图像对应的类别索引值),(),...],即对每个图像进行标记
  samples = make_dataset(root, class_to_idx, extensions) 
  if len(samples) == 0:
   raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
        "Supported extensions are: " + ",".join(extensions)))

  self.root = root
  self.loader = loader
  self.extensions = extensions

  self.classes = classes
  self.class_to_idx = class_to_idx
  self.samples = samples
  self.targets = [s[1] for s in samples] #所有图像的类索引值组成的列表

  self.transform = transform
  self.target_transform = target_transform

 def _find_classes(self, dir):
  """
  在数据集中查找类文件夹。

  Args:
   dir (string): 根目录路径

  Returns:
   返回元组: (classes, class_to_idx)即(类名, 类索引),其中classes即相应的目录名,如['cat', 'dog'];class_to_idx为形如{类名:类索引}的字典,如{'cat': 0, 'dog': 1}.

  Ensures:
   保证没有类名是另一个类目录的子目录
  """
  if sys.version_info >= (3, 5):
   # Faster and available in Python 3.5 and above
   classes = [d.name for d in os.scandir(dir) if d.is_dir()] #获得根目录dir的所有第一层子目录名
  else:
   classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] #效果和上面的一样,只是版本不同方法不同
  classes.sort() #然后对类名进行排序
  class_to_idx = {classes[i]: i for i in range(len(classes))} #然后将类名和索引值一一对应的到相应字典,如{'cat': 0, 'dog': 1}
  return classes, class_to_idx #然后返回类名和类索引

 def __getitem__(self, index):
  """
  Args:
   index (int): Index

  Returns:
   tuple: (sample, target) where target is class_index of the target class.
  """
  path, target = self.samples[index]
  sample = self.loader(path) # 加载图片
  if self.transform is not None:
   sample = self.transform(sample)
  if self.target_transform is not None:
   target = self.target_transform(target)

  return sample, target

 def __len__(self):
  return len(self.samples)

 def __repr__(self):
  fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
  fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
  fmt_str += ' Root Location: {}\n'.format(self.root)
  tmp = ' Transforms (if any): '
  fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  tmp = ' Target Transforms (if any): '
  fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
  return fmt_str

此时想要覆写ImageFolder,代码为:

class CustomImageFolder(ImageFolder):
 """
  为了得到两张图(其中一张是随机选取的)的图像和索引值信息
 """
 def __init__(self, root, transform=None):
  super(CustomImageFolder, self).__init__(root, transform)
  self.indices = range(len(self)) #该文件夹中的长度

 def __getitem__(self, index1):
  index2 = random.choice(self.indices) #从[0,indices]中随机抽取一个数字,为了随机选取一张图

  path1 = self.imgs[index1][0] #此时的self.imgs等于self.samples,即内容为[(图像路径, 该图像对应的类别索引值),(),...]
  label1 = self.imgs[index1][1]
  path2 = self.imgs[index2][0]
  label2 = self.imgs[index2][1]

  img1 = self.loader(path1)
  img2 = self.loader(path2)
  if self.transform is not None:
   img1 = self.transform(img1)
   img2 = self.transform(img2)

  return img1, img2, label1, label2

以上这篇pytorch ImageFolder的覆写实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中处理字符串之endswith()方法的使用简介
May 18 Python
python查找指定具有相同内容文件的方法
Jun 28 Python
Windows系统下使用flup搭建Nginx和Python环境的方法
Dec 25 Python
Django实现自定义404,500页面教程
Mar 26 Python
Python实现抓取网页生成Excel文件的方法示例
Aug 05 Python
用tensorflow搭建CNN的方法
Mar 05 Python
一文带你了解Python中的字符串是什么
Nov 20 Python
使用Matplotlib绘制不同颜色的带箭头的线实例
Apr 17 Python
Python flask框架端口失效解决方案
Jun 04 Python
python实现按键精灵找色点击功能教程,使用pywin32和Pillow库
Jun 04 Python
Python2与Python3关于字符串编码处理的差别总结
Sep 07 Python
python如何获得list或numpy数组中最大元素对应的索引
Nov 16 Python
pytorch torchvision.ImageFolder的用法介绍
Feb 20 #Python
详解python常用命令行选项与环境变量
Feb 20 #Python
用什么库写 Python 命令行程序(示例代码详解)
Feb 20 #Python
在 Linux/Mac 下为Python函数添加超时时间的方法
Feb 20 #Python
Python os模块常用方法和属性总结
Feb 20 #Python
Python requests获取网页常用方法解析
Feb 20 #Python
pytorch实现保证每次运行使用的随机数都相同
Feb 20 #Python
You might like
PHP生成不重复标识符的方法
2014/11/21 PHP
PHP获取Exif缩略图的方法
2015/07/13 PHP
利用php + Laravel如何实现部署自动化详解
2017/10/11 PHP
laravel框架中路由设置,路由参数和路由命名实例分析
2019/11/23 PHP
TP框架实现上传一张图片和批量上传图片的方法分析
2020/04/23 PHP
laravel开发环境homestead搭建过程详解
2020/07/03 PHP
让iframe框架网页在任何浏览器下自动伸缩
2006/08/18 Javascript
js几个验证函数代码
2010/03/25 Javascript
在jQuery1.5中使用deferred对象 着放大镜看Promise
2011/03/12 Javascript
Jquery同辈元素选中/未选中效果的实例代码
2013/08/01 Javascript
文本有关的样式和jQuery求对象的高宽问题分别说明
2013/08/30 Javascript
javascript列表框操作函数集合汇总
2013/11/28 Javascript
javascript操作excel生成报表全攻略
2014/05/04 Javascript
初识Node.js
2015/03/20 Javascript
基于JavaScript实现生成名片、链接等二维码
2015/09/20 Javascript
js如何改变文章的字体大小
2016/01/08 Javascript
jQuery异步提交表单实例
2017/05/30 jQuery
nodejs mysql 实现分页的方法
2017/06/06 NodeJs
JavaScript实现提交模式窗口后刷新父窗口数据的方法
2017/06/16 Javascript
微信JSAPI Ticket接口签名详解
2020/06/28 Javascript
解决vue移动端适配问题
2018/12/12 Javascript
在vue中使用vuex,修改state的值示例
2019/11/08 Javascript
es6中let和const的使用方法详解
2020/02/24 Javascript
js实现磁性吸附的示例
2020/10/26 Javascript
Python3批量移动指定文件到指定文件夹方法示例
2019/09/02 Python
pytorch多GPU并行运算的实现
2019/09/27 Python
Win10下安装并使用tensorflow-gpu1.8.0+python3.6全过程分析(显卡MX250+CUDA9.0+cudnn)
2020/02/17 Python
对Python中 \r, \n, \r\n的彻底理解
2020/03/06 Python
一篇文章教你用python画动态爱心表白
2020/11/22 Python
matplotlib部件之矩形选区(RectangleSelector)的实现
2021/02/01 Python
Pat McGrath Labs官网:世界上最有影响力的化妆师推出的彩妆品牌
2018/01/07 全球购物
Nike比利时官网:Nike.com (BE)
2019/02/07 全球购物
JD Sports丹麦:英国领先的运动时尚零售商
2020/11/24 全球购物
文秘大学生求职信
2014/02/25 职场文书
中等生评语大全
2014/05/04 职场文书
Android移动应用开发指南之六种布局详解
2022/09/23 Java/Android