Pytorch的mean和std调查实例


Posted in Python onJanuary 02, 2020

如下所示:

# coding: utf-8

from __future__ import print_function
import copy
import click
import cv2
import numpy as np
import torch
from torch.autograd import Variable
from torchvision import models, transforms

import matplotlib.pyplot as plt
import load_caffemodel
import scipy.io as sio

# if model has LSTM
# torch.backends.cudnn.enabled = False

imgpath = 'D:/ck/files_detected_face224/'   

imgname = 'S055_002_00000025.png' # anger
image_path = imgpath + imgname

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
raw_image = cv2.imread(image_path)[..., ::-1]
print(raw_image.shape)
raw_image = cv2.resize(raw_image, (224, ) * 2)
image = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    mean=mean_file,
    std =std_file,
    #mean = mean_file,
    #std = std_file,
  )
])(raw_image).unsqueeze(0)

print(image.shape)

convert_image1 = image.numpy()
convert_image1 = np.squeeze(convert_image1) # 3* 224 *224, C * H * W
convert_image1 = convert_image1 * np.reshape(std_file,(3,1,1)) + np.reshape(mean_file,(3,1,1))
convert_image1 = np.transpose(convert_image1, (1,2,0)) # H * W * C
print(convert_image1.shape)

convert_image1 = convert_image1 * 255

diff = raw_image - convert_image1
err = np.max(diff)
print(err)
plt.imshow(np.uint8(convert_image1))
plt.show()

结论:

input_image = (raw_image / 255 - mean) ./ std

下面调查均值文件和方差文件是如何生成的:

mean_file = [0.485, 0.456, 0.406]
std_file = [0.229, 0.224, 0.225]
# coding: utf-8
import matplotlib.pyplot as plt
import argparse
import os
import numpy as np
import torchvision
import torchvision.transforms as transforms

dataset_names = ('cifar10','cifar100','mnist')

parser = argparse.ArgumentParser(description='PyTorchLab')
parser.add_argument('-d', '--dataset', metavar='DATA', default='cifar10', choices=dataset_names,
          help='dataset to be used: ' + ' | '.join(dataset_names) + ' (default: cifar10)')

args = parser.parse_args()

data_dir = os.path.join('.', args.dataset)

print(args.dataset)
args.dataset = 'cifar10'
if args.dataset == "cifar10":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(train_set.train_data.mean(axis=(0,1,2))/255)
  print(train_set.train_data.std(axis=(0,1,2))/255)

  # imshow image
  train_data = train_set.train_data
  ind = 100
  img0 = train_data[ind,...]
  ## test channel number, in total , the correct channel is : RGB,not like BGR in caffe
  # error produce
  #b,g,r=cv2.split(img0)
  #img0=cv2.merge([r,g,b])

  print(img0.shape)
  print(type(img0))
  plt.imshow(img0)
  plt.show() # in ship in sea

  #img0 = cv2.resize(img0,(224,224))
  #cv2.imshow('img0',img0)
  #cv2.waitKey()

elif args.dataset == "cifar100":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.CIFAR100(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(train_set.train_data.shape)
  print(np.mean(train_set.train_data, axis=(0,1,2))/255)
  print(np.std(train_set.train_data, axis=(0,1,2))/255)

elif args.dataset == "mnist":
  train_transform = transforms.Compose([transforms.ToTensor()])
  train_set = torchvision.datasets.MNIST(root=data_dir, train=True, download=True, transform=train_transform)
  #print(vars(train_set))
  print(list(train_set.train_data.size()))
  print(train_set.train_data.float().mean()/255)
  print(train_set.train_data.float().std()/255)

结果:

cifar10
Files already downloaded and verified
(50000, 32, 32, 3)
[ 0.49139968 0.48215841 0.44653091]
[ 0.24703223 0.24348513 0.26158784]
(32, 32, 3)
<class 'numpy.ndarray'>

使用matlab检测是如何计算mean_file和std_file的:

% load cifar10 dataset

data = load('cifar10_train_data.mat');
train_data = data.train_data;
disp(size(train_data));

temp = mean(train_data,1);
disp(size(temp));

train_data = double(train_data);

% compute mean_file 
mean_val = mean(mean(mean(train_data,1),2),3)/255;


% compute std_file 
temp1 = train_data(:,:,:,1);
std_val1 = std(temp1(:))/255;

temp2 = train_data(:,:,:,2);
std_val2 = std(temp2(:))/255;

temp3 = train_data(:,:,:,3);
std_val3 = std(temp3(:))/255;

mean_val = squeeze(mean_val);
std_val = [std_val1, std_val2, std_val3];

disp(mean_val);
disp(std_val);

% result: mean_val: [0.4914, 0.4822, 0.4465]
%     std_val: [0.2470, 0.2435, 0.2616]

均值计算的过程也可以遵循标准差的计算过程。为 了简单,例如对于一个矩阵,所有元素的均值,等于两个方向上先后均值。所以会直接采用如下的形式:

mean_val = mean(mean(mean(train_data,1),2),3)/255;

标准差的计算是每一个通道的对所有样本的求标准差。然后再除以255。

以上这篇Pytorch的mean和std调查实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python生成不重复随机值的方法
May 11 Python
python3抓取中文网页的方法
Jul 28 Python
Python构造自定义方法来美化字典结构输出的示例
Jun 16 Python
Python使用base64模块进行二进制数据编码详解
Jan 11 Python
python 按照固定长度分割字符串的方法小结
Apr 30 Python
python判断列表的连续数字范围并分块的方法
Nov 16 Python
详解Python并发编程之创建多线程的几种方法
Aug 23 Python
python将三维数组展开成二维数组的实现
Nov 30 Python
Python requests获取网页常用方法解析
Feb 20 Python
详解Python修复遥感影像条带的两种方式
Feb 23 Python
python3 googletrans超时报错问题及翻译工具优化方案 附源码
Dec 23 Python
使用pytorch实现线性回归
Apr 11 Python
pytorch 图像预处理之减去均值,除以方差的实例
Jan 02 #Python
Linux下升级安装python3.8并配置pip及yum的教程
Jan 02 #Python
pytorch实现focal loss的两种方式小结
Jan 02 #Python
pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
Jan 02 #Python
基于torch.where和布尔索引的速度比较
Jan 02 #Python
Python魔法方法 容器部方法详解
Jan 02 #Python
python 图像的离散傅立叶变换实例
Jan 02 #Python
You might like
Lumen timezone 时区设置方法(慢了8个小时)
2018/01/20 PHP
laravel中短信发送验证码的实现方法
2018/04/25 PHP
原生javascript实现隔行换色
2015/01/04 Javascript
JavaScript通过prototype给对象定义属性用法实例
2015/03/23 Javascript
Javascript非构造函数的继承
2015/04/27 Javascript
浅谈JS中String()与 .toString()的区别
2016/10/20 Javascript
jQuery加载及解析XML文件的方法实例分析
2017/01/22 Javascript
jQuery插件FusionCharts绘制的2D帕累托图效果示例【附demo源码】
2017/03/28 jQuery
vue绑定class与行间样式style详解
2017/08/16 Javascript
NodeJs通过async/await处理异步的方法
2017/10/09 NodeJs
浅谈JS函数节流防抖
2017/10/18 Javascript
浅谈vue项目如何打包扔向服务器
2018/05/08 Javascript
详解vue使用vue-layer-mobile组件实现toast,loading效果
2018/08/31 Javascript
[48:54]VGJ.T vs infamous Supermajor小组赛D组败者组第一轮 BO3 第二场 6.3
2018/06/04 DOTA
[07:20]2018DOTA2国际邀请赛寻真——逐梦Mineski
2018/08/10 DOTA
复制粘贴功能的Python程序
2008/04/04 Python
用Python实现一个简单的能够发送带附件的邮件程序的教程
2015/04/08 Python
Python+Wordpress制作小说站
2017/04/14 Python
批量获取及验证HTTP代理的Python脚本
2017/04/23 Python
Python构建XML树结构的方法示例
2017/06/30 Python
Python实现小数转化为百分数的格式化输出方法示例
2017/09/20 Python
Tensorflow 利用tf.contrib.learn建立输入函数的方法
2018/02/08 Python
pygame实现俄罗斯方块游戏(基础篇3)
2019/10/29 Python
python元组的概念知识点
2019/11/19 Python
python读写Excel表格的实例代码(简单实用)
2019/12/19 Python
Python的历史与优缺点整理
2020/05/26 Python
曼城官方网上商店:Manchester City
2019/09/10 全球购物
个人投资计划书
2014/05/01 职场文书
事业单位考核材料
2014/05/21 职场文书
作风建设剖析材料
2014/10/06 职场文书
工厂见习报告范文
2014/10/31 职场文书
学校2014年度工作总结
2014/12/06 职场文书
2015年小学教师培训工作总结
2015/07/21 职场文书
高中优秀作文(范文)
2019/08/15 职场文书
深入解析NumPy中的Broadcasting广播机制
2021/05/30 Python
Python os和os.path模块详情
2022/04/02 Python