Pytorch distributed 多卡并行载入模型操作


Posted in Python onJune 05, 2021

一、Pytorch distributed 多卡并行载入模型

这次来介绍下如何载入模型。

目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。

大部分情况下,我们在测试时不需要多卡并行计算。

所以,我在测试时只使用单卡。

from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device)  #自己的模型
state_dict = torch.load(args.model_path)    #存放模型的位置

new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
    # load params
model.load_state_dict (new_state_dict)

二、pytorch DistributedParallel进行单机多卡训练

One_导入库:

import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

Two_进程初始化:

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
# 添加必要参数
# local_rank:系统自动赋予的进程编号,可以利用该编号控制打印输出以及设置device

torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile',
rank=local_rank, world_size=world_size)

# world_size:所创建的进程数,也就是所使用的GPU数量
# (初始化设置详见参考文档)

Three_数据分发:

dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
# 使用DistributedSampler来为各个进程分发数据,其中num_replicas与world_size保持一致,用于将数据集等分成不重叠的数个子集

dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler)
# 在Dataloader中指定sampler时,其中的shuffle必须为False,而DistributedSampler中的shuffle项默认为True,因此训练过程默认执行shuffle

Four_网络模型:

torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
# 设置每个进程对应的GPU设备

D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
# 由于在训练过程中各卡的前向后向传播均独立进行,因此无法进行统一的批归一化,如果想要将各卡的输出统一进行批归一化,需要将模型中的BN转换成SyncBN
   
D = torch.nn.parallel.DistributedDataParallel(
D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
# 如果有forward的返回值如果不在计算loss的计算图里,那么需要find_unused_parameters=True,即返回值不进入backward去算grad,也不需要在不同进程之间进行通信。

Five_迭代:

data_sampler.set_epoch(epoch)
# 每个epoch需要为sampler设置当前epoch

Six_加载:

dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
# 加载模型前后用dist.barrier()来同步不同进程间的快慢

Seven_启动:

CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
# 用-m torch.distributed.launch启动,nproc_per_node为所使用的卡数,batchsize设置为每张卡各自的批大小

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 迭代器和iter()函数详解及实例
Mar 21 Python
python使用pandas实现数据分割实例代码
Jan 25 Python
Python实现的自定义多线程多进程类示例
Mar 23 Python
python在html中插入简单的代码并加上时间戳的方法
Oct 16 Python
python ddt数据驱动最简实例代码
Feb 22 Python
Python自动化之数据驱动让你的脚本简洁10倍【推荐】
Jun 04 Python
通过PHP与Python代码对比的语法差异详解
Jul 10 Python
django admin组件使用方法详解
Jul 19 Python
Django框架静态文件使用/中间件/禁用ip功能实例详解
Jul 22 Python
python自动结束mysql慢查询会话的实例代码
Oct 27 Python
在Ubuntu 20.04中安装Pycharm 2020.1的图文教程
Apr 30 Python
python subprocess pipe 实时输出日志的操作
Dec 05 Python
Pytorch中的学习率衰减及其用法详解
Jun 05 #Python
pytorch finetuning 自己的图片进行训练操作
Jun 05 #Python
Python 如何将integer转化为罗马数(3999以内)
Jun 05 #Python
刚学完怎么用Python实现定时任务,转头就跑去撩妹!
OpenCV中resize函数插值算法的实现过程(五种)
Jun 05 #Python
OpenCV全景图像拼接的实现示例
opencv 分类白天与夜景视频的方法
You might like
PHP.MVC的模板标签系统(四)
2006/09/05 PHP
php.ini-dist 和 php.ini-recommended 的区别介绍(方便开发与安全的朋友)
2012/07/01 PHP
PHP实现批量生成App各种尺寸Logo
2015/03/19 PHP
php基于jquery的ajax技术传递json数据简单实例
2016/04/15 PHP
php metaphone()函数及php localeconv() 函数实例解析
2016/05/15 PHP
走出JavaScript初学困境—js初学
2008/12/29 Javascript
javascript form 验证函数 弹出对话框形式
2009/06/23 Javascript
页面图片浮动左右滑动效果的简单实现案例
2014/02/10 Javascript
NodeJS学习笔记之MongoDB模块
2015/01/13 NodeJs
jQuery实现导航滚动到指定内容效果完整实例【附demo源码下载】
2016/09/20 Javascript
vue2.0父子组件及非父子组件之间的通信方法
2017/01/21 Javascript
JS对象深度克隆实例分析
2017/03/16 Javascript
JS实现中文汉字按拼音排序的方法
2017/10/09 Javascript
JavaScript实现数字前补“0”的五种方法示例
2019/01/03 Javascript
ES6数组与对象的解构赋值详解
2019/06/14 Javascript
Python实现判断一行代码是否为注释的方法
2018/05/23 Python
Python使用装饰器模拟用户登陆验证功能示例
2018/08/24 Python
详解Django的CSRF认证实现
2018/10/09 Python
python实现文件的备份流程详解
2019/06/18 Python
利用python实现短信和电话提醒功能的例子
2019/08/08 Python
pygame实现俄罗斯方块游戏(基础篇2)
2019/10/29 Python
keras 自定义loss层+接受输入实例
2020/06/28 Python
基于CentOS搭建Python Django环境过程解析
2020/08/24 Python
HTML5页面音视频在微信和app下自动播放的实现方法
2016/10/20 HTML / CSS
美国中小型企业领先的办公家具供应商:Office Designs
2016/11/26 全球购物
阿迪达斯希腊官方网上商店:adidas希腊
2019/04/06 全球购物
建筑工程自我鉴定
2013/10/18 职场文书
财务部经理岗位职责
2014/02/03 职场文书
工作疏忽检讨书500字
2014/10/26 职场文书
2014年度考核工作总结
2014/12/24 职场文书
工作经验交流材料
2014/12/30 职场文书
新郎新娘答谢词
2015/01/04 职场文书
神龙架导游词
2015/02/11 职场文书
2015年度员工自我评价范文
2015/03/11 职场文书
奖学金主要事迹范文
2015/11/04 职场文书
离婚协议书范本(2016最新版)
2016/03/18 职场文书