Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python批量同步web服务器代码核心程序
Sep 01 Python
Python 制作糗事百科爬虫实例
Sep 22 Python
Django 前后台的数据传递的方法
Aug 08 Python
Python if语句知识点用法总结
Jun 10 Python
详解如何为eclipse安装合适版本的python插件pydev
Nov 04 Python
Python机器学习之scikit-learn库中KNN算法的封装与使用方法
Dec 14 Python
python用插值法绘制平滑曲线
Feb 19 Python
Python+PyQT5的子线程更新UI界面的实例
Jun 14 Python
Django连接数据库并实现读写分离过程解析
Nov 13 Python
django 实现简单的插入视频
Apr 07 Python
Pytest之测试命名规则的使用
Apr 16 Python
python实现简易自习室座位预约系统
Jun 30 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
自动生成文章摘要的代码[PHP 版本]
2007/03/20 PHP
Yii快速入门经典教程
2015/12/28 PHP
一个实用的php验证码类
2017/07/06 PHP
phpinfo无法显示的原因及解决办法
2019/02/15 PHP
页面定时刷新(1秒刷新一次)
2013/11/22 Javascript
JavaScript父子窗体间的调用方法
2015/03/31 Javascript
js实现鼠标滑过文字链接色彩变化的效果
2015/05/06 Javascript
JS实现按比例缩放图片的方法(附C#版代码)
2015/12/08 Javascript
JS 动态加载js文件和css文件 同步/异步的两种简单方式
2016/09/23 Javascript
微信通过页面(H5)直接打开本地app的解决方法
2017/09/09 Javascript
vue基于mint-ui的城市选择3级联动的示例
2017/10/25 Javascript
React-native桥接Android原生开发详解
2018/01/17 Javascript
浅谈KOA2 Restful方式路由初探
2019/03/14 Javascript
Vue批量图片显示时遇到的路径被解析问题
2019/03/28 Javascript
element-ui 中使用upload多文件上传只请求一次接口
2019/07/19 Javascript
微信小程序错误this.setData报错及解决过程
2019/09/18 Javascript
vue实现整屏滚动切换
2020/06/29 Javascript
微信小程序视频弹幕发送功能的实现
2020/12/28 Javascript
python中urllib模块用法实例详解
2014/11/19 Python
DataFrame中去除指定列为空的行方法
2018/04/08 Python
python使用socket创建tcp服务器和客户端
2018/04/12 Python
对Python发送带header的http请求方法详解
2019/01/02 Python
对Python+opencv将图片生成视频的实例详解
2019/01/08 Python
Python设计模式之备忘录模式原理与用法详解
2019/01/15 Python
对python生成业务报表的实例详解
2019/02/03 Python
keras使用Sequence类调用大规模数据集进行训练的实现
2020/06/22 Python
html5实现输入框fixed定位在屏幕最底部兼容性
2020/07/03 HTML / CSS
欧洲最大的笔和书写专家:The Pen Shop
2017/03/19 全球购物
欧舒丹澳洲版:L’OCCITANE
2017/07/17 全球购物
火山咖啡:Volcanica Coffee
2019/10/29 全球购物
大学秋游活动方案
2014/02/11 职场文书
学生党员一帮一活动总结
2014/07/08 职场文书
设立有限责任公司出资协议书
2014/11/01 职场文书
上甘岭观后感
2015/06/10 职场文书
幼儿园迎新生欢迎词
2015/09/30 职场文书
教师网络培训心得体会
2016/01/09 职场文书