Pytorch中accuracy和loss的计算知识点总结


Posted in Python onSeptember 10, 2019

这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚。

给出实例

def train(train_loader, model, criteon, optimizer, epoch):
  train_loss = 0
  train_acc = 0
  num_correct= 0
  for step, (x,y) in enumerate(train_loader):

    # x: [b, 3, 224, 224], y: [b]
    x, y = x.to(device), y.to(device)

    model.train()
    logits = model(x)
    loss = criteon(logits, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += float(loss.item())
    train_losses.append(train_loss)
    pred = logits.argmax(dim=1)
    num_correct += torch.eq(pred, y).sum().float().item()
  logger.info("Train Epoch: {}\t Loss: {:.6f}\t Acc: {:.6f}".format(epoch,train_loss/len(train_loader),num_correct/len(train_loader.dataset)))
  return num_correct/len(train_loader.dataset), train_loss/len(train_loader)

首先这样一次训练称为一个epoch,样本总数/batchsize是走完一个epoch所需的“步数”,相对应的,len(train_loader.dataset)也就是样本总数,len(train_loader)就是这个步数。

那么,accuracy的计算也就是在整个train_loader的for循环中(步数),把每个mini_batch中判断正确的个数累加起来,然后除以样本总数就行了;

而loss的计算有讲究了,首先在这里我们是计算交叉熵,关于交叉熵,也就是涉及到两个值,一个是模型给出的logits,也就是10个类,每个类的概率分布,另一个是样本自身的

label,在Pytorch中,只要把这两个值输进去就能计算交叉熵,用的方法是nn.CrossEntropyLoss,这个方法其实是计算了一个minibatch的均值了,因此累加以后需要除以的步数,也就是

minibatch的个数,而不是像accuracy那样是样本个数,这一点非常重要。

以上就是本次介绍的全部知识点内容,感谢大家对三水点靠木的支持。

Python 相关文章推荐
Python中的推导式使用详解
Jun 03 Python
Python的装饰器用法学习笔记
Jun 24 Python
Python实现遍历目录的方法【测试可用】
Mar 22 Python
Python利用QQ邮箱发送邮件的实现方法(分享)
Jun 09 Python
解析Python中的eval()、exec()及其相关函数
Dec 20 Python
Python实现简易版的Web服务器(推荐)
Jan 29 Python
使用coverage统计python web项目代码覆盖率的方法详解
Aug 05 Python
对Pytorch中Tensor的各种池化操作解析
Jan 03 Python
解决Python Matplotlib绘图数据点位置错乱问题
May 16 Python
PySide2出现“ImportError: DLL load failed: 找不到指定的模块”的问题及解决方法
Jun 10 Python
Python 游戏大作炫酷机甲闯关游戏爆肝数千行代码实现案例进阶
Oct 16 Python
用Python仅20行代码编写一个简单的端口扫描器
Apr 08 Python
python3.7环境下安装Anaconda的教程图解
Sep 10 #Python
Windows10下 python3.7 安装 facenet的教程
Sep 10 #Python
python 图像处理画一个正弦函数代码实例
Sep 10 #Python
Python操作Mongodb数据库的方法小结
Sep 10 #Python
Python使用matplotlib绘制三维参数曲线操作示例
Sep 10 #Python
Python matplotlib绘制饼状图功能示例
Sep 10 #Python
numpy.random.shuffle打乱顺序函数的实现
Sep 10 #Python
You might like
php,不用COM,生成excel文件
2006/10/09 PHP
PHP中常用的转义函数
2014/02/28 PHP
CentOS安装php v8js教程
2015/02/26 PHP
PHP面向对象程序设计实例分析
2016/01/26 PHP
PHP实现生成模糊图片的方法示例
2017/12/21 PHP
利用jQuery的deferred对象实现异步按顺序加载JS文件
2013/03/17 Javascript
JS预览图像将本地图片显示到浏览器上
2013/08/25 Javascript
js实现黑色简易的滑动门网页tab选项卡效果
2015/08/31 Javascript
String字符串截取的四种方式总结
2016/11/28 Javascript
详解微信小程序开发之城市选择器 城市切换
2017/01/17 Javascript
Angular中$state.go页面跳转并传递参数的方法
2017/05/09 Javascript
详解使用angular-cli发布i18n多国语言Angular应用
2017/05/20 Javascript
JavaScript基于扩展String实现替换字符串中index处字符的方法
2017/06/13 Javascript
vue组件父子间通信详解(三)
2017/11/07 Javascript
解决vue中无法动态修改jqgrid组件 url地址的问题
2018/03/01 Javascript
vue 组件中slot插口的具体用法
2018/04/03 Javascript
浅谈PDF.js使用心得
2018/06/07 Javascript
微信小程序引用iconfont图标的方法
2018/10/22 Javascript
nodejs中各种加密算法的实现详解
2019/07/11 NodeJs
Node.js系列之安装配置与基本使用(1)
2019/08/30 Javascript
nodejs+koa2 实现模仿springMVC框架
2020/10/21 NodeJs
[55:32]2018DOTA2亚洲邀请赛 4.4 淘汰赛 EG vs LGD 第二场
2018/04/05 DOTA
[01:04:14]OG vs Winstrike 2018国际邀请赛小组赛BO2 第二场 8.19
2018/08/21 DOTA
python利用paramiko连接远程服务器执行命令的方法
2017/10/16 Python
Python实现的爬取小说爬虫功能示例
2019/03/30 Python
详解如何用TensorFlow训练和识别/分类自定义图片
2019/08/05 Python
Django app配置多个数据库代码实例
2019/12/17 Python
Django 解决distinct无法去除重复数据的问题
2020/05/20 Python
美国著名的家居用品购物网站:Bed Bath & Beyond
2018/01/05 全球购物
幼儿园招生广告
2014/03/19 职场文书
演讲稿格式
2014/04/30 职场文书
局机关干部群众路线个人对照检查材料思想汇报
2014/10/05 职场文书
新郎新娘答谢词
2015/01/04 职场文书
公司员工体检通知
2015/04/21 职场文书
java如何实现socket连接方法封装
2021/09/25 Java/Android
SQL Server 忘记密码以及重新添加新账号
2022/04/26 SQL Server