Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python smtplib模块发送SSL/TLS安全邮件实例
Apr 08 Python
高效使用Python字典的清单
Apr 04 Python
python实现在pandas.DataFrame添加一行
Apr 04 Python
python解决js文件utf-8编码乱码问题(推荐)
May 02 Python
Python实现正弦信号的时域波形和频谱图示例【基于matplotlib】
May 04 Python
用pycharm开发django项目示例代码
Jun 13 Python
Python 实现将大图切片成小图,将小图组合成大图的例子
Mar 14 Python
如何将PySpark导入Python的放实现(2种)
Apr 26 Python
Django cookie和session的应用场景及如何使用
Apr 29 Python
python神经网络编程之手写数字识别
May 08 Python
python树莓派通过队列实现进程交互的程序分析
Jul 04 Python
Python中文纠错的简单实现
Jul 07 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
Windows2003下php5.4安装配置教程(Apache2.4)
2016/06/30 PHP
php cookie 详解使用实例
2016/11/03 PHP
thinkphp框架实现路由重定义简化url访问地址的方法分析
2020/04/04 PHP
jQuery checkbox全选/取消全选实现代码
2009/11/14 Javascript
asp.net下利用js实现返回上一页的实现方法小集
2009/11/24 Javascript
解决html按钮切换绑定不同函数后点击时执行多次函数问题
2014/05/14 Javascript
一个JavaScript处理textarea中的字符成每一行实例
2014/09/22 Javascript
js实现简单鼠标跟随效果的方法
2015/04/10 Javascript
js变形金刚文字特效代码分享
2015/08/20 Javascript
AngularJS基础 ng-value 指令简单示例
2016/08/03 Javascript
JavaScript 函数模式详解及示例
2016/09/07 Javascript
详解vue表单验证组件 v-verify-plugin
2017/04/19 Javascript
NodeJS链接MySql数据库的操作方法
2017/06/27 NodeJs
jQuery条件分页 代替离线查询(附代码)
2017/08/17 jQuery
基于BootStrap的文本编辑器组件Summernote
2017/10/27 Javascript
Angular数据绑定机制原理
2018/04/17 Javascript
详解Webpack-dev-server的proxy用法
2018/09/08 Javascript
利用百度echarts实现图表功能简单入门示例【附源码下载】
2019/06/10 Javascript
JS/CSS实现字符串单词首字母大写功能
2019/09/03 Javascript
[10:18]2018DOTA2国际邀请赛寻真——找回自信的TNCPredator
2018/08/13 DOTA
python的urllib模块显示下载进度示例
2014/01/17 Python
python实现汉诺塔方法汇总
2016/07/25 Python
怎样使用Python脚本日志功能
2016/08/14 Python
python 从文件夹抽取图片另存的方法
2018/12/04 Python
Python的numpy库下的几个小函数的用法(小结)
2019/07/12 Python
Pyecharts绘制全球流向图的示例代码
2020/01/08 Python
Django权限设置及验证方式
2020/05/13 Python
Python如何解除一个装饰器
2020/08/07 Python
用ldap作为django后端用户登录验证的实现
2020/12/07 Python
世界上最大的餐具公司:Oneida
2016/12/17 全球购物
澳大利亚Rockwear官网:女子瑜伽、健身和运动服
2021/01/26 全球购物
保加利亚服装和鞋类购物网站:Bibloo.bg
2020/11/08 全球购物
拓展培训心得体会
2014/01/04 职场文书
《三亚落日》教学反思
2014/04/26 职场文书
喋血孤城观后感
2015/06/08 职场文书
MybatisPlus EntityWrapper如何自定义SQL
2022/03/22 Java/Android