pytorch::Dataloader中的迭代器和生成器应用详解


Posted in Python onJanuary 03, 2020

在使用pytorch训练模型,经常需要加载大量图片数据,因此pytorch提供了好用的数据加载工具Dataloader。

为了实现小批量循环读取大型数据集,在Dataloader类具体实现中,使用了迭代器和生成器。

这一应用场景正是python中迭代器模式的意义所在,因此本文对Dataloader中代码进行解读,可以更好的理解python中迭代器和生成器的概念。

本文的内容主要有:

  • 解释python中的迭代器和生成器概念
  • 解读pytorch中Dataloader代码,如何使用迭代器和生成器实现数据加载

python迭代基础

python中围绕着迭代有以下概念:

  1. 可迭代对象 iterables
  2. 迭代器 iterator
  3. 生成器 generator

这三个概念互相关联,并不是孤立的。在可迭代对象的基础上发展了迭代器,在迭代器的基础上又发展了生成器。

学习这些概念的名词解释没有多大意义。编程中很多的抽象概念都是为了更好的实现某些功能,才去人为创造的协议和模式。

因此,要理解它们,需要探究概念背后的逻辑,为什么这样设计?要解决的真正问题是什么?在哪些场景下应用是最好的?

迭代模式首先要解决的基础问题是,需要按一定顺序获取集合内部数据,比如循环某个list。

当数据很小时,不会有问题。但当读取大量数据时,一次性读取会超出内存限制,因此想出以下方法:

  • 把大的数据分成几个小块,分批处理
  • 惰性的取值方式,按需取值

循环读数据可分为下面三种应用场景,对应着容器(可迭代对象),迭代器和生成器:

  • for x in container: 为了遍历python内部序列容器(如list), 这些类型内部实现了__getitem__() 方法,可以从0开始按顺序遍历序列容器中的元素。
  • for x in iterator: 为了循环用户自定义的迭代器,需要实现__iter__和__next__方法,__iter__是迭代协议,具体每次迭代的执行逻辑在 __next__或next方法里
  • for x in generator: 为了节省循环的内存和加速,使用生成器来实现惰性加载,在迭代器的基础上加入了yield语句,最简单的例子是 range(5)

代码示例:

# 普通循环 for x in list
numbers = [1, 2, 3,]
for n in numbers:
  print(n) # 1,2,3

# for循环实际干的事情
# iter输入一个可迭代对象list,返回迭代器
# next方法取数据
my_iterator = iter(numbers)
next(my_iterator) # 1
next(my_iterator) # 2
next(my_iterator) # 3
next(my_iterator) # StopIteration exception

# 迭代器循环 for x in iterator
for i,n in enumerate(numbers):
  print(i,n) # 0,1 / 1,3 / 2,3

# 生成器循环 for x in generator
for i in range(3):
  print(i) # 0,1,2

上面示例代码中python内置函数iter和next的用法:

  • iter函数,调用__iter__,返回一个迭代器
  • next函数,输入迭代器,调用__next__,取出数据

比较容易混淆的是__iter__和__next__两个方法。它们的区别是:

  • __iter__是为了可以迭代,真正执行取数据的逻辑是__next__方法实现的,实际调用是通过next(iterator)完成
  • __iter__可以返回自身(return self),实际读取数据的实现放在__next__方法
  • __iter__可以和yield搭配,返回生成器对象

__iter__返回自身的做法有点类似 python中的类型系统。为了保持一致性,python中一切皆对象。

每个对象创建后,都有类型指针,而类型对象的指针指向元对象,元对象的指针指向自身。

生成器,是在__iter__方法中加入yield语句,好处有:

  • 减少循环判断逻辑的复杂度
  • 惰性取值,节省内存和时间

yield作用:

  • 代替函数中的return语句
  • 记住上一次循环迭代器内部元素的位置

三种循环模式常用函数

for x in container 方法:

  • list, deque, …
  • set, frozensets, …
  • dict, defaultdict, OrderedDict, Counter, …
  • tuple, namedtuple, …
  • str

for x in iterator 方法:

  • enumerate() # 加上list的index
  • sorted() # 排序list
  • reversed() # 倒序list
  • zip() # 合并list

for x in generator 方法:

  • range()
  • map()
  • filter()
  • reduce()
  • [x for x in list(...)]

Dataloder源码分析

pytorch采用 for x in iterator 模式,从Dataloader类中读取数据。

  1. 为了实现该迭代模式,在Dataloader内部实现__iter__方法,实际返回的是_DataLoaderIter类。
  2. _DataLoaderIter类里面,实现了 __iter__方法,返回自身,具体执行读数据的逻辑,在__next__方法中。

以下代码只截取了单线程下的数据读取。

class DataLoader(object):
  r"""
  Data loader. Combines a dataset and a sampler, and provides
  single- or multi-process iterators over the dataset.
  """
  def __init__(self, dataset, batch_size=1, shuffle=False, ...):
    self.dataset = dataset
    self.batch_sampler = batch_sampler
    ...
  
  def __iter__(self):
    return _DataLoaderIter(self)

  def __len__(self):
    return len(self.batch_sampler)

class _DataLoaderIter(object):
  r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
  def __init__(self, loader):
    self.sample_iter = iter(self.batch_sampler)
    ...

  def __next__(self):
    if self.num_workers == 0: # same-process loading
      indices = next(self.sample_iter) # may raise StopIteration
      batch = self.collate_fn([self.dataset[i] for i in indices])
      if self.pin_memory:
        batch = pin_memory_batch(batch)
      return batch
    ...

  def __iter__(self):
    return self

Dataloader类中读取数据Index的方法,采用了 for x in generator 方式,但是调用采用iter和next函数

  1. 构建随机采样类RandomSampler,内部实现了 __iter__方法
  2. __iter__方法内部使用了 yield,循环遍历数据集,当数量达到batch_size大小时,就返回
  3. 实例化随机采样类,传入iter函数,返回一个迭代器
  4. next会调用随机采样类中生成器,返回相应的index数据
class RandomSampler(object):
  """random sampler to yield a mini-batch of indices."""
  def __init__(self, batch_size, dataset, drop_last=False):
    self.dataset = dataset
    self.batch_size = batch_size
    self.num_imgs = len(dataset)
    self.drop_last = drop_last

  def __iter__(self):
    indices = np.random.permutation(self.num_imgs)
    batch = []
    for i in indices:
      batch.append(i)
      if len(batch) == self.batch_size:
        yield batch
        batch = []
    ## if images not to yield a batch
    if len(batch)>0 and not self.drop_last:
      yield batch


  def __len__(self):
    if self.drop_last:
      return self.num_imgs // self.batch_size
    else:
      return (self.num_imgs + self.batch_size - 1) // self.batch_size

batch_sampler = RandomSampler(batch_size. dataset)
sample_iter = iter(batch_sampler)
indices = next(sample_iter)

总结

本文总结了python中循环的三种模式:

  • for x in container 可迭代对象
  • for x in iterator 迭代器
  • for x in generator 生成器

pytorch中的数据加载模块 Dataloader,使用生成器来返回数据的索引,使用迭代器来返回需要的张量数据,可以在大量数据情况下,实现小批量循环迭代式的读取,避免了内存不足问题。

参考文章

迭代器和生成器
流畅的Python-第14章:可迭代的对象、迭代器和生成器
pytorch-dataloader源码

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现的简单猜数字游戏
Apr 04 Python
利用Python批量生成任意尺寸的图片
Aug 29 Python
Python计算两个日期相差天数的方法示例
May 23 Python
Python实现模拟分割大文件及多线程处理的方法
Oct 10 Python
tensorflow1.0学习之模型的保存与恢复(Saver)
Apr 23 Python
详谈Python3 操作系统与路径 模块(os / os.path / pathlib)
Apr 26 Python
python绘制立方体的方法
Jul 02 Python
python 3.7.0 安装配置方法图文教程
Aug 27 Python
python的set处理二维数组转一维数组的方法示例
May 31 Python
Python 批量刷博客园访问量脚本过程解析
Aug 30 Python
opencv 阈值分割的具体使用
Jul 08 Python
Python 绘制多因子柱状图
May 11 Python
django商品分类及商品数据建模实例详解
Jan 03 #Python
PyTorch和Keras计算模型参数的例子
Jan 02 #Python
Pytorch中实现只导入部分模型参数的方式
Jan 02 #Python
PyTorch中topk函数的用法详解
Jan 02 #Python
Pytorch训练过程出现nan的解决方式
Jan 02 #Python
pytorch绘制并显示loss曲线和acc曲线,LeNet5识别图像准确率
Jan 02 #Python
基于MSELoss()与CrossEntropyLoss()的区别详解
Jan 02 #Python
You might like
PHP之uniqid()函数用法
2014/11/03 PHP
3种php生成唯一id的方法
2015/11/23 PHP
JavaScript实现删除电脑的关机键
2016/07/26 PHP
fireworks菜单生成器mm_menu.js在 IE 7.0 显示问题的解决方法
2009/10/20 Javascript
javascript(jquery)利用函数修改全局变量的代码
2009/11/02 Javascript
Kibo 用于处理键盘事件的Javascript工具库
2011/10/28 Javascript
JS实现距离上次刷新已过多少秒示例
2014/05/23 Javascript
jquery获取img的src值的简单实例
2016/05/17 Javascript
picLazyLoad 实现图片延时加载(包含背景图片)
2016/07/21 Javascript
JS正则RegExp.test()使用注意事项(不具有重复性)
2016/12/28 Javascript
微信小程序滚动Tab实现左右可滑动切换
2017/08/17 Javascript
原生JS 购物车及购物页面的cookie使用方法
2017/08/21 Javascript
AngularJS监听ng-repeat渲染完成的两种方法
2018/01/16 Javascript
Vue官网todoMVC示例代码
2018/01/29 Javascript
9102年webpack4搭建vue项目的方法步骤
2019/02/20 Javascript
原生js实现抽奖小游戏
2019/06/27 Javascript
[02:23]完美世界全国高校联赛街访DOTA2第一期
2019/11/28 DOTA
python在windows下实现备份程序实例
2014/07/04 Python
编写简单的Python程序来判断文本的语种
2015/04/07 Python
Python psutil模块简单使用实例
2015/04/28 Python
Python探索之静态方法和类方法的区别详解
2017/10/27 Python
Python Tkinter模块实现时钟功能应用示例
2018/07/23 Python
Python通过TensorFLow进行线性模型训练原理与实现方法详解
2020/01/15 Python
如何使用Python处理HDF格式数据及可视化问题
2020/06/24 Python
Django Admin后台模型列表页面如何添加自定义操作按钮
2020/11/11 Python
在线学习西班牙语、法语或其他语言:Babbel.com
2018/02/07 全球购物
纽约市的奢华内衣目的地:Anya Lust
2019/08/02 全球购物
W Hamond官网:始于1979年的钻石专家
2020/07/20 全球购物
请编写一个 C 函数,该函数在给定的内存区域搜索给定的字符,并返回该字符所在位置索引值
2014/09/15 面试题
西安当代医院管理研究院笔试题
2015/12/11 面试题
建筑工程专业学生的自我评价
2013/12/25 职场文书
单位介绍信范文
2014/01/18 职场文书
函授毕业个人自我评价
2014/02/20 职场文书
六一儿童节标语
2014/10/08 职场文书
药品开票员岗位职责
2015/04/15 职场文书
2016秋季小学开学寄语
2015/12/03 职场文书