Python如何加载模型并查看网络


Posted in Python onJuly 15, 2022

加载模型并查看网络

加载模型,以vgg19为例。

打开终端

> python
Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
(AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from torchvision import models
>>> model = models.vgg19(pretrained=True) #此时如果是第一次加载会开始下载模型的pth文件
>>> print(model.model)

结果:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意,直接打印模型是没有办法看到模型结构的,只能看到带模型参数的pth文件内容;需要打印model.model才可以看到模型本身。

神经网络_模型的保存,模型的加载

模型的保存(torch.save)

方式1(模型结构+模型参数)

参数:保存位置

# 创建模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1——模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

方式2(模型参数)

# 保存方式2  模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

模型的加载(torch.load)

对应保存方式1

参数:模型路径

# 方式1 --》 保存方式1
model1 = torch.load("vgg16_method1.pth")

对应保存方式2

vgg16.load_state_dict("vgg16_method2.pth")

输出为字典形式。若要回复网络,采用以下形式:

model2 = torch.load("vgg16_method2.pth")  #输出是字典形式
# 恢复网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)

方式1存储,加载时需注意事项

新建自己的网络:

class test(nn.Module):
    def __init__(self):
        super(lh, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

保存自己的网络:

Test = test()
# 保存自己定义的网络
torch.save(Test, "Test_method1.pth")

加载自己的网络:

model3 = torch.load("Test_method1.pth")

会报错!!!!!!

Python如何加载模型并查看网络

解决办法(需要注意):

将定义的网络复制到加载的python文件中:

class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x
model3 = torch.load("Test_method1.pth")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python实现一个简单的能够发送带附件的邮件程序的教程
Apr 08 Python
Python中的rfind()方法使用详解
May 19 Python
Python解析树及树的遍历
Feb 03 Python
Windows系统下多版本pip的共存问题详解
Oct 10 Python
Python面向对象编程之继承与多态详解
Jan 16 Python
Python模块的制作方法实例分析
Dec 21 Python
PyQt使用QPropertyAnimation开发简单动画
Apr 02 Python
python异常处理之try finally不报错的原因
May 18 Python
卸载tensorflow-cpu重装tensorflow-gpu操作
Jun 23 Python
Python基于tkinter canvas实现图片裁剪功能
Nov 05 Python
python爬虫如何解决图片验证码
Feb 14 Python
浅谈Python基础之列表那些事儿
May 11 Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
Python通用验证码识别OCR库ddddocr的安装使用教程
Jul 07 #Python
Django数据库(SQlite)基本入门使用教程
Jul 07 #Python
Python可视化神器pyecharts之绘制地理图表练习
Django中celery的使用项目实例
Python可视化神器pyecharts绘制地理图表
You might like
隐藏X-Space个人空间下方版权方法隐藏X-Space个人空间标题隐藏X-Space个人空间管理版权方法
2007/02/22 PHP
PHP中is_file不能替代file_exists的理由
2014/03/04 PHP
PHP使用json_encode函数时不转义中文的解决方法
2014/11/12 PHP
用js实现下载远程文件并保存在本地的脚本
2008/05/06 Javascript
jQuery 行背景颜色的交替显示(隔行变色)实现代码
2009/12/13 Javascript
利用JS延迟加载百度分享代码,提高网页速度
2013/07/01 Javascript
jquery实现隐藏与显示动画效果/输入框字符动态递减/导航按钮切换
2013/07/01 Javascript
js动态添加事件并可传参数示例代码
2013/10/21 Javascript
jquery基础教程之deferred对象使用方法
2014/01/22 Javascript
javascript多物体运动实现方法分析
2016/01/08 Javascript
如何用JS判断两个数字的大小
2016/07/21 Javascript
Vue2.0权限树组件实现代码
2017/08/29 Javascript
解决angularjs中同步执行http请求的方法
2018/08/13 Javascript
node.js 基于cheerio的爬虫工具的实现(需要登录权限的爬虫工具)
2019/04/10 Javascript
vue cli3.0打包上线静态资源找不到路径的解决操作
2020/08/03 Javascript
[54:51]Ti4 冒泡赛第二轮LGD vs C9 3
2014/07/14 DOTA
利用Python实现命令行版的火车票查看器
2016/08/05 Python
Python3实现的画图及加载图片动画效果示例
2018/01/19 Python
python使用pandas处理大数据节省内存技巧(推荐)
2019/05/05 Python
python如何解析配置文件并应用到项目中
2019/06/27 Python
python中pow函数用法及功能说明
2020/12/04 Python
详解python的变量缓存机制
2021/01/24 Python
HTML5安全介绍之内容安全策略(CSP)简介
2012/07/10 HTML / CSS
用HTML5实现手机摇一摇的功能的教程
2012/10/30 HTML / CSS
详解HTML5中ol标签的用法
2015/09/08 HTML / CSS
整理HTML5的一些新特性与Canvas的常用属性
2016/01/29 HTML / CSS
HTML5中meta属性的使用方法
2016/02/29 HTML / CSS
关于h5中的fetch方法解读(小结)
2017/11/15 HTML / CSS
小学体育教学反思
2014/01/31 职场文书
可口可乐广告词
2014/03/20 职场文书
彩妆大赛策划方案
2014/05/13 职场文书
安全教育观后感
2015/06/17 职场文书
2015中学教师个人工作总结
2015/07/22 职场文书
漫画「你在春天醒来」第10卷封面公开
2022/03/21 日漫
openstack云计算keystone组件工作介绍
2022/04/20 Servers
python三子棋游戏
2022/05/04 Python