Pytorch distributed 多卡并行载入模型操作


Posted in Python onJune 05, 2021

一、Pytorch distributed 多卡并行载入模型

这次来介绍下如何载入模型。

目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。

大部分情况下,我们在测试时不需要多卡并行计算。

所以,我在测试时只使用单卡。

from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device)  #自己的模型
state_dict = torch.load(args.model_path)    #存放模型的位置

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
model.load_state_dict (new_state_dict)

二、pytorch DistributedParallel进行单机多卡训练

One_导入库:

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

Two_进程初始化:

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
# 添加必要参数
# local_rank:系统自动赋予的进程编号,可以利用该编号控制打印输出以及设置device

torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile',
rank=local_rank, world_size=world_size)

# world_size:所创建的进程数,也就是所使用的GPU数量
# (初始化设置详见参考文档)

Three_数据分发:

dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
# 使用DistributedSampler来为各个进程分发数据,其中num_replicas与world_size保持一致,用于将数据集等分成不重叠的数个子集

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler)
# 在Dataloader中指定sampler时,其中的shuffle必须为False,而DistributedSampler中的shuffle项默认为True,因此训练过程默认执行shuffle

Four_网络模型:

torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
# 设置每个进程对应的GPU设备

D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
# 由于在训练过程中各卡的前向后向传播均独立进行,因此无法进行统一的批归一化,如果想要将各卡的输出统一进行批归一化,需要将模型中的BN转换成SyncBN
   
D = torch.nn.parallel.DistributedDataParallel(
D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
# 如果有forward的返回值如果不在计算loss的计算图里,那么需要find_unused_parameters=True,即返回值不进入backward去算grad,也不需要在不同进程之间进行通信。

Five_迭代:

data_sampler.set_epoch(epoch)
# 每个epoch需要为sampler设置当前epoch

Six_加载:

dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
# 加载模型前后用dist.barrier()来同步不同进程间的快慢

Seven_启动:

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
# 用-m torch.distributed.launch启动,nproc_per_node为所使用的卡数,batchsize设置为每张卡各自的批大小

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用Python程序抓取新浪在国内的所有IP的教程
May 04 Python
Python字符串逐字符或逐词反转方法
May 21 Python
Python3学习笔记之列表方法示例详解
Oct 06 Python
django自带serializers序列化返回指定字段的方法
Aug 21 Python
Python统计文本词汇出现次数的实例代码
Feb 27 Python
Django框架获取form表单数据方式总结
Apr 22 Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 Python
有关pycharm登录github时有的时候会报错connection reset的问题
Sep 15 Python
安装不同版本的tensorflow与models方法实现
Feb 20 Python
PyTorch 如何自动计算梯度
May 23 Python
python引入其他文件夹下的py文件具体方法
May 23 Python
Matplotlib绘制混淆矩阵的实现
May 27 Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
opencv 分类白天与夜景视频的方法
You might like
phpmyadmin 常用选项设置详解版
2010/03/07 PHP
php去除换行(回车换行)的三种方法
2014/03/26 PHP
PHP按行读取、处理较大CSV文件的代码实例
2014/04/09 PHP
php实现检查文章是否被百度收录
2015/01/27 PHP
php基于闭包实现函数的自调用(递归)实例分析
2016/11/11 PHP
php 如何设置一个严格控制过期时间的session
2017/05/05 PHP
关于Javascript 的 prototype问题。
2007/01/03 Javascript
学习ExtJS border布局
2009/10/08 Javascript
基于node.js的快速开发透明代理
2010/12/25 Javascript
jquery DIV撑大让滚动条滚到最底部代码
2013/06/06 Javascript
javascript用户注册提示效果的简单实例
2013/08/17 Javascript
JavaScript实现SHA-1加密算法的方法
2015/03/11 Javascript
BootStrap入门教程(二)之固定的内置样式
2016/09/19 Javascript
webpack4 处理SCSS的方法示例
2018/09/03 Javascript
angularjs性能优化的方法
2018/09/05 Javascript
vue悬浮可拖拽悬浮按钮的实例代码
2019/08/20 Javascript
python+django快速实现文件上传
2016/10/24 Python
Python生成MD5值的两种方法实例分析
2019/04/26 Python
python config文件的读写操作示例
2019/09/27 Python
Python上下文管理器类和上下文管理器装饰器contextmanager用法实例分析
2019/11/07 Python
python修改文件内容的3种方法详解
2019/11/15 Python
Python如何使用函数做字典的值
2019/11/30 Python
python单例设计模式实现解析
2020/01/07 Python
如何搭建pytorch环境的方法步骤
2020/05/06 Python
详细分析Python可变对象和不可变对象
2020/07/09 Python
Python3爬虫中Selenium的用法详解
2020/07/10 Python
CSS3 实现弹幕的示例代码
2017/08/07 HTML / CSS
匡威俄罗斯官网:Converse俄罗斯
2020/05/09 全球购物
校园门卫岗位职责
2013/12/09 职场文书
致跳高运动员广播稿
2014/01/13 职场文书
简历的自我评价范文
2014/02/04 职场文书
工程负责人任命书
2014/06/06 职场文书
团组织推优材料
2014/12/29 职场文书
酒店人事专员岗位职责
2015/04/07 职场文书
详解python的异常捕获
2022/03/03 Python
一次SQL查询优化原理分析(900W+数据从17s到300ms)
2022/06/10 SQL Server