pytorch中的numel函数用法说明


Posted in Python onMay 13, 2021

获取tensor中一共包含多少个元素

import torch
x = torch.randn(3,3)
print("number elements of x is ",x.numel())
y = torch.randn(3,10,5)
print("number elements of y is ",y.numel())

输出:

number elements of x is 9

number elements of y is 150

27和150分别位x和y中各有多少个元素或变量

补充:pytorch获取张量元素个数numel()的用法

numel就是"number of elements"的简写。

numel()可以直接返回int类型的元素个数

import torch 
a = torch.randn(1, 2, 3, 4)
b = a.numel()
print(type(b)) # int
print(b) # 24

通过numel()函数,我们可以迅速查看一个张量到底又多少元素。

补充:pytorch 卷积结构和numel()函数

看代码吧~

from torch import nn 
class CNN(nn.Module):
    def __init__(self, num_channels=1, d=56, s=12, m=4):
        super(CNN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=3, padding=5//2),
            nn.Conv2d(num_channels, d, kernel_size=(1,3), padding=5//2),
            nn.Conv2d(num_channels, d, kernel_size=(3,1), padding=5//2),
            nn.PReLU(d)
        )
 
    def forward(self, x):
        x = self.first_part(x)
        return x
 
model = CNN()
for m in model.first_part:
    if isinstance(m, nn.Conv2d):
        # print('m:',m.weight.data)
        print('m:',m.weight.data[0])
        print('m:',m.weight.data[0][0])
        print('m:',m.weight.data.numel()) #numel() 计算矩阵中元素的个数
 
结果:
m: tensor([[[-0.2822,  0.0128, -0.0244],
         [-0.2329,  0.1037,  0.2262],
         [ 0.2845, -0.3094,  0.1443]]]) #卷积核大小为3x3
m: tensor([[-0.2822,  0.0128, -0.0244],
        [-0.2329,  0.1037,  0.2262],
        [ 0.2845, -0.3094,  0.1443]]) #卷积核大小为3x3
m: 504   # = 56 x (3 x 3)  输出通道数为56,卷积核大小为3x3
m: tensor([-0.0335,  0.2945,  0.2512,  0.2770,  0.2071,  0.1133, -0.1883,  0.2738,
         0.0805,  0.1339, -0.3000, -0.1911, -0.1760,  0.2855, -0.0234, -0.0843,
         0.1815,  0.2357,  0.2758,  0.2689, -0.2477, -0.2528, -0.1447, -0.0903,
         0.1870,  0.0945, -0.2786, -0.0419,  0.1577, -0.3100, -0.1335, -0.3162,
        -0.1570,  0.3080,  0.0951,  0.1953,  0.1814, -0.1936,  0.1466, -0.2911,
        -0.1286,  0.3024,  0.1143, -0.0726, -0.2694, -0.3230,  0.2031, -0.2963,
         0.2965,  0.2525, -0.2674,  0.0564, -0.3277,  0.2185, -0.0476,  0.0558]) bias偏置的值
m: tensor([[[ 0.5747, -0.3421,  0.2847]]]) 卷积核大小为1x3
m: tensor([[ 0.5747, -0.3421,  0.2847]]) 卷积核大小为1x3
m: 168 # = 56 x (1 x 3) 输出通道数为56,卷积核大小为1x3
m: tensor([ 0.5328, -0.5711, -0.1945,  0.2844,  0.2012, -0.0084,  0.4834, -0.2020,
        -0.0941,  0.4683, -0.2386,  0.2781, -0.1812, -0.2990, -0.4652,  0.1228,
        -0.0627,  0.3112, -0.2700,  0.0825,  0.4345, -0.0373, -0.3220, -0.5038,
        -0.3166, -0.3823,  0.3947, -0.3232,  0.1028,  0.2378,  0.4589,  0.1675,
        -0.3112, -0.0905, -0.0705,  0.2763,  0.5433,  0.2768, -0.3804,  0.4855,
        -0.4880, -0.4555,  0.4143,  0.5474,  0.3305, -0.0381,  0.2483,  0.5133,
        -0.3978,  0.0407,  0.2351,  0.1910, -0.5385,  0.1340,  0.1811, -0.3008]) bias偏置的值
m: tensor([[[0.0184],
         [0.0981],
         [0.1894]]]) 卷积核大小为3x1
m: tensor([[0.0184],
        [0.0981],
        [0.1894]]) 卷积核大小为3x1
m: 168 # = 56 x (3 x 1) 输出通道数为56,卷积核大小为3x1
m: tensor([-0.2951, -0.4475,  0.1301,  0.4747, -0.0512,  0.2190,  0.3533, -0.1158,
         0.2237, -0.1407, -0.4756,  0.1637, -0.4555, -0.2157,  0.0577, -0.3366,
        -0.3252,  0.2807,  0.1660,  0.2949, -0.2886, -0.5216,  0.1665,  0.2193,
         0.2038, -0.1357,  0.2626,  0.2036,  0.3255,  0.2756,  0.1283, -0.4909,
         0.5737, -0.4322, -0.4930, -0.0846,  0.2158,  0.5565,  0.3751, -0.3775,
        -0.5096, -0.4520,  0.2246, -0.5367,  0.5531,  0.3372, -0.5593, -0.2780,
        -0.5453, -0.2863,  0.5712, -0.2882,  0.4788,  0.3222, -0.4846,  0.2170]) bias偏置的值
  
'''初始化后'''
class CNN(nn.Module):
    def __init__(self, num_channels=1, d=56, s=12, m=4):
        super(CNN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=3, padding=5//2),
            nn.Conv2d(num_channels, d, kernel_size=(1,3), padding=5//2),
            nn.Conv2d(num_channels, d, kernel_size=(3,1), padding=5//2),
            nn.PReLU(d)
        )
        self._initialize_weights()
    def _initialize_weights(self):
        for m in self.first_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)
 
    def forward(self, x):
        x = self.first_part(x)
        return x
 
model = CNN()
for m in model.first_part:
    if isinstance(m, nn.Conv2d):
        # print('m:',m.weight.data)
        print('m:',m.weight.data[0])
        print('m:',m.weight.data[0][0])
        print('m:',m.weight.data.numel()) #numel() 计算矩阵中元素的个数
 
结果:
m: tensor([[[-0.0284, -0.0585,  0.0271],
         [ 0.0125,  0.0554,  0.0511],
         [-0.0106,  0.0574, -0.0053]]])
m: tensor([[-0.0284, -0.0585,  0.0271],
        [ 0.0125,  0.0554,  0.0511],
        [-0.0106,  0.0574, -0.0053]])
m: 504
m: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])
m: tensor([[[ 0.0059,  0.0465, -0.0725]]])
m: tensor([[ 0.0059,  0.0465, -0.0725]])
m: 168
m: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])
m: tensor([[[ 0.0599],
         [-0.1330],
         [ 0.2456]]])
m: tensor([[ 0.0599],
        [-0.1330],
        [ 0.2456]])
m: 168
m: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。如有错误或未考虑完全的地方,望不吝赐教。

Python 相关文章推荐
Python查询阿里巴巴关键字排名的方法
Jul 08 Python
如何处理Python3.4 使用pymssql 乱码问题
Jan 08 Python
python操作excel的方法(xlsxwriter包的使用)
Jun 11 Python
Flask之flask-session的具体使用
Jul 26 Python
基于python实现学生管理系统
Oct 17 Python
python-opencv颜色提取分割方法
Dec 08 Python
python使用Plotly绘图工具绘制柱状图
Apr 01 Python
python生成器推导式用法简单示例
Oct 08 Python
python提取xml里面的链接源码详解
Oct 15 Python
Python函数参数分类原理详解
May 28 Python
基于Python 的语音重采样函数解析
Jul 06 Python
python高级特性简介
Aug 13 Python
pytorch损失反向传播后梯度为none的问题
如何使用Python实现一个简易的ORM模型
May 12 #Python
用python删除文件夹中的重复图片(图片去重)
May 12 #Python
Pyhton模块和包相关知识总结
python 下划线的多种应用场景总结
May 12 #Python
超级详细实用的pycharm常用快捷键
pycharm 如何查看某一函数源码的快捷键
You might like
PHP 修复未正常关闭的HTML标签实现代码(支持嵌套和就近闭合)
2012/06/07 PHP
PHP常用的排序和查找算法
2015/08/06 PHP
laravel 多图上传及图片的存储例子
2019/10/14 PHP
extjs 的权限问题 要求控制的对象是 菜单,按钮,URL
2010/03/09 Javascript
javascript下对于事件、事件流、事件触发的顺序随便说说
2010/07/17 Javascript
读jQuery之十 事件模块概述
2011/06/27 Javascript
关于js注册事件的常用方法
2013/04/03 Javascript
JS 实现导航栏悬停效果(续2)
2013/09/24 Javascript
js如何判断用户是否是用微信浏览器
2014/06/05 Javascript
Javascript获取CSS伪元素属性的实现代码
2014/09/28 Javascript
png在IE6 下无法透明的解决方法汇总
2015/05/21 Javascript
jQuery动态背景图片效果实现方法
2015/07/03 Javascript
理解JS事件循环
2016/01/07 Javascript
Javascript中匿名函数的调用与写法实例详解(多种)
2016/01/26 Javascript
AngularJS转换响应内容
2016/01/27 Javascript
浅谈jquery点击label触发2次的问题
2016/06/12 Javascript
JS获取随机数和时间转换的简单实例
2016/07/10 Javascript
js实现目录链接,内容跟着目录滚动显示的简单实例
2016/10/15 Javascript
原生js更改css样式的两种方式
2017/03/15 Javascript
最全正则表达式总结:验证QQ号、手机号、Email、中文、邮编、身份证、IP地址等
2017/08/16 Javascript
Vue.js移动端左滑删除组件的实现代码
2017/09/08 Javascript
js 索引下标之li集合绑定点击事件
2018/01/12 Javascript
vue页面加载闪烁问题的解决方法
2018/03/28 Javascript
详解angular部署到iis出现404解决方案
2018/08/14 Javascript
深入理解js A*寻路算法原理与具体实现过程
2018/12/13 Javascript
微信小程序与公众号实现数据互通的方法
2019/07/25 Javascript
python批量导出导入MySQL用户的方法
2013/11/15 Python
一文了解Python并发编程的工程实现方法
2019/05/31 Python
Python 实现大整数乘法算法的示例代码
2019/09/17 Python
使用matplotlib动态刷新指定曲线实例
2020/04/23 Python
Python Json数据文件操作原理解析
2020/05/09 Python
Expedia印度:您的一站式在线旅游网站
2017/08/24 全球购物
德国足球商店:OUTFITTER
2019/05/06 全球购物
大学生职业生涯规划书参考模板
2014/03/05 职场文书
财务部岗位职责范本
2015/04/14 职场文书
上课迟到检讨书范文
2015/05/06 职场文书