pytorch实现加载保存查看checkpoint文件


Posted in Python onJuly 15, 2022

1.保存加载checkpoint文件

# 方式一:保存加载整个state_dict(推荐)
# 保存
torch.save(model.state_dict(), PATH)
# 加载
model.load_state_dict(torch.load(PATH))
# 测试时不启用 BatchNormalization 和 Dropout
model.eval()
# 方式二:保存加载整个模型
# 保存
torch.save(model, PATH)
# 加载
model = torch.load(PATH)
model.eval()
# 方式三:保存用于继续训练的checkpoint或者多个模型
# 保存
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            ...
            }, PATH)
# 加载
checkpoint = torch.load(PATH)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
# 测试时
model.eval()
# 或者训练时
model.train()

2.跨gpu和cpu

# GPU上保存,CPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device('cpu')
model.load_state_dict(torch.load(PATH, map_location=device))
# 如果是多gpu保存,需要去除关键字中的module,见第4部分
# GPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
model.load_state_dict(torch.load(PATH))
model.to(device)
# CPU上保存,GPU上加载
# 保存
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device("cuda")
# 选择希望使用的GPU
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  
model.to(device)

3.查看checkpoint文件内容

# 打印模型的 state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

4.常见问题

多gpu

报错为KeyError: ‘unexpected key “module.conv1.weight” in state_dict’

原因:当使用多gpu时,会使用torch.nn.DataParallel,所以checkpoint中有module字样

#解决1:加载时将module去掉

# 创建一个不包含`module.`的新OrderedDict
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # 去掉 `module.`
    new_state_dict[name] = v
# 加载参数
model.load_state_dict(new_state_dict)
# 解决2:保存checkpoint时不保存module
torch.save(model.module.state_dict(), PATH)

pytorch保存和加载文件的方法,从断点处继续训练

'''本文件用于举例说明pytorch保存和加载文件的方法''' 
import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import os
  
# 参数声明
batch_size = 32
epochs = 10
WORKERS = 0  # dataloder线程数
test_flag = False  # 测试标志,True时加载保存好的模型进行测试
ROOT = '/home/pxt/pytorch/cifar'  # MNIST数据集保存路径
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  # 模型保存路径
# 加载MNIST数据集
transform = tv.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
 
train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)
 
train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)
 
 
# 构造模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
  
model = Net().cpu()
 
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
 
 
# 模型训练
def train(model, train_loader, epoch):
    model.train()
    train_loss = 0
    for i, data in enumerate(train_loader, 0):
        x, y = data
        x = x.cpu()
        y = y.cpu()
 
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss
        print('正在进行第{}个epoch中的第{}次循环'.format(epoch,i))
 
    loss_mean = train_loss / (i + 1)
    print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
 
 
# 模型测试
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            x, y = data
            x = x.cpu()
            y = y.cpu()
 
            optimizer.zero_grad()
            y_hat = model(x)
            test_loss += criterion(y_hat, y).item()
            pred = y_hat.max(1, keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
        test_loss /= (i + 1)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data), 100. * correct / len(test_data)))
  
def main():
    # 如果test_flag=True,则加载已保存的模型并进行测试,测试以后不进行此模块以后的步骤
    if test_flag:
        # 加载保存的模型直接进行测试机验证
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        test(model, test_load)
        return
 
    # 如果有保存的模型,则加载模型,并在其基础上继续训练
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print('加载 epoch {} 成功!'.format(start_epoch))
    else:
        start_epoch = 0
        print('无保存了的模型,将从头开始训练!')
 
    for epoch in range(start_epoch+1, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        # 保存模型
        state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)
 
if __name__ == '__main__':
    main()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python不带重复的全排列代码
Aug 13 Python
Python random模块(获取随机数)常用方法和使用例子
May 13 Python
python执行子进程实现进程间通信的方法
Jun 02 Python
Python爬虫包 BeautifulSoup  递归抓取实例详解
Jan 28 Python
使用Python和xlwt向Excel文件中写入中文的实例
Apr 21 Python
Django1.9 加载通过ImageField上传的图片方法
May 25 Python
Pytorch十九种损失函数的使用详解
Apr 29 Python
使用OpenCV对车道进行实时检测的实现示例代码
Jun 19 Python
tensorflow 2.1.0 安装与实战教程(CASIA FACE v5)
Jun 30 Python
pycharm激活方法到2099年(激活流程)
Sep 22 Python
python实现计算图形面积
Feb 22 Python
只需要这一行代码就能让python计算速度提高十倍
May 24 Python
pytest实现多进程与多线程运行超好用的插件
Jul 15 #Python
python如何将mat文件转为png
Jul 15 #Python
python读取mat文件生成h5文件的实现
Jul 15 #Python
全网非常详细的pytest配置文件
Jul 15 #Python
Python如何加载模型并查看网络
Jul 15 #Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
You might like
DOTA2 无惧惊涛骇浪 昆卡大型水友攻略
2020/04/20 DOTA
PHP 文件上传进度条的两种实现方法的代码
2007/11/25 PHP
PHP实现通过URL提取根域名
2016/03/31 PHP
PHP在弹框中获取foreach中遍历的id值并传递给地址栏
2017/06/13 PHP
jQuery实现form表单reset按钮重置清空表单功能
2012/12/18 Javascript
JQuery实现简单的图片滑动切换特效
2015/11/22 Javascript
给before和after伪元素设置js效果的方法
2015/12/04 Javascript
解决jQuery上传插件Uploadify出现Http Error 302错误的方法
2015/12/18 Javascript
JS实现弹出居中的模式窗口示例
2016/06/20 Javascript
深入理解bootstrap框架之第二章整体架构
2016/10/09 Javascript
浅谈Angular的$q, defer, promise
2016/12/20 Javascript
JavaScript在控件上添加倒计时功能的实现代码
2017/07/04 Javascript
vue mint-ui 实现省市区街道4级联动示例(仿淘宝京东收货地址4级联动)
2017/10/16 Javascript
node.js中TCP Socket多进程间的消息推送示例详解
2018/07/10 Javascript
vue-cli脚手架的安装教程图解
2018/09/02 Javascript
Vue文件配置全局变量的实例
2018/09/06 Javascript
JS运算符简单用法示例
2020/01/19 Javascript
ES6学习笔记之let与const用法实例分析
2020/01/22 Javascript
vue css 引入asstes中的图片无法显示的四种解决方法
2020/03/16 Javascript
JavaScript实现网页tab栏效果制作
2020/11/20 Javascript
python教程之用py2exe将PY文件转成EXE文件
2014/06/12 Python
python单例模式实例分析
2015/04/08 Python
Python语言实现百度语音识别API的使用实例
2017/12/13 Python
解决Mac下首次安装pycharm无project interpreter的问题
2018/10/29 Python
selenium处理元素定位点击无效问题
2019/06/12 Python
python实现将json多行数据传入到mysql中使用
2019/12/31 Python
Pytorch通过保存为ONNX模型转TensorRT5的实现
2020/05/25 Python
Python GUI之tkinter窗口视窗教程大集合(推荐)
2020/10/20 Python
专科文秘应届生求职信
2013/11/18 职场文书
法院信息化建设方案
2014/05/21 职场文书
南京青奥会口号
2014/06/12 职场文书
文明城市标语
2014/06/16 职场文书
就业协议书盖章的注意事项
2014/09/28 职场文书
文案策划岗位个人自我评价(范文)
2019/08/08 职场文书
python中requests库+xpath+lxml简单使用
2021/04/29 Python
Go语言实现一个简单的并发聊天室的项目实战
2022/03/18 Golang