Python如何加载模型并查看网络


Posted in Python onJuly 15, 2022

加载模型并查看网络

加载模型,以vgg19为例。

打开终端

> python
Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
(AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from torchvision import models
>>> model = models.vgg19(pretrained=True) #此时如果是第一次加载会开始下载模型的pth文件
>>> print(model.model)

结果:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

注意,直接打印模型是没有办法看到模型结构的,只能看到带模型参数的pth文件内容;需要打印model.model才可以看到模型本身。

神经网络_模型的保存,模型的加载

模型的保存(torch.save)

方式1(模型结构+模型参数)

参数:保存位置

# 创建模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1——模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")

方式2(模型参数)

# 保存方式2  模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

模型的加载(torch.load)

对应保存方式1

参数:模型路径

# 方式1 --》 保存方式1
model1 = torch.load("vgg16_method1.pth")

对应保存方式2

vgg16.load_state_dict("vgg16_method2.pth")

输出为字典形式。若要回复网络,采用以下形式:

model2 = torch.load("vgg16_method2.pth")  #输出是字典形式
# 恢复网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)

方式1存储,加载时需注意事项

新建自己的网络:

class test(nn.Module):
    def __init__(self):
        super(lh, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x

保存自己的网络:

Test = test()
# 保存自己定义的网络
torch.save(Test, "Test_method1.pth")

加载自己的网络:

model3 = torch.load("Test_method1.pth")

会报错!!!!!!

Python如何加载模型并查看网络

解决办法(需要注意):

将定义的网络复制到加载的python文件中:

class test(nn.Module):
    def __init__(self):
        super(test, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x
model3 = torch.load("Test_method1.pth")

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python中的列表推导浅析
Apr 26 Python
Tornado服务器中绑定域名、虚拟主机的方法
Aug 22 Python
跟老齐学Python之再深点,更懂list
Sep 20 Python
Python简单遍历字典及删除元素的方法
Sep 18 Python
Python类的动态修改的实例方法
Mar 24 Python
Python3中条件控制、循环与函数的简易教程
Nov 21 Python
如何在sae中设置django,让sae的工作环境跟本地python环境一致
Nov 21 Python
Python实现去除列表中重复元素的方法小结【4种方法】
Apr 27 Python
Django配置celery(非djcelery)执行异步任务和定时任务
Jul 16 Python
python中的 zip函数详解及用法举例
Feb 16 Python
django-orm F对象的使用 按照两个字段的和,乘积排序实例
May 18 Python
利用python实时刷新基金估值(摸鱼小工具)
Sep 15 Python
Python绘制散点图之可视化神器pyecharts
Jul 07 #Python
Python可视化神器pyecharts之绘制箱形图
Jul 07 #Python
Python通用验证码识别OCR库ddddocr的安装使用教程
Jul 07 #Python
Django数据库(SQlite)基本入门使用教程
Jul 07 #Python
Python可视化神器pyecharts之绘制地理图表练习
Django中celery的使用项目实例
Python可视化神器pyecharts绘制地理图表
You might like
php GUID生成函数和类
2014/03/10 PHP
百度地图经纬度转换到腾讯地图/Google 对应的经纬度
2015/08/28 PHP
详解thinkphp中的volist标签
2018/01/15 PHP
取得一定长度的内容,处理中文
2006/12/20 Javascript
Jsonp 跨域的原理以及Jquery的解决方案
2010/05/18 Javascript
判断对象是否Window的实现代码
2012/01/10 Javascript
分享XmlHttpRequest调用Webservice的一点心得
2012/07/20 Javascript
javascript结合html5 canvas实现(可调画笔颜色/粗细/橡皮)的涂鸦板
2013/04/27 Javascript
js中settimeout方法加参数
2014/02/28 Javascript
jQuery实现下拉框选择图片功能实例
2015/08/08 Javascript
jQuery、layer实现弹出层的打开、关闭功能
2017/06/28 jQuery
JavaScript 中Date对象的格式化代码方法汇总
2017/09/06 Javascript
Vue.js实现数据响应的方法
2018/08/13 Javascript
配置一个vue3.0项目的完整步骤
2019/04/26 Javascript
浅谈Three.js截图并下载的大坑
2019/11/01 Javascript
js回调函数原理与用法案例分析
2020/03/04 Javascript
vue实现瀑布流组件滑动加载更多
2020/03/10 Javascript
Vue CLI3移动端适配(px2rem或postcss-plugin-px2rem)
2020/04/27 Javascript
使用React-Router实现前端路由鉴权的示例代码
2020/07/26 Javascript
[04:10]2018年度CS GO玩家最喜爱的主播-完美盛典
2018/12/16 DOTA
python开发的小球完全弹性碰撞游戏代码
2013/10/15 Python
2款Python内存检测工具介绍和使用方法
2014/06/01 Python
python使用clear方法清除字典内全部数据实例
2015/07/11 Python
老生常谈Python之装饰器、迭代器和生成器
2017/07/26 Python
Python2.7编程中SQLite3基本操作方法示例
2017/08/09 Python
解决Django的request.POST获取不到内容的问题
2018/05/28 Python
关于python下cv.waitKey无响应的原因及解决方法
2019/01/10 Python
Pandas_cum累积计算和rolling滚动计算的用法详解
2019/07/04 Python
python监控nginx端口和进程状态
2019/09/06 Python
python-OpenCV 实现将数组转换成灰度图和彩图
2020/01/09 Python
基于Python爬取搜狐证券股票过程解析
2020/11/18 Python
美国体育用品商店:Rally House(NCAA、NFL、MLB、NBA、NHL和MLS)
2018/01/03 全球购物
出纳岗位职责范本
2015/03/31 职场文书
分析并发编程之LongAdder原理
2021/06/29 Java/Android
Golang日志包的使用
2022/04/20 Golang
Java对文件的读写操作方法
2022/04/29 Java/Android