Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python list中append()与extend()用法分享
Mar 24 Python
Tensorflow实现卷积神经网络用于人脸关键点识别
Mar 05 Python
在Mac上删除自己安装的Python方法
Oct 29 Python
python基于Selenium的web自动化框架
Jul 14 Python
Python 如何优雅的将数字转化为时间格式的方法
Sep 26 Python
Python编译成.so文件进行加密后调用的实现
Dec 23 Python
pytorch实现focal loss的两种方式小结
Jan 02 Python
Python3 Click模块的使用方法详解
Feb 12 Python
Python opencv相机标定实现原理及步骤详解
Apr 09 Python
使用python创建生成动态链接库dll的方法
May 09 Python
Python 实现PS滤镜的旋涡特效
Dec 03 Python
call在Python中改进数列的实例讲解
Dec 09 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
虹吸壶煮咖啡26个注意事项
2021/03/03 冲泡冲煮
PHP 一个页面执行时间类代码
2010/03/05 PHP
php将会员数据导入到ucenter的代码
2010/07/18 PHP
php urlencode()与urldecode()函数字符编码原理详解
2011/12/06 PHP
php创建和删除目录函数介绍和递归删除目录函数分享
2014/11/18 PHP
php+ajax实时刷新简单实例
2015/02/25 PHP
php写app接口并返回json数据的实例(分享)
2017/05/20 PHP
PHP实现模拟http请求的方法分析
2017/12/20 PHP
Jquery中国地图热点效果-鼠标经过弹出提示层信息的简单实例
2014/02/12 Javascript
EasyUI中datagrid在ie下reload失败解决方案
2015/03/09 Javascript
JavaScript中的setUTCDate()方法使用详解
2015/06/11 Javascript
Javascript农历与公历相互转换的简单实例
2016/10/09 Javascript
详解如何在Angular优雅编写HTTP请求
2018/12/05 Javascript
微信小程序反编译的实现
2020/12/10 Javascript
Python Web框架Pylons中使用MongoDB的例子
2013/12/03 Python
Python:Scrapy框架中Item Pipeline组件使用详解
2017/12/27 Python
django的csrf实现过程详解
2019/07/26 Python
Python算法中的时间复杂度问题
2019/11/19 Python
使用Python和百度语音识别生成视频字幕的实现
2020/04/09 Python
css3一个简易的 LED 数字时钟实现方法
2020/01/15 HTML / CSS
Groupon比利时官方网站:特卖和网上购物高达-70%
2019/08/09 全球购物
模具设计与制造专业应届生求职信
2013/10/18 职场文书
园林设计师自荐信
2013/11/18 职场文书
服装设计行业个人的自我评价
2013/12/20 职场文书
名人演讲稿范文
2013/12/28 职场文书
财产公证书格式
2014/04/10 职场文书
学生手册评语
2014/05/05 职场文书
竞聘演讲稿精彩开头和结尾
2014/05/14 职场文书
学校感恩教育活动总结
2014/07/07 职场文书
乒乓球兴趣小组活动总结
2014/07/08 职场文书
同意离婚答辩状
2015/05/22 职场文书
退休欢送会致辞
2015/07/31 职场文书
500字作文之难忘的同学
2019/12/20 职场文书
Pandas实现DataFrame的简单运算、统计与排序
2022/03/31 Python
Win11 22H2 2022怎么更新? 获得Win1122H22022版本升级技巧
2022/09/23 数码科技
Python使用pandas导入xlsx格式的excel文件内容操作代码
2022/12/24 Python