pytorch 使用半精度模型部署的操作


Posted in Python onMay 24, 2021

背景

pytorch作为深度学习的计算框架正得到越来越多的应用.

我们除了在模型训练阶段应用外,最近也把pytorch应用在了部署上.

在部署时,为了减少计算量,可以考虑使用16位浮点模型,而训练时涉及到梯度计算,需要使用32位浮点,这种精度的不一致经过测试,模型性能下降有限,可以接受.

但是推断时计算量可以降低一半,同等计算资源下,并发度可提升近一倍

具体方法

在pytorch中,一般模型定义都继承torch.nn.Moudle,torch.nn.Module基类的half()方法会把所有参数转为16位浮点,所以在模型加载后,调用一下该方法即可达到模型切换的目的.接下来只需要在推断时把input的tensor切换为16位浮点即可

另外还有一个小的trick,在推理过程中模型输出的tensor自然会成为16位浮点,如果需要新创建tensor,最好调用已有tensor的new_zeros,new_full等方法而不是torch.zeros和torch.full,前者可以自动继承已有tensor的类型,这样就不需要到处增加代码判断是使用16位还是32位了,只需要针对input tensor切换.

补充:pytorch 使用amp.autocast半精度加速训练

准备工作

pytorch 1.6+

如何使用autocast?

根据官方提供的方法,

答案就是autocast + GradScaler。

1,autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的:

如何在PyTorch中使用自动混合精度?

答案:autocast + GradScaler。

1.autocast

正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # 前向过程(model + loss)开启 autocast
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # 反向传播在autocast上下文之外
    loss.backward()
    optimizer.step()

2.GradScaler

GradScaler就是梯度scaler模块,需要在训练最开始之前实例化一个GradScaler对象。

因此PyTorch中经典的AMP使用方式如下:

from torch.cuda.amp import autocast as autocast

# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # 前向过程(model + loss)开启 autocast
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

3.nn.DataParallel

单卡训练的话上面的代码已经够了,亲测在2080ti上能减少至少1/3的显存,至于速度。。。

要是想多卡跑的话仅仅这样还不够,会发现在forward里面的每个结果都还是float32的,怎么办?

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, input_data_c1):
     with autocast():
      # code
     return

只要把forward里面的代码用autocast代码块方式运行就好啦!

自动进行autocast的操作

如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:

1、matmul

2、addbmm

3、addmm

4、addmv

5、addr

6、baddbmm

7、bmm

8、chain_matmul

9、conv1d

10、conv2d

11、conv3d

12、conv_transpose1d

13、conv_transpose2d

14、conv_transpose3d

15、linear

16、matmul

17、mm

18、mv

19、prelu

那么只有这些操作才能半精度吗?不是。其他操作比如rnn也可以进行半精度运行,但是需要自己手动,暂时没有提供自动的转换。

Python 相关文章推荐
在Docker上开始部署Python应用的教程
Apr 17 Python
Python操作MongoDB数据库PyMongo库使用方法
Apr 27 Python
Python调用C++程序的方法详解
Jan 24 Python
CentOS 6.5中安装Python 3.6.2的方法步骤
Dec 03 Python
对Python3 解析html的几种操作方式小结
Feb 16 Python
基于Python打造账号共享浏览器功能
May 30 Python
python文件选择对话框的操作方法
Jun 27 Python
Django生成PDF文档显示在网页上以及解决PDF中文显示乱码的问题
Jul 04 Python
如何利用Pyecharts可视化微信好友
Jul 04 Python
python可视化篇之流式数据监控的实现
Aug 07 Python
python实现密码验证合格程序的思路详解
Jun 01 Python
Python3爬虫ChromeDriver的安装实例
Feb 06 Python
解决Pytorch半精度浮点型网络训练的问题
May 24 #Python
Python办公自动化之Excel(中)
May 24 #Python
PyTorch梯度裁剪避免训练loss nan的操作
May 24 #Python
python3读取文件指定行的三种方法
May 24 #Python
pytorch中Schedule与warmup_steps的用法说明
May 24 #Python
Python Pycharm虚拟下百度飞浆PaddleX安装报错问题及处理方法(亲测100%有效)
May 24 #Python
pytorch交叉熵损失函数的weight参数的使用
May 24 #Python
You might like
php5.3中连接sqlserver2000的两种方法(com与ODBC)
2012/12/29 PHP
深入理解PHP中mt_rand()随机数的安全
2017/10/12 PHP
PHP实现的只保留字符串首尾字符功能示例【隐藏部分字符串】
2019/03/11 PHP
PHP5.6.8连接SQL Server 2008 R2数据库常用技巧分析总结
2019/05/06 PHP
Javascript 入门基础学习
2010/03/10 Javascript
Highcharts 非常实用的Javascript统计图demo示例
2013/07/03 Javascript
弹出窗口并且此窗口带有半透明的遮罩层效果
2014/03/13 Javascript
使用jQuery设置disabled属性与移除disabled属性
2014/08/21 Javascript
node.js中的querystring.unescape方法使用说明
2014/12/10 Javascript
JS实现向表格行添加新单元格的方法
2015/03/30 Javascript
解决jQuery uploadify在非IE核心浏览器下无法上传
2015/08/05 Javascript
浅谈JS中String()与 .toString()的区别
2016/10/20 Javascript
通过sails和阿里大于实现短信验证
2017/01/04 Javascript
Bootstrap禁用响应式布局的实现方法
2017/03/09 Javascript
详解关于webpack多入口热加载很慢的原因
2019/04/24 Javascript
vue路由 遍历生成复数router-link的例子
2019/10/30 Javascript
[01:09:23]KG vs TNC 2019国际邀请赛小组赛 BO2 第一场 8.15
2019/08/16 DOTA
Python二分查找详解
2015/09/13 Python
python编程开发之类型转换convert实例分析
2015/11/13 Python
python 获取网页编码方式实现代码
2017/03/11 Python
Python定时任务随机时间执行的实现方法
2019/08/14 Python
Python GUI库Tkiner使用方法代码示例
2020/11/27 Python
canvas绘制视频封面的方法
2018/02/05 HTML / CSS
美国艺术和工艺品商店:Hobby Lobby
2020/12/09 全球购物
经贸日语专业个人求职信范文
2013/12/28 职场文书
校园餐饮创业计划书
2014/01/10 职场文书
回门宴新郎答谢词
2014/01/12 职场文书
模具专业毕业生自荐书范文
2014/02/19 职场文书
新品发布会主持词
2014/04/02 职场文书
社团活动总结怎么写
2014/06/30 职场文书
欢迎新生标语
2014/10/06 职场文书
民主评议党员个人自我评价
2015/03/03 职场文书
2016应届毕业生实习心得体会
2015/10/09 职场文书
高中物理教学反思
2016/02/19 职场文书
css3 filter属性的使用简介
2021/03/31 HTML / CSS
Vue中插槽slot的使用方法与应用场景详析
2021/06/08 Vue.js