pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python实现子类调用父类的方法
Nov 10 Python
Python中super的用法实例
May 28 Python
Python首次安装后运行报错(0xc000007b)的解决方法
Oct 18 Python
python thrift搭建服务端和客户端测试程序
Jan 17 Python
flask中使用蓝图将路由分开写在不同文件实例解析
Jan 19 Python
Python实现PyPDF2处理PDF文件的方法示例
Sep 25 Python
python 中的[:-1]和[::-1]的具体使用
Feb 13 Python
python中sklearn的pipeline模块实例详解
May 21 Python
如何使用 Flask 做一个评论系统
Nov 27 Python
python中温度单位转换的实例方法
Dec 27 Python
Python列表元素删除和remove()方法详解
Jan 04 Python
Python绘制地图神器folium的新人入门指南
May 23 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
php中文件上传的安全问题
2006/10/09 PHP
PHP实现下载断点续传的方法
2014/11/12 PHP
PHP经典面试题集锦
2015/03/19 PHP
通过实例解析PHP数据类型转换方法
2020/07/11 PHP
php远程请求CURL实例教程(爬虫、保存登录状态)
2020/12/10 PHP
javascript css float属性的特殊写法
2008/11/13 Javascript
ExtJs 表单提交登陆实现代码
2010/08/19 Javascript
JavaScript高级程序设计 扩展--关于动态原型
2010/11/09 Javascript
主页面中的两个iframe实现鼠标拖动改变其大小
2013/04/16 Javascript
JS文本框追加多个下拉框的值的简单实例
2013/07/12 Javascript
JS随机生成不重复数据的实例方法
2013/07/17 Javascript
jquery队列函数用法实例
2014/12/16 Javascript
jQuery在线选座位插件seat-charts特效代码分享
2015/08/27 Javascript
JavaScript如何实现对数字保留两位小数一位自动补零
2015/12/18 Javascript
jQuery实用小技巧_输入框文字获取和失去焦点的简单实例
2016/08/25 Javascript
jQuery获取复选框选中的当前行的某个字段的值
2017/09/15 jQuery
图片懒加载imgLazyLoading.js使用详解
2020/09/15 Javascript
vue之父子组件间通信实例讲解(props、$ref、$emit)
2018/05/22 Javascript
JavaScript实现的级联算法示例【省市二级联动功能】
2018/12/25 Javascript
Python中实现常量(Const)功能
2015/01/28 Python
Python爬虫:通过关键字爬取百度图片
2017/02/17 Python
numpy中实现ndarray数组返回符合特定条件的索引方法
2018/04/17 Python
解决PyCharm控制台输出乱码的问题
2019/01/16 Python
如何在django中添加日志功能
2020/02/06 Python
浅谈tensorflow 中的图片读取和裁剪方式
2020/06/30 Python
基于opencv实现简单画板功能
2020/08/02 Python
html5/css3响应式页面开发总结
2018/10/16 HTML / CSS
html5 浏览器支持 如何让所有的浏览器都支持HTML5标签样式
2012/12/07 HTML / CSS
ASOS英国官网:英国在线时装和化妆品零售商
2017/05/19 全球购物
华为慧通笔试题
2016/04/22 面试题
幼儿教师师德承诺书
2014/05/23 职场文书
司机工作自我鉴定
2014/09/19 职场文书
大学生军训感言
2015/08/01 职场文书
导游词之沈阳植物园
2019/11/30 职场文书
MySQL 自定义变量的概念及特点
2021/05/13 MySQL
浅谈MySql整型索引和字符串索引失效或隐式转换问题
2021/11/20 MySQL