Pytorch 使用 nii数据做输入数据的操作


Posted in Python onMay 26, 2020

使用pix2pix-gan做医学图像合成的时候,如果把nii数据转成png格式会损失很多信息,以为png格式图像的灰度值有256阶,因此直接使用nii的医学图像做输入会更好一点。

但是Pythorch中的Dataloader是不能直接读取nii图像的,因此加一个CreateNiiDataset的类。

先来了解一下pytorch中读取数据的主要途径——Dataset类。在自己构建数据层时都要基于这个类,类似于C++中的虚基类。

自己构建的数据层包含三个部分

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__(self, index):
 raise NotImplementedError
def __len__(self):
 raise NotImplementedError
def __add__(self, other):
 return ConcatDataset([self, other])

根据自己的需要编写CreateNiiDataset子类:

因为我是基于https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

做pix2pix-gan的实验,数据包含两个部分mr 和 ct,不需要标签,因此上面的 def getitem(self, index):中不需要index这个参数了,类似地,根据需要,加入自己的参数,去掉不需要的参数。

class CreateNiiDataset(Dataset):
 def __init__(self, opt, transform = None, target_transform = None):
  self.path1 = opt.dataroot # parameter passing
  self.A = 'MR' 
  self.B = 'CT'
  lines = os.listdir(os.path.join(self.path1, self.A))
  lines.sort()
  imgs = []
  for line in lines:
   imgs.append(line)
  self.imgs = imgs
  self.transform = transform
  self.target_transform = target_transform

 def crop(self, image, crop_size):
  shp = image.shape
  scl = [int((shp[0] - crop_size[0]) / 2), int((shp[1] - crop_size[1]) / 2)]
  image_crop = image[scl[0]:scl[0] + crop_size[0], scl[1]:scl[1] + crop_size[1]]
  return image_crop

 def __getitem__(self, item):
  file = self.imgs[item]
  img1 = sitk.ReadImage(os.path.join(self.path1, self.A, file))
  img2 = sitk.ReadImage(os.path.join(self.path1, self.B, file))
  data1 = sitk.GetArrayFromImage(img1)
  data2 = sitk.GetArrayFromImage(img2)

  if data1.shape[0] != 256:
   data1 = self.crop(data1, [256, 256])
   data2 = self.crop(data2, [256, 256])
  if self.transform is not None:
   data1 = self.transform(data1)
   data2 = self.transform(data2)

  if np.min(data1)<0:
   data1 = (data1 - np.min(data1))/(np.max(data1)-np.min(data1))

  if np.min(data2)<0:
   #data2 = data2 - np.min(data2)
   data2 = (data2 - np.min(data2))/(np.max(data2)-np.min(data2))

  data = {}
  data1 = data1[np.newaxis, np.newaxis, :, :]
  data1_tensor = torch.from_numpy(np.concatenate([data1,data1,data1], 1))
  data1_tensor = data1_tensor.type(torch.FloatTensor)
  data['A'] = data1_tensor # should be a tensor in Float Tensor Type

  data2 = data2[np.newaxis, np.newaxis, :, :]
  data2_tensor = torch.from_numpy(np.concatenate([data2,data2,data2], 1))
  data2_tensor = data2_tensor.type(torch.FloatTensor)
  data['B'] = data2_tensor # should be a tensor in Float Tensor Type
  data['A_paths'] = [os.path.join(self.path1, self.A, file)] # should be a list, with path inside
  data['B_paths'] = [os.path.join(self.path1, self.B, file)]
  return data

 def load_data(self):
  return self

 def __len__(self):
  return len(self.imgs)

注意:最后输出的data是一个字典,里面有四个keys=[‘A',‘B',‘A_paths',‘B_paths'], 一定要注意数据要转成FloatTensor。

其次是data[‘A_paths'] 接收的值是一个list,一定要加[ ] 扩起来,要不然测试存图的时候会有问题,找这个问题找了好久才发现。

然后直接在train.py的主函数里面把数据加载那行改掉就好了

data_loader = CreateNiiDataset(opt)
dataset = data_loader.load_data()

Over!

补充知识:nii格式图像存为npy格式

我就废话不多说了,大家还是直接看代码吧!

import nibabel as nib
import os
import numpy as np
 
img_path = '/home/lei/train/img/'
seg_path = '/home/lei/train/seg/'
saveimg_path = '/home/lei/train/npy_img/'
saveseg_path = '/home/lei/train/npy_seg/'
 
img_names = os.listdir(img_path)
seg_names = os.listdir(seg_path)
 
for img_name in img_names:
 print(img_name)
 img = nib.load(img_path + img_name).get_data() #载入
 img = np.array(img)
 np.save(saveimg_path + str(img_name).split('.')[0] + '.npy', img) #保存
 
for seg_name in seg_names:
 print(seg_name)
 seg = nib.load(seg_path + seg_name).get_data()
 seg = np.array(seg)
 np.save(saveseg_path + str(seg_name).split('.')[0] + '.npy

以上这篇Pytorch 使用 nii数据做输入数据的操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python sys模块sys.path使用方法示例
Dec 04 Python
python数据结构之二叉树的建立实例
Apr 29 Python
Python中利用Scipy包的SIFT方法进行图片识别的实例教程
Jun 03 Python
Python使用回溯法子集树模板解决迷宫问题示例
Sep 01 Python
django启动uwsgi报错的解决方法
Apr 08 Python
Python中defaultdict与lambda表达式用法实例小结
Apr 09 Python
十分钟利用Python制作属于你自己的个性logo
May 07 Python
解决使用pycharm提交代码时冲突之后文件丢失找回的方法
Aug 05 Python
使用python搭建服务器并实现Android端与之通信的方法
Jun 28 Python
python中使用.py配置文件的方法详解
Nov 23 Python
Python爬虫逆向分析某云音乐加密参数的实例分析
Dec 04 Python
自己搭建resnet18网络并加载torchvision自带权重的操作
May 13 Python
python变量的作用域是什么
May 26 #Python
Python3 pywin32模块安装的详细步骤
May 26 #Python
什么是python的列表推导式
May 26 #Python
python中列表的含义及用法
May 26 #Python
初学者学习Python好还是Java好
May 26 #Python
python函数map()和partial()的知识点总结
May 26 #Python
Python selenium使用autoIT上传附件过程详解
May 26 #Python
You might like
关于js和php对url编码的处理方法
2014/03/04 PHP
PHP大批量插入数据库的3种方法和速度对比
2014/07/08 PHP
php修改上传图片尺寸的方法
2015/04/14 PHP
PHP 7.0.2 正式版发布
2016/01/08 PHP
11个用于提高排版水平的基于jquery的文字效果插件
2012/09/14 Javascript
12行javascript代码绘制一个八卦图
2015/04/02 Javascript
JS实现超简单的仿QQ折叠菜单效果
2015/09/21 Javascript
Bootstrap 树控件使用经验分享(图文解说)
2017/11/06 Javascript
详解用场景去理解函数柯里化(入门篇)
2019/04/11 Javascript
微信小程序全局变量改变监听的实现方法
2019/07/15 Javascript
es6中let和const的使用方法详解
2020/02/24 Javascript
vue实现瀑布流组件滑动加载更多
2020/03/10 Javascript
Node.js设置定时任务之node-schedule模块的使用详解
2020/04/28 Javascript
解决vant的Toast组件时提示not defined的问题
2020/11/11 Javascript
[01:50]WODOTA制作 DOTA2中文宣传片《HERO》
2013/04/28 DOTA
[00:32]2018DOTA2亚洲邀请赛出场——LGD
2018/04/04 DOTA
Python的lambda匿名函数的简单介绍
2013/04/25 Python
Python解释执行原理分析
2014/08/22 Python
Centos5.x下升级python到python2.7版本教程
2015/02/14 Python
Python内置函数reversed()用法分析
2018/03/20 Python
tf.concat中axis的含义与使用详解
2020/02/07 Python
UNDONE手表官网:世界领先的定制手表品牌
2018/11/13 全球购物
The Body Shop美体小铺西班牙官网:天然化妆品
2019/06/21 全球购物
葡萄牙航空官方网站:TAP Air Portugal
2019/10/31 全球购物
正风肃纪剖析材料
2014/02/18 职场文书
《小猪家的桃花树》教学反思
2014/04/11 职场文书
初中升旗仪式演讲稿
2014/05/08 职场文书
小学语文教学经验交流材料
2014/06/02 职场文书
员工年终自我评价
2014/09/14 职场文书
2014年生活老师工作总结
2014/12/23 职场文书
冬季作息时间调整通知
2015/04/24 职场文书
盲山观后感
2015/06/11 职场文书
学生病假条范文
2015/08/17 职场文书
2016年度创先争优活动总结
2016/04/05 职场文书
python turtle绘图
2022/05/04 Python
前端传参数进行Mybatis调用mysql存储过程执行返回值详解
2022/08/14 MySQL