pytorch制作自己的LMDB数据操作示例


Posted in Python onDecember 18, 2019

本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:

前言

记录下pytorch里如何使用lmdb的code,自用

制作部分的Code

code就是ASTER里数据制作部分的代码改了点,aster_train.txt里面就算图片的完整路径每行一个,图片同目录下有同名的txt,里面记着jpg的标签

import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import numpy as np
from tqdm import tqdm
import six
from PIL import Image
import scipy.io as sio
from tqdm import tqdm
import re
def checkImageIsValid(imageBin):
 if imageBin is None:
  return False
 imageBuf = np.fromstring(imageBin, dtype=np.uint8)
 img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
 imgH, imgW = img.shape[0], img.shape[1]
 if imgH * imgW == 0:
  return False
 return True
def writeCache(env, cache):
 with env.begin(write=True) as txn:
  for k, v in cache.items():
   txn.put(k.encode(), v)
def _is_difficult(word):
 assert isinstance(word, str)
 return not re.match('^[\w]+$', word)
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
 """
 Create LMDB dataset for CRNN training.
 ARGS:
   outputPath  : LMDB output path
   imagePathList : list of image path
   labelList   : list of corresponding groundtruth texts
   lexiconList  : (optional) list of lexicon lists
   checkValid  : if true, check the validity of every image
 """
 assert(len(imagePathList) == len(labelList))
 nSamples = len(imagePathList)
 env = lmdb.open(outputPath, map_size=1099511627776)#最大空间1048576GB
 cache = {}
 cnt = 1
 for i in range(nSamples):
  imagePath = imagePathList[i]
  label = labelList[i]
  if len(label) == 0:
   continue
  if not os.path.exists(imagePath):
   print('%s does not exist' % imagePath)
   continue
  with open(imagePath, 'rb') as f:
   imageBin = f.read()
  if checkValid:
   if not checkImageIsValid(imageBin):
    print('%s is not a valid image' % imagePath)
    continue
  #数据库中都是二进制数据
  imageKey = 'image-%09d' % cnt#9位数不足填零
  labelKey = 'label-%09d' % cnt
  cache[imageKey] = imageBin
  cache[labelKey] = label.encode()
  if lexiconList:
   lexiconKey = 'lexicon-%09d' % cnt
   cache[lexiconKey] = ' '.join(lexiconList[i])
  if cnt % 1000 == 0:
   writeCache(env, cache)
   cache = {}
   print('Written %d / %d' % (cnt, nSamples))
  cnt += 1
 nSamples = cnt-1
 cache['num-samples'] = str(nSamples).encode()
 writeCache(env, cache)
 print('Created dataset with %d samples' % nSamples)
def get_sample_list(txt_path:str):
  with open(txt_path,'r') as fr:
    jpg_list=[x.strip() for x in fr.readlines() if os.path.exists(x.replace('.jpg','.txt').strip())]
  txt_content_list=[]
  for jpg in jpg_list:
    label_path=jpg.replace('.jpg','.txt')
    with open(label_path,'r') as fr:
      try:
        str_tmp=fr.readline()
      except UnicodeDecodeError as e:
        print(label_path)
        raise(e)
      txt_content_list.append(str_tmp.strip())
  return jpg_list,txt_content_list
if __name__ == "__main__":
 txt_path='/home/gpu-server/disk/disk1/NumberData/8NumberSample/aster_train.txt'
 lmdb_output_path = '/home/gpu-server/project/aster/dataset/train'
 imagePathList,labelList=get_sample_list(txt_path)
 createDataset(lmdb_output_path, imagePathList, labelList)

读取部分

这里用的pytorch的dataloader,简单记录一下,人比较懒,代码就直接抄过来,不整理拆分了,重点看__getitem__

from __future__ import absolute_import
# import sys
# sys.path.append('./')
import os
# import moxing as mox
import pickle
from tqdm import tqdm
from PIL import Image, ImageFile
import numpy as np
import random
import cv2
import lmdb
import sys
import six
import torch
from torch.utils import data
from torch.utils.data import sampler
from torchvision import transforms
from lib.utils.labelmaps import get_vocabulary, labels2strs
from lib.utils import to_numpy
ImageFile.LOAD_TRUNCATED_IMAGES = True
from config import get_args
global_args = get_args(sys.argv[1:])
if global_args.run_on_remote:
 import moxing as mox
 #moxing是一个分布式的框架 跳过
class LmdbDataset(data.Dataset):
 def __init__(self, root, voc_type, max_len, num_samples, transform=None):
  super(LmdbDataset, self).__init__()
  if global_args.run_on_remote:
   dataset_name = os.path.basename(root)
   data_cache_url = "/cache/%s" % dataset_name
   if not os.path.exists(data_cache_url):
    os.makedirs(data_cache_url)
   if mox.file.exists(root):
    mox.file.copy_parallel(root, data_cache_url)
   else:
    raise ValueError("%s not exists!" % root)
   self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True)
  else:
   self.env = lmdb.open(root, max_readers=32, readonly=True)
  assert self.env is not None, "cannot create lmdb from %s" % root
  self.txn = self.env.begin()
  self.voc_type = voc_type
  self.transform = transform
  self.max_len = max_len
  self.nSamples = int(self.txn.get(b"num-samples"))
  self.nSamples = min(self.nSamples, num_samples)
  assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS','DIGITS']
  self.EOS = 'EOS'
  self.PADDING = 'PADDING'
  self.UNKNOWN = 'UNKNOWN'
  self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
  self.char2id = dict(zip(self.voc, range(len(self.voc))))
  self.id2char = dict(zip(range(len(self.voc)), self.voc))
  self.rec_num_classes = len(self.voc)
  self.lowercase = (voc_type == 'LOWERCASE')
 def __len__(self):
  return self.nSamples
 def __getitem__(self, index):
  assert index <= len(self), 'index range error'
  index += 1
  img_key = b'image-%09d' % index
  imgbuf = self.txn.get(img_key)
  #由于Image.open需要一个类文件对象 所以这里需要把二进制转为一个类文件对象
  buf = six.BytesIO()
  buf.write(imgbuf)
  buf.seek(0)
  try:
   img = Image.open(buf).convert('RGB')
   # img = Image.open(buf).convert('L')
   # img = img.convert('RGB')
  except IOError:
   print('Corrupted image for %d' % index)
   return self[index + 1]
  # reconition labels
  label_key = b'label-%09d' % index
  word = self.txn.get(label_key).decode()
  if self.lowercase:
   word = word.lower()
  ## fill with the padding token
  label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int)
  label_list = []
  for char in word:
   if char in self.char2id:
    label_list.append(self.char2id[char])
   else:
    ## add the unknown token
    print('{0} is out of vocabulary.'.format(char))
    label_list.append(self.char2id[self.UNKNOWN])
  ## add a stop token
  label_list = label_list + [self.char2id[self.EOS]]
  assert len(label_list) <= self.max_len
  label[:len(label_list)] = np.array(label_list)
  if len(label) <= 0:
   return self[index + 1]
  # label length
  label_len = len(label_list)
  if self.transform is not None:
   img = self.transform(img)
  return img, label, label_len

更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》

希望本文所述对大家Python程序设计有所帮助。

Python 相关文章推荐
Python标准异常和异常处理详解
Feb 02 Python
python 实现自动远程登陆scp文件实例代码
Mar 13 Python
简述:我为什么选择Python而不是Matlab和R语言
Nov 14 Python
Python2.7下安装Scrapy框架步骤教程
Dec 22 Python
Python Pandas找到缺失值的位置方法
Apr 12 Python
浅谈Pandas:Series和DataFrame间的算术元素
Dec 22 Python
Python面向对象程序设计示例小结
Jan 30 Python
selenium python 实现基本自动化测试的示例代码
Feb 25 Python
对Python 中矩阵或者数组相减的法则详解
Aug 26 Python
关于Tensorflow 模型持久化详解
Feb 12 Python
利用python 下载bilibili视频
Nov 13 Python
详解OpenCV曝光融合
Apr 29 Python
Python Gluon参数和模块命名操作教程
Dec 18 #Python
python turtle 绘制太极图的实例
Dec 18 #Python
Python使用gluon/mxnet模块实现的mnist手写数字识别功能完整示例
Dec 18 #Python
简单了解Python读取大文件代码实例
Dec 18 #Python
python 比较2张图片的相似度的方法示例
Dec 18 #Python
使用Python的Turtle库绘制森林的实例
Dec 18 #Python
python3 requests库实现多图片爬取教程
Dec 18 #Python
You might like
php获取网页内容方法总结
2008/12/04 PHP
PHP中用正则表达式清除字符串的空白
2011/01/17 PHP
ThinkPHP验证码使用简明教程
2014/03/05 PHP
php使用date和strtotime函数输出指定日期的方法
2014/11/14 PHP
Ajax中的JSON格式与php传输过程全面解析
2017/11/14 PHP
jQuery 使用个人心得
2009/02/26 Javascript
jQuery遍历Table应用示例
2014/04/09 Javascript
javascript原生和jquery库实现iframe自适应高度和宽度
2014/07/18 Javascript
全面解析Bootstrap布局组件应用
2016/02/22 Javascript
JavaScript中Number对象的toFixed() 方法详解
2016/09/02 Javascript
jQuery实现鼠标响应式淘宝动画效果示例
2018/02/13 jQuery
关于vue的npm run dev和npm run build的区别介绍
2019/01/14 Javascript
AngularJS实现的鼠标拖动画矩形框示例【可兼容IE8】
2019/05/17 Javascript
layui表格数据重载
2019/07/27 Javascript
python+Django+apache的配置方法详解
2016/06/01 Python
python 控制Asterisk AMI接口外呼电话的例子
2019/08/08 Python
在pytorch 中计算精度、回归率、F1 score等指标的实例
2020/01/18 Python
使用tensorflow DataSet实现高效加载变长文本输入
2020/01/20 Python
keras 特征图可视化实例(中间层)
2020/01/24 Python
python可视化text()函数使用详解
2020/02/11 Python
Python基于百度AI实现OCR文字识别
2020/04/02 Python
python中urllib.request和requests的使用及区别详解
2020/05/05 Python
python实现最短路径的实例方法
2020/07/19 Python
悦木之源美国官网:Origins美国
2016/08/01 全球购物
卡西欧B级产品官方网站:Casio Outlet
2018/05/22 全球购物
新英格兰最大的特色礼品连锁店:The Paper Store
2018/07/23 全球购物
英国日常交易网站:Wowcher
2018/09/04 全球购物
中国专业的音频分享平台:喜马拉雅
2019/05/24 全球购物
积极贯彻学习两会精神总结
2014/03/17 职场文书
处级干部反四风个人对照检查材料思想汇报
2014/09/27 职场文书
2014年干部培训工作总结
2014/12/17 职场文书
优秀党员主要事迹材料
2015/11/04 职场文书
导游词之太行山青龙峡
2020/01/14 职场文书
Redis安装启动及常见数据类型
2021/04/14 Redis
python实现的人脸识别打卡系统
2021/05/08 Python
Windows下redis下载、redis安装及使用教程
2021/06/02 Redis