pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python编写暴力破解FTP密码小工具
Nov 19 Python
Python使用正则匹配实现抓图代码分享
Apr 02 Python
Python搭建APNS苹果推送通知推送服务的相关模块使用指南
Jun 02 Python
python入门前的第一课 python怎样入门
Mar 06 Python
python文本数据相似度的度量
Mar 12 Python
python中字符串变二维数组的实例讲解
Apr 03 Python
利用python开发app实战的方法
Jul 09 Python
深入浅析Python 中的sklearn模型选择
Oct 12 Python
python set集合使用方法解析
Nov 05 Python
python-sys.stdout作为默认函数参数的实现
Feb 21 Python
python模拟实现分发扑克牌
Apr 22 Python
pytorch查看通道数 维数 尺寸大小方式
May 26 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
php中curl、fsocket、file_get_content三个函数的使用比较
2014/05/09 PHP
php在数组中查找指定值的方法
2015/03/17 PHP
Laravel 5框架学习之向视图传送数据
2015/04/08 PHP
js函数般调用正则
2008/04/08 Javascript
jquery prop的使用介绍及与attr的区别
2013/12/19 Javascript
检查输入的是否是数字使用keyCode配合onkeypress事件
2014/01/23 Javascript
用于deeplink的js方法(判断手机是否安装app)
2014/04/02 Javascript
javascript如何判断输入的url是否正确
2014/04/11 Javascript
js实现右下角提示框的方法
2015/02/03 Javascript
js实现简单的联动菜单效果
2015/08/19 Javascript
Bootstrap每天必学之前端开发框架
2015/11/19 Javascript
JavaScript结合Bootstrap仿微信后台多图文界面管理
2016/07/22 Javascript
关于vue面试题汇总
2018/03/20 Javascript
JS实现获取数组中最大值或最小值功能示例
2019/03/02 Javascript
vue子传父关于.sync与$emit的实现
2019/11/05 Javascript
Python strip lstrip rstrip使用方法
2008/09/06 Python
Python Queue模块详解
2014/11/30 Python
Python2.x中文乱码问题解决方法
2015/06/02 Python
Python三级目录展示的实现方法
2016/09/28 Python
urllib和BeautifulSoup爬取维基百科的词条简单实例
2018/01/17 Python
Python3随机漫步生成数据并绘制
2018/08/27 Python
Tensorflow卷积实现原理+手写python代码实现卷积教程
2020/05/22 Python
python中selenium库的基本使用详解
2020/07/31 Python
通过HTML5 Canvas API绘制弧线和圆形的教程
2016/03/14 HTML / CSS
Web前端页面跳转并取到值
2017/04/24 HTML / CSS
canvas实现滑动验证的实现示例
2020/08/11 HTML / CSS
Miller Harris官网:英国小众香水品牌
2020/09/24 全球购物
女大学生毕业找工作的自我评价
2013/10/03 职场文书
中职生自我鉴定范文
2013/10/03 职场文书
学期自我鉴定
2013/11/04 职场文书
庆元旦迎新年广播稿
2014/02/18 职场文书
中职毕业生自我鉴定范文(3篇)
2014/09/28 职场文书
2014年小学教导处工作总结
2014/12/19 职场文书
2015年高中班主任工作总结
2015/04/30 职场文书
python使用PySimpleGUI设置进度条及控件使用
2021/06/10 Python
MySQL分布式恢复进阶
2022/07/23 MySQL