pytorch fine-tune 预训练的模型操作


Posted in Python onJune 03, 2021

之一:

torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-tune torchvision 中预训练好的模型。

安装

pip install torchvision

如何 fine-tune

以 resnet18 为例:

from torchvision import models
from torch import nn
from torch import optim
 
resnet_model = models.resnet18(pretrained=True) 
# pretrained 设置为 True,会自动下载模型 所对应权重,并加载到模型中
# 也可以自己下载 权重,然后 load 到 模型中,源码中有 权重的地址。
 
# 假设 我们的 分类任务只需要 分 100 类,那么我们应该做的是
# 1. 查看 resnet 的源码
# 2. 看最后一层的 名字是啥 (在 resnet 里是 self.fc = nn.Linear(512 * block.expansion, num_classes))
# 3. 在外面替换掉这个层
resnet_model.fc= nn.Linear(in_features=..., out_features=100)
 
# 这样就 哦了,修改后的模型除了输出层的参数是 随机初始化的,其他层都是用预训练的参数初始化的。
 
# 如果只想训练 最后一层的话,应该做的是:
# 1. 将其它层的参数 requires_grad 设置为 False
# 2. 构建一个 optimizer, optimizer 管理的参数只有最后一层的参数
# 3. 然后 backward, step 就可以了
 
# 这一步可以节省大量的时间,因为多数的参数不需要计算梯度
for para in list(resnet_model.parameters())[:-2]:
    para.requires_grad=False 
 
optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3)
 
...

为什么

这里介绍下 运行resnet_model.fc= nn.Linear(in_features=..., out_features=100)时 框架内发生了什么

这时应该看 nn.Module 源码的 __setattr__ 部分,因为 setattr 时都会调用这个方法:

def __setattr__(self, name, value):
    def remove_from(*dicts):
        for d in dicts:
            if name in d:
                del d[name]

首先映入眼帘就是 remove_from 这个函数,这个函数的目的就是,如果出现了 同名的属性,就将旧的属性移除。 用刚才举的例子就是:

预训练的模型中 有个 名字叫fc 的 Module。

在类定义外,我们 将另一个 Module 重新 赋值给了 fc。

类定义内的 fc 对应的 Module 就会从 模型中 删除。

之二:

前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。

pytorch fine-tune 预训练的模型操作

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型
    if isinstance(m, nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
    elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
        m.weight.data.fill_(1)
        m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)
 
# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
                     model.parameters())
 
optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': model.fc.parameters(), 'lr': 1e-3}
            ], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

之三:

pytorch finetune模型

文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。

pytorch 模型的存储与读取

其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的

单独存储模型参数

存储时使用:

torch.save(the_model.state_dict(), PATH)

读取时:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

存储模型与参数

存储:

torch.save(the_model, PATH)

读取:

the_model = torch.load(PATH)

模型的参数

fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。

pytorch模型参数的形式

模型的参数是以字典的形式存储的。

model_dict = the_model.state_dict(),
for k,v in model_dict.items():
    print(k)

即可看到所有的键值

如果想修改模型的参数,给相应的键值赋值即可

model_dict[k] = new_value

最后更新模型的参数

the_model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是一样的

我们可以通过下列算法进行读取模型

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
 # 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的

model_dict = model.state_dict() 
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
    keys.append(k)
i = 0
for k,v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
        print(k, ',', keys[i])
         model_dict[k]=pretrained_dict[keys[i]]
    i = i + 1
model.load_state_dict(model_dict)

如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的

自己找对应关系,一个key对应一个key的赋值

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
10个易被忽视但应掌握的Python基本用法
Apr 01 Python
详解Python中for循环的使用方法
May 14 Python
浅谈Python中的闭包
Jul 08 Python
特征脸(Eigenface)理论基础之PCA主成分分析法
Mar 13 Python
python web基础之加载静态文件实例
Mar 20 Python
python3.6.3+opencv3.3.0实现动态人脸捕获
May 25 Python
详解python的argpare和click模块小结
Mar 31 Python
Python定义函数功能与用法实例详解
Apr 08 Python
Python递归函数特点及原理解析
Mar 04 Python
一文轻松掌握python语言命名规范规则
Jun 18 Python
详解python中GPU版本的opencv常用方法介绍
Jul 24 Python
python 匿名函数与三元运算学习笔记
Oct 23 Python
Python实现byte转integer
Jun 03 #Python
Python数据分析之绘图和可视化详解
Python数据分析之pandas读取数据
Jun 02 #Python
Python 如何实现文件自动去重
python状态机transitions库详解
Jun 02 #Python
python爬取某网站原图作为壁纸
Python爬虫之自动爬取某车之家各车销售数据
You might like
PHP 编写大型网站问题集
2010/05/07 PHP
ThinkPHP模板判断输出Defined标签用法详解
2014/06/30 PHP
简单分析ucenter 会员同步登录通信原理
2014/08/25 PHP
Zend Framework缓存Cache用法简单实例
2016/03/19 PHP
PHP封装的MSSql操作类完整实例
2016/05/26 PHP
PHP开发的微信现金红包功能示例
2017/06/29 PHP
PHP完全二叉树定义与实现方法示例
2017/10/09 PHP
PHP 7.4中使用预加载的方法详解
2019/07/08 PHP
Laravel 创建可以传递参数 Console服务的例子
2019/10/14 PHP
laravel框架邮箱认证实现方法详解
2019/11/22 PHP
有趣的javascript数组定义方法
2010/09/10 Javascript
js与jquery获取父元素,删除子元素的两种不同方法
2014/01/09 Javascript
jQuery中closest()函数用法实例
2015/01/07 Javascript
轻松学习jQuery插件EasyUI EasyUI创建树形网络(1)
2015/11/30 Javascript
Javascript同时声明一连串(多个)变量的方法
2017/01/23 Javascript
理解 JavaScript EventEmitter
2018/03/29 Javascript
React SSR样式及SEO的实践
2018/10/22 Javascript
vue+SSM实现验证码功能
2018/12/07 Javascript
详解Python中映射类型(字典)操作符的概念和使用
2015/08/19 Python
总结网络IO模型与select模型的Python实例讲解
2016/06/27 Python
Python 字符串换行的多种方式
2018/09/06 Python
Django实现发送邮件功能
2019/07/18 Python
python for循环remove同一个list过程解析
2019/08/14 Python
使用Pyhton 分析酒店针孔摄像头
2020/03/04 Python
OpenCV 表盘指针自动读数的示例代码
2020/04/10 Python
pandas针对excel处理的实现
2021/01/15 Python
解决margin 外边距合并问题
2019/07/03 HTML / CSS
html5移动端自适应布局的实现
2020/04/15 HTML / CSS
Yahoo-PHP面试题3
2012/01/14 面试题
大二自我鉴定
2014/01/31 职场文书
冰淇淋开店创业计划书
2014/02/01 职场文书
保护环境标语
2014/06/09 职场文书
小学感恩节活动策划方案
2014/10/06 职场文书
回复函范文
2015/07/14 职场文书
如何使用Maxwell实时同步mysql数据
2021/04/08 MySQL
mysql中整数数据类型tinyint详解
2021/12/06 MySQL