pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用cx_freeze把python打包exe示例
Jan 24 Python
python通过邮件服务器端口发送邮件的方法
Apr 30 Python
python求解水仙花数的方法
May 11 Python
Python中断言Assertion的一些改进方案
Oct 27 Python
Python字符串拼接六种方法介绍
Dec 18 Python
Python3结合Dlib实现人脸识别和剪切
Jan 24 Python
tensorflow建立一个简单的神经网络的方法
Feb 10 Python
python无限生成不重复(字母,数字,字符)组合的方法
Dec 04 Python
Python递归函数特点及原理解析
Mar 04 Python
解决python运行启动报错问题
Jun 01 Python
Python爬虫数据的分类及json数据使用小结
Mar 29 Python
Pytest实现setup和teardown的详细使用详解
Apr 17 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
延长phpmyadmin登录时间的方法
2011/02/06 PHP
php递归使用示例(php递归函数)
2014/02/14 PHP
php生成动态验证码gif图片
2015/10/19 PHP
用Javascript实现UTF8编码转换成gb2312编码
2006/12/22 Javascript
JavaScript asp.net 获取当前超链接中的文本
2009/04/14 Javascript
在vs2010中调试javascript代码方法
2011/02/11 Javascript
autoPlay 基于jquery的图片自动播放效果
2011/12/07 Javascript
jQuery学习笔记 操作jQuery对象 文档处理
2012/09/19 Javascript
jQuery div层的放大与缩小简单实现代码
2013/03/28 Javascript
JS简单实现登陆验证附效果图
2013/11/19 Javascript
javascript基于HTML5 canvas制作画箭头组件
2014/06/25 Javascript
写给小白的JavaScript引擎指南
2015/12/04 Javascript
JS实现页面跳转参数不丢失的方法
2016/11/28 Javascript
jQuery的事件预绑定
2016/12/05 Javascript
jQuery基本筛选选择器实例代码
2017/02/06 Javascript
微信小程序联网请求的轮播图
2017/07/07 Javascript
jQuery与vue实现拖动验证码功能
2018/01/30 jQuery
Vue路由钩子之afterEach beforeEach的区别详解
2018/07/15 Javascript
vue elementUI table表格数据 滚动懒加载的实现方法
2019/04/04 Javascript
详细讲解如何创建, 发布自己的 Vue UI 组件库
2019/05/29 Javascript
如何使用50行javaScript代码实现简单版的call,apply,bind
2019/08/14 Javascript
axios如何取消重复无用的请求详解
2019/12/15 Javascript
js生成1到100的随机数最简单的实现方法
2020/02/07 Javascript
JavaScript canvas基于数组生成柱状图代码实例
2020/03/06 Javascript
[00:34]DOTA2上海特级锦标赛 VG战队宣传片
2016/03/04 DOTA
[01:38]完美世界高校联赛决赛花絮
2018/12/02 DOTA
python 提取文件的小程序
2009/07/29 Python
python递归计算N!的方法
2015/05/05 Python
python 实现在shell窗口中编写print不向屏幕输出
2020/02/19 Python
Monnier Frères美国官网:法国知名奢侈品网站
2016/11/22 全球购物
jQuery treeview树形结构应用
2021/03/24 jQuery
社区七一党员活动方案
2014/01/25 职场文书
村干部培训班主持词
2014/03/28 职场文书
平安建设工作方案
2014/06/02 职场文书
大学生党性分析材料
2014/12/19 职场文书
Js类的构建与继承案例详解
2021/09/15 Javascript