pytorch实现用Resnet提取特征并保存为txt文件的方法


Posted in Python onAugust 20, 2019

接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。

以下是提取一张jpg图像的特征的程序:

# -*- coding: utf-8 -*-
 
import os.path
 
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable 
 
import numpy as np
from PIL import Image 
 
features_dir = './features'
 
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
 
 
transform1 = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()  ]
)
 
img = Image.open(img_path)
img1 = transform1(img)
 
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
 
for param in resnet50_feature_extractor.parameters():
  param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True) 
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)
 
y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)

以下是提取一个文件夹下所有jpg、jpeg图像的程序:

# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image 
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))
 
 
def extractor(img_path, saved_path, net, use_gpu):
  transform = transforms.Compose([
      transforms.Scale(256),
      transforms.CenterCrop(224),
      transforms.ToTensor()  ]
  )
  
  img = Image.open(img_path)
  img = transform(img)
  
 
 
  x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
  if use_gpu:
    x = x.cuda()
    net = net.cuda()
  y = net(x).cpu()
  y = y.data.numpy()
  np.savetxt(saved_path, y, delimiter=',')
  
if __name__ == '__main__':
  extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    
  files_list = []
  sub_dirs = [x[0] for x in os.walk(data_dir) ]
  sub_dirs = sub_dirs[1:]
  for sub_dir in sub_dirs:
    for extention in extensions:
      file_glob = os.path.join(sub_dir, '*.' + extention)
      files_list.extend(glob.glob(file_glob))
    
  resnet50_feature_extractor = models.resnet50(pretrained = True)
  resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
  torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
  for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False  
    
  use_gpu = torch.cuda.is_available()
 
  for x_path in files_list:
    print(x_path)
    fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
    extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)

另外最近发现一个很简单的提取不含FC层的网络的方法:

resnet = models.resnet152(pretrained=True)
    modules = list(resnet.children())[:-1]   # delete the last fc layer.
    convnet = nn.Sequential(*modules)

另一种更简单的方法:

resnet = models.resnet152(pretrained=True)
del resnet.fc

以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python socket.error: [Errno 98] Address already in use的原因和解决方法
Aug 25 Python
Python中动态获取对象的属性和方法的教程
Apr 09 Python
Python2.7下安装Scrapy框架步骤教程
Dec 22 Python
对Python 3.5拼接列表的新语法详解
Nov 08 Python
Python面向对象之类和对象属性的增删改查操作示例
Dec 14 Python
只需7行Python代码玩转微信自动聊天
Jan 27 Python
django模板加载静态文件的方法步骤
Mar 01 Python
Python实现通过解析域名获取ip地址的方法分析
May 17 Python
详解Python可视化神器Yellowbrick使用
Nov 11 Python
最小二乘法及其python实现详解
Feb 24 Python
初学者学习Python好还是Java好
May 26 Python
python根据用户需求输入想爬取的内容及页数爬取图片方法详解
Aug 03 Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
Pytorch 抽取vgg各层并进行定制化处理的方法
Aug 20 #Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
You might like
使用PHP制作新闻系统的思路
2006/10/09 PHP
PHP 处理TXT文件(打开/关闭/检查/读取)
2013/05/13 PHP
PHP与MYSQL中UTF8编码的中文排序实例
2014/10/21 PHP
php实现的简单检验登陆类
2015/06/18 PHP
PHP中如何使用session实现保存用户登录信息
2015/10/20 PHP
PHP中字符与字节的区别及字符串与字节转换示例
2016/10/15 PHP
基于PHP实现的多元线性回归模拟曲线算法
2018/01/30 PHP
PHP实现微信小程序用户授权的工具类示例
2019/03/05 PHP
php弹出提示框的是实例写法
2019/09/26 PHP
laravel框架实现去掉URL中index.php的方法
2019/10/12 PHP
ExtJS 2.0实用简明教程 之获得ExtJS
2009/04/29 Javascript
javascript+iframe 实现无刷新载入整页的代码
2010/03/17 Javascript
jQuery中first()方法用法实例
2015/01/06 Javascript
微信小程序 loading(加载中提示框)实例
2016/10/28 Javascript
详解Angular的内置过滤器和自定义过滤器【推荐】
2016/12/26 Javascript
原生js实现节日时间倒计时功能
2017/01/18 Javascript
JavaScript实现各种排序的代码详解
2017/08/28 Javascript
详解vue-router 命名路由和命名视图
2018/06/01 Javascript
vue组件实现可搜索下拉框扩展
2020/10/23 Javascript
深入理解JavaScript 中的匿名函数((function() {})();)与变量的作用域
2018/08/28 Javascript
vue项目使用axios发送请求让ajax请求头部携带cookie的方法
2018/09/26 Javascript
vue 基于element-ui 分页组件封装的实例代码
2018/12/10 Javascript
webpack 如何解析代码模块路径的实现
2019/09/04 Javascript
vue+elementui 对话框取消 表单验证重置示例
2019/10/29 Javascript
vue浏览器返回监听的具体步骤
2021/02/03 Vue.js
解决python执行较大excel文件openpyxl慢问题
2020/05/15 Python
Python StringIO及BytesIO包使用方法解析
2020/06/15 Python
Win10环境中如何实现python2和python3并存
2020/07/20 Python
HTML5移动端开发遇见的东西
2019/10/11 HTML / CSS
HTML5 canvas画矩形时出现边框样式不一致的解决方法
2013/10/14 HTML / CSS
新年团拜会主持词
2014/04/02 职场文书
公司捐书倡议书
2015/04/27 职场文书
2015年医院科室工作总结范文
2015/05/26 职场文书
公司处罚决定书
2015/06/24 职场文书
详解Spring Boot使用系统参数表提升系统的灵活性
2021/06/30 Java/Android
「约定的梦幻岛」作画发布诺曼生日新绘
2022/03/21 日漫