pytorch实现加载保存查看checkpoint文件


Posted in Python onJuly 15, 2022

1.保存加载checkpoint文件

# 方式一:保存加载整个state_dict(推荐)
# 保存
torch.save(model.state_dict(), PATH)
# 加载
model.load_state_dict(torch.load(PATH))
# 测试时不启用 BatchNormalization 和 Dropout
model.eval()
# 方式二:保存加载整个模型
# 保存
torch.save(model, PATH)
# 加载
model = torch.load(PATH)
model.eval()
# 方式三:保存用于继续训练的checkpoint或者多个模型
# 保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            ...
            }, PATH)
# 加载
checkpoint = torch.load(PATH)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
# 测试时
model.eval()
# 或者训练时
model.train()

2.跨gpu和cpu

# GPU上保存,CPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device('cpu')
model.load_state_dict(torch.load(PATH, map_location=device))
# 如果是多gpu保存,需要去除关键字中的module,见第4部分
# GPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
model.load_state_dict(torch.load(PATH))
model.to(device)
# CPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
# 选择希望使用的GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  
model.to(device)

3.查看checkpoint文件内容

# 打印模型的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

4.常见问题

多gpu

报错为KeyError: ‘unexpected key “module.conv1.weight” in state_dict’

原因:当使用多gpu时,会使用torch.nn.DataParallel,所以checkpoint中有module字样

#解决1:加载时将module去掉

# 创建一个不包含`module.`的新OrderedDict
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # 去掉 `module.`
    new_state_dict[name] = v
# 加载参数
model.load_state_dict(new_state_dict)
# 解决2:保存checkpoint时不保存module
torch.save(model.module.state_dict(), PATH)

pytorch保存和加载文件的方法,从断点处继续训练

'''本文件用于举例说明pytorch保存和加载文件的方法''' 
import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import os
  
# 参数声明
batch_size = 32
epochs = 10
WORKERS = 0  # dataloder线程数
test_flag = False  # 测试标志,True时加载保存好的模型进行测试
ROOT = '/home/pxt/pytorch/cifar'  # MNIST数据集保存路径
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # 模型保存路径
# 加载MNIST数据集
transform = tv.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
 
train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)
 
train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)
 
 
# 构造模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  
model = Net().cpu()
 
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
 
 
# 模型训练
def train(model, train_loader, epoch):
    model.train()
    train_loss = 0
    for i, data in enumerate(train_loader, 0):
        x, y = data
        x = x.cpu()
        y = y.cpu()
 
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss
        print('正在进行第{}个epoch中的第{}次循环'.format(epoch,i))
 
    loss_mean = train_loss / (i + 1)
    print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
 
 
# 模型测试
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            x, y = data
            x = x.cpu()
            y = y.cpu()
 
            optimizer.zero_grad()
            y_hat = model(x)
            test_loss += criterion(y_hat, y).item()
            pred = y_hat.max(1, keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
        test_loss /= (i + 1)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data), 100. * correct / len(test_data)))
  
def main():
    # 如果test_flag=True,则加载已保存的模型并进行测试,测试以后不进行此模块以后的步骤
    if test_flag:
        # 加载保存的模型直接进行测试机验证
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        test(model, test_load)
        return
 
    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存了的模型,将从头开始训练!')
 
    for epoch in range(start_epoch+1, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        # 保存模型
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)
 
if __name__ == '__main__':
    main()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python Tkinter基础控件用法
Sep 03 Python
Python 类与元类的深度挖掘 I【经验】
May 06 Python
深入浅析ImageMagick命令执行漏洞
Oct 11 Python
对numpy和pandas中数组的合并和拆分详解
Apr 11 Python
python 识别图片中的文字信息方法
May 10 Python
python使用xlrd模块读取xlsx文件中的ip方法
Jan 11 Python
对Python3中列表乘以某一个数的示例详解
Jul 20 Python
详解Python 4.0 预计推出的新功能
Jul 26 Python
python 追踪except信息方式
Apr 25 Python
Python爬取阿拉丁统计信息过程图解
May 12 Python
python中wheel的用法整理
Jun 15 Python
python集合能干吗
Jul 19 Python
pytest实现多进程与多线程运行超好用的插件
Jul 15 #Python
python如何将mat文件转为png
Jul 15 #Python
python读取mat文件生成h5文件的实现
Jul 15 #Python
全网非常详细的pytest配置文件
Jul 15 #Python
Python如何加载模型并查看网络
Jul 15 #Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
You might like
PHP中redis的用法深入解析
2014/02/20 PHP
thinkphp实现图片上传功能分享
2014/03/04 PHP
php使用glob函数快速查询指定目录文件的方法
2014/11/15 PHP
Ajax提交表单时验证码自动验证 php后端验证码检测
2016/07/20 PHP
PHP+Redis链表解决高并发下商品超卖问题(实现原理及步骤)
2020/08/03 PHP
java script编程起步(第三课)
2007/01/10 Javascript
详解JavaScript对Date对象的操作问题(生成一个倒数7天的数组)
2015/10/01 Javascript
JavaScript+html5 canvas绘制的圆弧荡秋千效果完整实例
2016/01/26 Javascript
jQuery+Ajax+PHP弹出层异步登录效果(附源码下载)
2016/05/27 Javascript
深入解析Javascript闭包的功能及实现方法
2016/07/10 Javascript
js中 计算两个日期间的工作日的简单实例
2016/08/08 Javascript
把多个JavaScript函数绑定到onload事件处理函数上的方法
2016/09/04 Javascript
利用js来实现缩略语列表、文献来源链接和快捷键列表
2016/12/16 Javascript
JS动态遍历json中所有键值对的方法(不知道属性名的情况)
2016/12/28 Javascript
javascript设计模式之单体模式学习笔记
2017/02/15 Javascript
JS字符串长度判断,超出进行自动截取的实例(支持中文)
2017/03/06 Javascript
B/S(Web)实时通讯解决方案分享
2017/04/06 Javascript
Vue SPA单页应用首屏优化实践
2018/06/28 Javascript
微信小程序签到功能
2018/10/31 Javascript
vue.js实现三级菜单效果
2019/10/19 Javascript
原生JavaScript实现贪吃蛇游戏
2020/11/04 Javascript
[40:05]DOTA2上海特级锦标赛A组小组赛#1 EHOME VS MVP.Phx第一局
2016/02/25 DOTA
Python正则表达式匹配ip地址实例
2014/10/09 Python
Python之两种模式的生产者消费者模型详解
2018/10/26 Python
python图像处理入门(一)
2019/04/04 Python
Python3将数据保存为txt文件的方法
2019/09/12 Python
django-orm F对象的使用 按照两个字段的和,乘积排序实例
2020/05/18 Python
CSS3图片旋转特效(360/60/-360度)
2013/10/10 HTML / CSS
MATCHESFASHION.COM美国官网:英国奢侈品零售商
2018/10/29 全球购物
某公司.Net方向面试题
2014/04/24 面试题
外贸英语毕业生自荐信
2013/11/14 职场文书
网络教育自我鉴定
2014/02/04 职场文书
大学生逃课检讨书
2015/05/04 职场文书
市语委办2016年第十九届“推普周”活动总结
2016/04/05 职场文书
亲情作文之母爱
2019/09/25 职场文书
WordPress多语言翻译插件 - WPML使用教程
2021/04/01 PHP