MxNet预训练模型到Pytorch模型的转换方式


Posted in Python onMay 25, 2020

预训练模型在不同深度学习框架中的转换是一种常见的任务。今天刚好DPN预训练模型转换问题,顺手将这个过程记录一下。

核心转换函数如下所示:

def convert_from_mxnet(model, checkpoint_prefix, debug=False):
 _, mxnet_weights, mxnet_aux = mxnet.model.load_checkpoint(checkpoint_prefix, 0)
 remapped_state = {}
 for state_key in model.state_dict().keys():
  k = state_key.split('.')
  aux = False
  mxnet_key = ''
  if k[0] == 'features':
   if k[1] == 'conv1_1':
    # input block
    mxnet_key += 'conv1_x_1__'
    if k[2] == 'bn':
     mxnet_key += 'relu-sp__bn_'
     aux, key_add = _convert_bn(k[3])
     mxnet_key += key_add
    else:
     assert k[3] == 'weight'
     mxnet_key += 'conv_' + k[3]
   elif k[1] == 'conv5_bn_ac':
    # bn + ac at end of features block
    mxnet_key += 'conv5_x_x__relu-sp__bn_'
    assert k[2] == 'bn'
    aux, key_add = _convert_bn(k[3])
    mxnet_key += key_add
   else:
    # middle blocks
    if model.b and 'c1x1_c' in k[2]:
     bc_block = True # b-variant split c-block special treatment
    else:
     bc_block = False
    ck = k[1].split('_')
    mxnet_key += ck[0] + '_x__' + ck[1] + '_'
    ck = k[2].split('_')
    mxnet_key += ck[0] + '-' + ck[1]
    if ck[1] == 'w' and len(ck) > 2:
     mxnet_key += '(s/2)' if ck[2] == 's2' else '(s/1)'
    mxnet_key += '__'
    if k[3] == 'bn':
     mxnet_key += 'bn_' if bc_block else 'bn__bn_'
     aux, key_add = _convert_bn(k[4])
     mxnet_key += key_add
    else:
     ki = 3 if bc_block else 4
     assert k[ki] == 'weight'
     mxnet_key += 'conv_' + k[ki]
  elif k[0] == 'classifier':
   if 'fc6-1k_weight' in mxnet_weights:
    mxnet_key += 'fc6-1k_'
   else:
    mxnet_key += 'fc6_'
   mxnet_key += k[1]
  else:
   assert False, 'Unexpected token'
 
  if debug:
   print(mxnet_key, '=> ', state_key, end=' ')
 
  mxnet_array = mxnet_aux[mxnet_key] if aux else mxnet_weights[mxnet_key]
  torch_tensor = torch.from_numpy(mxnet_array.asnumpy())
  if k[0] == 'classifier' and k[1] == 'weight':
   torch_tensor = torch_tensor.view(torch_tensor.size() + (1, 1))
  remapped_state[state_key] = torch_tensor
 
  if debug:
   print(list(torch_tensor.size()), torch_tensor.mean(), torch_tensor.std())
 
 model.load_state_dict(remapped_state)
 
 return model

从中可以看出,其转换步骤如下:

(1)创建pytorch的网络结构模型,设为model

(2)利用mxnet来读取其存储的预训练模型,得到mxnet_weights;

(3)遍历加载后模型mxnet_weights的state_dict().keys

(4)对一些指定的key值,需要进行相应的处理和转换

(5)对修改键名之后的key利用numpy之间的转换来实现加载。

为了实现上述转换,首先pip安装mxnet,现在新版的mxnet安装还是非常方便的。

MxNet预训练模型到Pytorch模型的转换方式

第二步,运行转换程序,实现预训练模型的转换。

MxNet预训练模型到Pytorch模型的转换方式

可以看到在相当的文件夹下已经出现了转换后的模型。

以上这篇MxNet预训练模型到Pytorch模型的转换方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中用PIL库批量给图片加上序号的教程
May 06 Python
使用python编写简单的小程序编译成exe跑在win10上
Jan 15 Python
Linux下python与C++使用dlib实现人脸检测
Jun 29 Python
使用python的pandas库读取csv文件保存至mysql数据库
Aug 20 Python
Python列表list排列组合操作示例
Dec 18 Python
对Python3使运行暂停的方法详解
Feb 18 Python
使用python制作一个为hex文件增加版本号的脚本实例
Jun 12 Python
python操作kafka实践的示例代码
Jun 19 Python
tornado+celery的简单使用详解
Dec 21 Python
Python Tornado核心及相关原理详解
Jun 24 Python
Pycharm添加虚拟解释器报错问题解决方案
Oct 13 Python
将Python代码打包成.exe可执行文件的完整步骤
May 12 Python
浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式
May 25 #Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
May 25 #Python
tensorflow pb to tflite 精度下降详解
May 25 #Python
Python HTMLTestRunner测试报告view按钮失效解决方案
May 25 #Python
python用opencv完成图像分割并进行目标物的提取
May 25 #Python
Pytorch转tflite方式
May 25 #Python
Python HTMLTestRunner库安装过程解析
May 25 #Python
You might like
PHP 数字左侧自动补0
2008/03/31 PHP
iis下php mail函数的sendmail配置方法(官方推荐)
2012/04/25 PHP
解析在zend Farmework下如何创立一个FORM表单
2013/06/28 PHP
PHP实现图片压缩的两则实例
2014/07/19 PHP
理解php依赖注入和控制反转
2016/05/11 PHP
Zend Framework教程之Zend_Helpers动作助手ViewRenderer用法详解
2016/07/20 PHP
laravel实现按时间日期进行分组统计方法示例
2019/03/23 PHP
屏蔽IE弹出"您查看的网页正在试图关闭窗口,是否关闭此窗口"的方法
2013/12/31 Javascript
JS使用ajax方法获取指定url的head信息中指定字段值的方法
2015/03/24 Javascript
JavaScript实现数字数组按照倒序排列的方法
2015/04/06 Javascript
javascript实现的图片切割多块效果实例
2015/05/07 Javascript
jquery ajax后台返回list,前台用jquery遍历list的实现
2016/10/30 Javascript
微信小程序 获取当前地理位置和经纬度实例代码
2016/12/05 Javascript
JavaScript实现精美个性导航栏筋斗云效果
2017/10/29 Javascript
node 使用 async 控制并发的方法
2018/05/07 Javascript
Layui给switch添加响应事件的例子
2019/09/03 Javascript
JS FormData对象使用方法实例详解
2020/02/12 Javascript
vue使用keep-alive实现组件切换时保存原组件数据方法
2020/10/30 Javascript
利用JavaScript为句子加标题的3种方法示例
2021/01/05 Javascript
python获取本地计算机名字的方法
2015/04/29 Python
完美解决安装完tensorflow后pip无法使用的问题
2018/06/11 Python
强悍的Python读取大文件的解决方案
2019/02/16 Python
Python和Sublime整合过程图示
2019/12/25 Python
Python内置异常类型全面汇总
2020/05/28 Python
Python中的特殊方法以及应用详解
2020/09/20 Python
关于老式浏览器兼容HTML5和CSS3的问题
2016/06/01 HTML / CSS
美国男女折扣服饰百货连锁店:Stein Mart
2017/05/02 全球购物
发现世界上最好的珠宝设计师:JewelStreet
2017/12/17 全球购物
高中自我鉴定
2013/12/20 职场文书
工程力学专业自荐信范文
2014/03/17 职场文书
艺术学院毕业生求职信
2014/07/09 职场文书
新颖的化妆品活动方案
2014/08/21 职场文书
实验心得体会
2014/09/05 职场文书
行政工作试用期自我评价
2014/09/14 职场文书
给校长的一封检讨书
2014/09/20 职场文书
利用js实现简单开关灯代码
2021/11/23 Javascript