pytorch 输出中间层特征的实例


Posted in Python onAugust 17, 2019

pytorch 输出中间层特征:

tensorflow输出中间特征,2种方式:

1. 保存全部模型(包括结构)时,需要之前先add_to_collection 或者 用slim模块下的end_points

2. 只保存模型参数时,可以读取网络结构,然后按照对应的中间层输出即可。

but:Pytorch 论坛给出的答案并不好用,无论是hooks,还是重建网络并去掉某些层,这些方法都不好用(在我看来)。

我们可以在创建网络class时,在forward时加入一个dict 或者 list,dict是将中间层名字与中间层输出分别作为key:value,然后作为第二个值返回。前提是:运行创建自己的网络(无论fine-tune),只保存网络参数。

个人理解:虽然每次运行都返回2个值,但是运行效率基本没有变化。

附上代码例子:

import torch
import torchvision
import numpy as np
from torch import nn
from torch.nn import init
from torch.autograd import Variable
from torch.utils import data

EPOCH=20
BATCH_SIZE=64
LR=1e-2

train_data=torchvision.datasets.MNIST(root='./mnist',train=True,
                   transform=torchvision.transforms.ToTensor(),download=False)
train_loader=data.DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

test_data=torchvision.datasets.MNIST(root='./mnist',train=False)

test_x=Variable(torch.unsqueeze(test_data.test_data,dim=1).type(torch.FloatTensor)).cuda()/255
test_y=test_data.test_labels.cuda()

class CNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(in_channels=1,out_channels=16,kernel_size=4,stride=1,padding=2),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2,stride=2))
    self.conv2=nn.Sequential(nn.Conv2d(16,32,4,1,2),nn.ReLU(),nn.MaxPool2d(2,2))
    self.out=nn.Linear(32*7*7,10)
    
  def forward(self,x):
    per_out=[] ############修改处##############
    x=self.conv1(x)
    per_out.append(x) # conv1
    x=self.conv2(x)
    per_out.append(x) # conv2
    x=x.view(x.size(0),-1)
    output=self.out(x)
    return output,per_out
  
cnn=CNN().cuda() # or cnn.cuda()

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
loss_func=nn.CrossEntropyLoss().cuda()############################

for epoch in range(EPOCH):
  for step,(x,y) in enumerate(train_loader):
    b_x=Variable(x).cuda()# if channel==1 auto add c=1
    b_y=Variable(y).cuda()
#    print(b_x.data.shape)
    optimizer.zero_grad()
    output=cnn(b_x)[0] ##原先只需要cnn(b_x) 但是现在需要用到第一个返回值##
    loss=loss_func(output,b_y)# Variable need to get .data
    loss.backward()
    optimizer.step()
    
    if step%50==0:
      test_output=cnn(test_x)[0]
      pred_y=torch.max(test_output,1)[1].cuda().data.squeeze()
      '''
      why data ,because Variable .data to Tensor;and cuda() not to numpy() ,must to cpu and to numpy 
      and .float compute decimal
      '''
      accuracy=torch.sum(pred_y==test_y).data.float()/test_y.size(0)
      print('EPOCH: ',epoch,'| train_loss:%.4f'%loss.data[0],'| test accuracy:%.2f'%accuracy)
    #                       loss.data.cpu().numpy().item() get one value

  torch.save(cnn.state_dict(),'./model/model.pth')

##输出中间层特征,根据索引调用##

conv1: conv1=cnn(b_x)[1][0]

conv2: conv2=cnn(b_x)[1][1]

##########################

hook使用:

res=torchvision.models.resnet18()

def get_features_hook(self, input, output):# self 代表类模块本身
  print(output.data.cpu().numpy().shape)

handle=res.layer2.register_forward_hook(get_features_hook)

a=torch.ones([1,3,224,224])

b=res(a) 直接打印出 layer2的输出形状,但是不好用。因为,实际中,我们需要return,而hook明确指出 不可以return 只能print。

所以,不建议使用hook。

以上这篇pytorch 输出中间层特征的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的Classes和Metaclasses详解
Apr 02 Python
Python实现对excel文件列表值进行统计的方法
Jul 25 Python
Python 中的with关键字使用详解
Sep 11 Python
Python爬取三国演义的实现方法
Sep 12 Python
TensorFlow深度学习之卷积神经网络CNN
Mar 09 Python
python实现海螺图片的方法示例
May 12 Python
Python 使用PyQt5 完成选择文件或目录的对话框方法
Jun 27 Python
Tensorflow 定义变量,函数,数值计算等名字的更新方式
Feb 10 Python
python打印文件的前几行或最后几行教程
Feb 13 Python
Windows下Pycharm远程连接虚拟机中Centos下的Python环境(图文教程详解)
Mar 19 Python
基于python实现对文件进行切分行
Apr 26 Python
基于python实现破解滑动验证码过程解析
May 28 Python
基于pytorch的保存和加载模型参数的方法
Aug 17 #Python
pytorch 固定部分参数训练的方法
Aug 17 #Python
python之PyQt按钮右键菜单功能的实现代码
Aug 17 #Python
pytorch 在网络中添加可训练参数,修改预训练权重文件的方法
Aug 17 #Python
python PyQt5/Pyside2 按钮右击菜单实例代码
Aug 17 #Python
Pytorch 实现自定义参数层的例子
Aug 17 #Python
Python中PyQt5/PySide2的按钮控件使用实例
Aug 17 #Python
You might like
追忆往昔!浅谈收音机的百年发展历史
2021/03/01 无线电
PHP6 mysql连接方式说明
2009/02/09 PHP
PHP实现动态获取函数参数的方法示例
2018/04/02 PHP
TP3.2.3框架使用CKeditor编辑器在页面中上传图片的方法分析
2019/12/31 PHP
[JS]点出统计器
2020/10/11 Javascript
jQuery实现行文字链接提示效果的方法
2015/03/10 Javascript
AngularJS学习第二篇 AngularJS依赖注入
2017/02/13 Javascript
JavaScript使用readAsDataURL读取图像文件
2017/05/10 Javascript
Angular2 父子组件通信方式的示例
2018/01/29 Javascript
详解webpack运行Babel教程
2018/06/13 Javascript
解决ng-repeat产生的ng-model中取不到值的问题
2018/10/02 Javascript
如何使用vuex实现兄弟组件通信
2018/11/02 Javascript
解决vue-router路由拦截造成死循环问题
2020/08/05 Javascript
[03:27]《辉夜杯》线下训练营 导师CU和海涛指点迷津
2015/10/23 DOTA
[37:23]DOTA2上海特级锦标赛主赛事日 - 3 胜者组第二轮#2Secret VS EG第二局
2016/03/04 DOTA
[01:03:22]LGD vs OG 2018国际邀请赛淘汰赛BO3 第一场 8.25
2018/08/29 DOTA
在Python中使用mongoengine操作MongoDB教程
2015/04/24 Python
python3实现短网址和数字相互转换的方法
2015/04/28 Python
web.py 十分钟创建简易博客实现代码
2016/04/22 Python
教你用python3根据关键词爬取百度百科的内容
2016/08/18 Python
python数据预处理之数据标准化的几种处理方式
2019/07/17 Python
Pandas0.25来了千万别错过这10大好用的新功能
2019/08/07 Python
深入浅析Python科学计算库Scipy及安装步骤
2019/10/12 Python
解决pytorch DataLoader num_workers出现的问题
2020/01/14 Python
简单介绍一下pyinstaller打包以及安全性的实现
2020/06/02 Python
python用Tkinter做自己的中文代码编辑器
2020/09/07 Python
Pycharm制作搞怪弹窗的实现代码
2021/02/19 Python
CSS3 Notes: -webkit-box-reflect实现倒影的实例
2016/12/08 HTML / CSS
怎样实现H5+CSS3手指滑动切换图片的示例代码
2019/05/05 HTML / CSS
关于学习的演讲稿
2014/05/10 职场文书
毕业论文答辩开场白和结束语
2015/05/27 职场文书
班主任培训研修日志
2015/11/13 职场文书
学习杨善洲同志先进事迹心得体会
2016/01/23 职场文书
又涨知识了,自律到底多重要?
2019/06/27 职场文书
SQL实现LeetCode(178.分数排行)
2021/08/04 MySQL
《帝国时代4》赛季预告 新增内容编译器可创造地图
2022/04/03 其他游戏