pytorch实现加载保存查看checkpoint文件


Posted in Python onJuly 15, 2022

1.保存加载checkpoint文件

# 方式一:保存加载整个state_dict(推荐)
# 保存
torch.save(model.state_dict(), PATH)
# 加载
model.load_state_dict(torch.load(PATH))
# 测试时不启用 BatchNormalization 和 Dropout
model.eval()
# 方式二:保存加载整个模型
# 保存
torch.save(model, PATH)
# 加载
model = torch.load(PATH)
model.eval()
# 方式三:保存用于继续训练的checkpoint或者多个模型
# 保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            ...
            }, PATH)
# 加载
checkpoint = torch.load(PATH)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
# 测试时
model.eval()
# 或者训练时
model.train()

2.跨gpu和cpu

# GPU上保存,CPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device('cpu')
model.load_state_dict(torch.load(PATH, map_location=device))
# 如果是多gpu保存,需要去除关键字中的module,见第4部分
# GPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
model.load_state_dict(torch.load(PATH))
model.to(device)
# CPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
# 选择希望使用的GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  
model.to(device)

3.查看checkpoint文件内容

# 打印模型的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

4.常见问题

多gpu

报错为KeyError: ‘unexpected key “module.conv1.weight” in state_dict’

原因:当使用多gpu时,会使用torch.nn.DataParallel,所以checkpoint中有module字样

#解决1:加载时将module去掉

# 创建一个不包含`module.`的新OrderedDict
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # 去掉 `module.`
    new_state_dict[name] = v
# 加载参数
model.load_state_dict(new_state_dict)
# 解决2:保存checkpoint时不保存module
torch.save(model.module.state_dict(), PATH)

pytorch保存和加载文件的方法,从断点处继续训练

'''本文件用于举例说明pytorch保存和加载文件的方法''' 
import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import os
  
# 参数声明
batch_size = 32
epochs = 10
WORKERS = 0  # dataloder线程数
test_flag = False  # 测试标志,True时加载保存好的模型进行测试
ROOT = '/home/pxt/pytorch/cifar'  # MNIST数据集保存路径
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # 模型保存路径
# 加载MNIST数据集
transform = tv.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
 
train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)
 
train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)
 
 
# 构造模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  
model = Net().cpu()
 
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
 
 
# 模型训练
def train(model, train_loader, epoch):
    model.train()
    train_loss = 0
    for i, data in enumerate(train_loader, 0):
        x, y = data
        x = x.cpu()
        y = y.cpu()
 
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss
        print('正在进行第{}个epoch中的第{}次循环'.format(epoch,i))
 
    loss_mean = train_loss / (i + 1)
    print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
 
 
# 模型测试
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            x, y = data
            x = x.cpu()
            y = y.cpu()
 
            optimizer.zero_grad()
            y_hat = model(x)
            test_loss += criterion(y_hat, y).item()
            pred = y_hat.max(1, keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
        test_loss /= (i + 1)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data), 100. * correct / len(test_data)))
  
def main():
    # 如果test_flag=True,则加载已保存的模型并进行测试,测试以后不进行此模块以后的步骤
    if test_flag:
        # 加载保存的模型直接进行测试机验证
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        test(model, test_load)
        return
 
    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存了的模型,将从头开始训练!')
 
    for epoch in range(start_epoch+1, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        # 保存模型
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)
 
if __name__ == '__main__':
    main()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用python写的一个wordpress的采集程序
Feb 27 Python
利用标准库fractions模块让Python支持分数类型的方法详解
Aug 11 Python
python 提取key 为中文的json 串方法
Dec 31 Python
pycharm打开命令行或Terminal的方法
Jan 16 Python
python的pytest框架之命令行参数详解(下)
Jun 27 Python
python实现批量nii文件转换为png图像
Jul 18 Python
python opencv鼠标事件实现画框圈定目标获取坐标信息
Apr 18 Python
详解python中docx库的安装过程
Nov 08 Python
python 中值滤波,椒盐去噪,图片增强实例
Dec 18 Python
Pandas —— resample()重采样和asfreq()频度转换方式
Feb 26 Python
如何解决pycharm调试报错的问题
Aug 06 Python
详解Python利用configparser对配置文件进行读写操作
Nov 03 Python
pytest实现多进程与多线程运行超好用的插件
Jul 15 #Python
python如何将mat文件转为png
Jul 15 #Python
python读取mat文件生成h5文件的实现
Jul 15 #Python
全网非常详细的pytest配置文件
Jul 15 #Python
Python如何加载模型并查看网络
Jul 15 #Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
You might like
php对gzip文件或者字符串解压实例参考
2008/07/25 PHP
JS 网站性能优化笔记
2011/05/24 PHP
Yii框架核心组件类实例详解
2019/08/06 PHP
PHP获取当前时间不准确问题解决方案
2020/08/14 PHP
一个很酷的拖动层的js类,兼容IE及Firefox
2009/06/23 Javascript
动态的改变IFrame的高度实现IFrame自动伸展适应高度
2012/12/28 Javascript
jquery入门—选择器实现隔行变色实例代码
2013/01/04 Javascript
当滚动条滚动到页面底部自动加载增加内容的js代码
2014/05/13 Javascript
javascript获取隐藏元素(display:none)的高度和宽度的方法
2014/06/06 Javascript
jQuery中关于ScrollableGridPlugin.js(固定表头)插件的使用逐步解析
2014/07/17 Javascript
get(0).tagName获得作用标签示例代码
2014/10/08 Javascript
jquery移动节点实例
2015/01/14 Javascript
jquery图片轮播特效代码分享
2020/04/20 Javascript
BootStrap Table后台分页时前台删除最后一页所有数据refresh刷新后无数据问题
2016/12/28 Javascript
完美解决node.js中使用https请求报CERT_UNTRUSTED的问题
2017/01/08 Javascript
Vue表单验证插件的制作过程
2017/04/01 Javascript
JavaScript门面模式详解
2017/10/19 Javascript
NodeJS实现自定义流的方法
2018/08/01 NodeJs
vue实现日历备忘录功能
2020/09/24 Javascript
[54:19]完美世界DOTA2联赛PWL S2 Magma vs PXG 第二场 11.28
2020/12/01 DOTA
pygame学习笔记(3):运动速率、时间、事件、文字
2015/04/15 Python
正确理解Python中if __name__ == '__main__'
2019/01/24 Python
python代码打印100-999之间的回文数示例
2019/11/24 Python
python实现udp聊天窗口
2020/03/31 Python
快速解决jupyter notebook启动需要密码的问题
2020/04/21 Python
CentOS 7如何实现定时执行python脚本
2020/06/24 Python
python入门教程之基本算术运算符
2020/11/13 Python
recorder.js 基于Html5录音功能的实现
2020/05/26 HTML / CSS
Timberland美国官网:全球领先的户外品牌
2016/08/15 全球购物
雅诗兰黛(Estee Lauder)英国官方网站:世界顶级化妆品牌
2016/12/29 全球购物
Superdry瑞典官网:英国日本街头风品牌
2017/05/17 全球购物
澳大利亚首屈一指的鞋类品牌:Tony Bianco
2018/03/13 全球购物
JAVA的事件委托机制和垃圾回收机制
2014/09/07 面试题
大学生求职自我评价
2014/01/16 职场文书
2014年语文教研组工作总结
2014/12/06 职场文书
入党积极分子党支部意见
2015/06/02 职场文书