pytorch实现用Resnet提取特征并保存为txt文件的方法


Posted in Python onAugust 20, 2019

接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。

以下是提取一张jpg图像的特征的程序:

# -*- coding: utf-8 -*-
 
import os.path
 
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable 
 
import numpy as np
from PIL import Image 
 
features_dir = './features'
 
img_path = "hymenoptera_data/train/ants/0013035.jpg"
file_name = img_path.split('/')[-1]
feature_path = os.path.join(features_dir, file_name + '.txt')
 
 
transform1 = transforms.Compose([
    transforms.Scale(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()  ]
)
 
img = Image.open(img_path)
img1 = transform1(img)
 
#resnet18 = models.resnet18(pretrained = True)
resnet50_feature_extractor = models.resnet50(pretrained = True)
resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
 
for param in resnet50_feature_extractor.parameters():
  param.requires_grad = False
#resnet152 = models.resnet152(pretrained = True)
#densenet201 = models.densenet201(pretrained = True) 
x = Variable(torch.unsqueeze(img1, dim=0).float(), requires_grad=False)
#y1 = resnet18(x)
y = resnet50_feature_extractor(x)
y = y.data.numpy()
np.savetxt(feature_path, y, delimiter=',')
#y3 = resnet152(x)
#y4 = densenet201(x)
 
y_ = np.loadtxt(feature_path, delimiter=',').reshape(1, 2048)

以下是提取一个文件夹下所有jpg、jpeg图像的程序:

# -*- coding: utf-8 -*-
import os, torch, glob
import numpy as np
from torch.autograd import Variable
from PIL import Image 
from torchvision import models, transforms
import torch.nn as nn
import shutil
data_dir = './hymenoptera_data'
features_dir = './features'
shutil.copytree(data_dir, os.path.join(features_dir, data_dir[2:]))
 
 
def extractor(img_path, saved_path, net, use_gpu):
  transform = transforms.Compose([
      transforms.Scale(256),
      transforms.CenterCrop(224),
      transforms.ToTensor()  ]
  )
  
  img = Image.open(img_path)
  img = transform(img)
  
 
 
  x = Variable(torch.unsqueeze(img, dim=0).float(), requires_grad=False)
  if use_gpu:
    x = x.cuda()
    net = net.cuda()
  y = net(x).cpu()
  y = y.data.numpy()
  np.savetxt(saved_path, y, delimiter=',')
  
if __name__ == '__main__':
  extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
    
  files_list = []
  sub_dirs = [x[0] for x in os.walk(data_dir) ]
  sub_dirs = sub_dirs[1:]
  for sub_dir in sub_dirs:
    for extention in extensions:
      file_glob = os.path.join(sub_dir, '*.' + extention)
      files_list.extend(glob.glob(file_glob))
    
  resnet50_feature_extractor = models.resnet50(pretrained = True)
  resnet50_feature_extractor.fc = nn.Linear(2048, 2048)
  torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
  for param in resnet50_feature_extractor.parameters():
    param.requires_grad = False  
    
  use_gpu = torch.cuda.is_available()
 
  for x_path in files_list:
    print(x_path)
    fx_path = os.path.join(features_dir, x_path[2:] + '.txt')
    extractor(x_path, fx_path, resnet50_feature_extractor, use_gpu)

另外最近发现一个很简单的提取不含FC层的网络的方法:

resnet = models.resnet152(pretrained=True)
    modules = list(resnet.children())[:-1]   # delete the last fc layer.
    convnet = nn.Sequential(*modules)

另一种更简单的方法:

resnet = models.resnet152(pretrained=True)
del resnet.fc

以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python和GO语言实现的消息摘要算法示例
Mar 10 Python
git进行版本控制心得详谈
Dec 10 Python
Python爬虫实例_城市公交网络站点数据的爬取方法
Jan 10 Python
python的re正则表达式实例代码
Jan 24 Python
django允许外部访问的实例讲解
May 14 Python
Python 打印中文字符的三种方法
Aug 14 Python
python实现括号匹配的思路详解
Aug 23 Python
python 对多个csv文件分别进行处理的方法
Jan 07 Python
python实现远程控制电脑
May 23 Python
Python 文件数据读写的具体实现
Jan 24 Python
python字典的值可以修改吗
Jun 29 Python
互斥锁解决 Python 中多线程共享全局变量的问题(推荐)
Sep 28 Python
python web框架 django wsgi原理解析
Aug 20 #Python
opencv转换颜色空间更改图片背景
Aug 20 #Python
pytorch 预训练层的使用方法
Aug 20 #Python
python爬虫 urllib模块反爬虫机制UA详解
Aug 20 #Python
Pytorch 抽取vgg各层并进行定制化处理的方法
Aug 20 #Python
python实现抠图给证件照换背景源码
Aug 20 #Python
python爬虫 基于requests模块发起ajax的get请求实现解析
Aug 20 #Python
You might like
php中try catch捕获异常实例详解
2014/11/21 PHP
php数组转成json格式的方法
2015/03/09 PHP
深入研究PHP中的preg_replace和代码执行
2018/08/15 PHP
JavaScript 拖拉缩放效果
2008/12/10 Javascript
javascript中的document.open()方法使用介绍
2013/10/09 Javascript
Javascript 拖拽雏形(逐行分析代码,让你轻松了拖拽的原理)
2015/01/23 Javascript
JS 对象属性相关(检查属性、枚举属性等)
2015/04/05 Javascript
简单的JS轮播图代码
2016/07/18 Javascript
vue从使用到源码实现教程详解
2016/09/19 Javascript
Node.js连接postgreSQL并进行数据操作
2016/12/18 Javascript
微信小程序 sha1 实现密码加密实例详解
2017/07/06 Javascript
微信小程序 获取javascript 里的数据
2017/08/17 Javascript
Vue 重置组件到初始状态的方法示例
2018/10/10 Javascript
vue mounted 调用两次的完美解决办法
2018/10/29 Javascript
jQuery-Citys省市区三级菜单联动插件使用详解
2019/07/26 jQuery
JavaScript canvas绘制渐变颜色的矩形
2020/02/18 Javascript
vue开发简单上传图片功能
2020/06/30 Javascript
OpenLayer学习之自定义测量控件
2020/09/28 Javascript
Python线程的两种编程方式
2015/04/14 Python
pytorch 实现将自己的图片数据处理成可以训练的图片类型
2020/01/08 Python
Python模块/包/库安装的六种方法及区别
2020/02/24 Python
Python 生成VOC格式的标签实例
2020/03/10 Python
Spartoo英国:欧洲最大的网上鞋店
2016/09/13 全球购物
BudgetAir印度:预订航班、酒店和汽车租赁
2019/07/07 全球购物
意大利和国际奢侈品牌购物网站:Suitnegozi.com
2021/01/15 全球购物
面向对象编程OOP的优点
2013/01/22 面试题
实习生的自我鉴定范文欣赏
2013/11/20 职场文书
上学迟到的检讨书
2014/01/11 职场文书
网络书店创业计划书
2014/02/07 职场文书
经济担保书范文
2014/04/02 职场文书
经典毕业生求职信
2014/07/12 职场文书
班主任师德师风自我剖析材料
2014/10/02 职场文书
入股协议书范本
2014/11/01 职场文书
《叶问2》观后感
2015/06/15 职场文书
公务员岗前培训心得体会
2016/01/08 职场文书
详解PHP服务器如何在有限的资源里最大提升并发能力
2021/05/25 PHP