Pytorch distributed 多卡并行载入模型操作


Posted in Python onJune 05, 2021

一、Pytorch distributed 多卡并行载入模型

这次来介绍下如何载入模型。

目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。

大部分情况下,我们在测试时不需要多卡并行计算。

所以,我在测试时只使用单卡。

from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device)  #自己的模型
state_dict = torch.load(args.model_path)    #存放模型的位置

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
model.load_state_dict (new_state_dict)

二、pytorch DistributedParallel进行单机多卡训练

One_导入库:

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

Two_进程初始化:

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
# 添加必要参数
# local_rank:系统自动赋予的进程编号,可以利用该编号控制打印输出以及设置device

torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile',
rank=local_rank, world_size=world_size)

# world_size:所创建的进程数,也就是所使用的GPU数量
# (初始化设置详见参考文档)

Three_数据分发:

dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
# 使用DistributedSampler来为各个进程分发数据,其中num_replicas与world_size保持一致,用于将数据集等分成不重叠的数个子集

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler)
# 在Dataloader中指定sampler时,其中的shuffle必须为False,而DistributedSampler中的shuffle项默认为True,因此训练过程默认执行shuffle

Four_网络模型:

torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
# 设置每个进程对应的GPU设备

D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
# 由于在训练过程中各卡的前向后向传播均独立进行,因此无法进行统一的批归一化,如果想要将各卡的输出统一进行批归一化,需要将模型中的BN转换成SyncBN
   
D = torch.nn.parallel.DistributedDataParallel(
D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
# 如果有forward的返回值如果不在计算loss的计算图里,那么需要find_unused_parameters=True,即返回值不进入backward去算grad,也不需要在不同进程之间进行通信。

Five_迭代:

data_sampler.set_epoch(epoch)
# 每个epoch需要为sampler设置当前epoch

Six_加载:

dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
# 加载模型前后用dist.barrier()来同步不同进程间的快慢

Seven_启动:

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
# 用-m torch.distributed.launch启动,nproc_per_node为所使用的卡数,batchsize设置为每张卡各自的批大小

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
pycharm 使用心得(六)进行简单的数据库管理
Jun 06 Python
python中管道用法入门实例
Jun 04 Python
Python编程pygame模块实现移动的小车示例代码
Jan 03 Python
python url 参数修改方法
Dec 26 Python
Python IDE Pycharm中的快捷键列表用法
Aug 08 Python
对django2.0 关联表的必填on_delete参数的含义解析
Aug 09 Python
django ManyToManyField多对多关系的实例详解
Aug 09 Python
wxPython绘图模块wxPyPlot实现数据可视化
Nov 19 Python
关于pandas的离散化,面元划分详解
Nov 22 Python
利用pandas将非数值数据转换成数值的方式
Dec 18 Python
pytorch中tensor张量数据类型的转化方式
Dec 31 Python
Python 动态变量名定义与调用方法
Feb 09 Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
opencv 分类白天与夜景视频的方法
You might like
PHP序列化操作方法分析
2016/09/28 PHP
PHP进程通信基础之信号量与共享内存通信
2017/02/19 PHP
PHP实现从上往下打印二叉树的方法
2018/01/18 PHP
JavaScript DOM 添加事件
2009/02/14 Javascript
JS中 用户登录系统的解决办法
2013/04/15 Javascript
如何让浏览器支持jquery ajax load 前进、后退功能
2014/06/12 Javascript
javascript学习笔记(六)数据类型和JSON格式
2014/10/08 Javascript
javascript 常见功能汇总
2015/06/11 Javascript
详解JavaScript中常用的函数类型
2015/11/18 Javascript
学习JavaScript设计模式之享元模式
2016/01/18 Javascript
JS中使用DOM来控制HTML元素
2016/07/31 Javascript
简单学习vue指令directive
2016/11/03 Javascript
微信小程序实战之运维小项目
2017/01/17 Javascript
D3.js中强制异步文件读取同步的几种方法
2017/02/06 Javascript
jQuery Form插件使用详解_动力节点Java学院整理
2017/07/17 jQuery
node.js实现微信JS-API封装接口的示例代码
2017/09/06 Javascript
JavaScript面向对象精要(上部)
2017/09/12 Javascript
JS库particles.js创建超炫背景粒子插件(附源码下载)
2017/09/13 Javascript
基于vue组件实现猜数字游戏
2020/05/28 Javascript
脚手架vue-cli工程webpack的作用和特点
2018/09/29 Javascript
js实现轮播图效果 z-index实现轮播图
2020/01/17 Javascript
简单谈谈python中的Queue与多进程
2016/08/25 Python
VScode编写第一个Python程序HelloWorld步骤
2018/04/06 Python
Python基于类路径字符串获取静态属性
2020/03/12 Python
Python中的wordcloud库安装问题及解决方法
2020/05/27 Python
Python绘制组合图的示例
2020/09/18 Python
Sperry澳大利亚官网:源自美国帆船鞋创始品牌
2019/07/29 全球购物
Talbots官网:美国成熟女装品牌
2019/11/15 全球购物
ASOS西班牙官网:英国在线时尚和美容零售商
2020/01/10 全球购物
家长给孩子的表扬信
2014/01/17 职场文书
跟单业务员岗位职责
2014/03/08 职场文书
小学感恩教育活动总结
2014/07/07 职场文书
一篇文章学会Vue中间件管道
2021/06/20 Vue.js
Python语言规范之Pylint的详细用法
2021/06/24 Python
windows11怎么查看wifi密码? win11查看wifi密码的技巧
2021/11/21 数码科技
SQL使用复合索引实现数据库查询的优化
2022/05/25 SQL Server