Pytorch中accuracy和loss的计算知识点总结


Posted in Python onSeptember 10, 2019

这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚。

给出实例

def train(train_loader, model, criteon, optimizer, epoch):
  train_loss = 0
  train_acc = 0
  num_correct= 0
  for step, (x,y) in enumerate(train_loader):

    # x: [b, 3, 224, 224], y: [b]
    x, y = x.to(device), y.to(device)

    model.train()
    logits = model(x)
    loss = criteon(logits, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += float(loss.item())
    train_losses.append(train_loss)
    pred = logits.argmax(dim=1)
    num_correct += torch.eq(pred, y).sum().float().item()
  logger.info("Train Epoch: {}\t Loss: {:.6f}\t Acc: {:.6f}".format(epoch,train_loss/len(train_loader),num_correct/len(train_loader.dataset)))
  return num_correct/len(train_loader.dataset), train_loss/len(train_loader)

首先这样一次训练称为一个epoch,样本总数/batchsize是走完一个epoch所需的“步数”,相对应的,len(train_loader.dataset)也就是样本总数,len(train_loader)就是这个步数。

那么,accuracy的计算也就是在整个train_loader的for循环中(步数),把每个mini_batch中判断正确的个数累加起来,然后除以样本总数就行了;

而loss的计算有讲究了,首先在这里我们是计算交叉熵,关于交叉熵,也就是涉及到两个值,一个是模型给出的logits,也就是10个类,每个类的概率分布,另一个是样本自身的

label,在Pytorch中,只要把这两个值输进去就能计算交叉熵,用的方法是nn.CrossEntropyLoss,这个方法其实是计算了一个minibatch的均值了,因此累加以后需要除以的步数,也就是

minibatch的个数,而不是像accuracy那样是样本个数,这一点非常重要。

以上就是本次介绍的全部知识点内容,感谢大家对三水点靠木的支持。

Python 相关文章推荐
Python解析树及树的遍历
Feb 03 Python
Python MD5加密实例详解
Aug 02 Python
python中如何正确使用正则表达式的详细模式(Verbose mode expression)
Nov 08 Python
Python3中正则模块re.compile、re.match及re.search函数用法详解
Jun 11 Python
flask session组件的使用示例
Dec 25 Python
解决python中画图时x,y轴名称出现中文乱码的问题
Jan 29 Python
python3正则提取字符串里的中文实例
Jan 31 Python
解决python执行不输出系统命令弹框的问题
Jun 24 Python
解决ROC曲线画出来只有一个点的问题
Feb 28 Python
Django 解决开发自定义抛出异常的问题
May 21 Python
python实现简单贪吃蛇游戏
Sep 29 Python
详解BeautifulSoup获取特定标签下内容的方法
Dec 07 Python
python3.7环境下安装Anaconda的教程图解
Sep 10 #Python
Windows10下 python3.7 安装 facenet的教程
Sep 10 #Python
python 图像处理画一个正弦函数代码实例
Sep 10 #Python
Python操作Mongodb数据库的方法小结
Sep 10 #Python
Python使用matplotlib绘制三维参数曲线操作示例
Sep 10 #Python
Python matplotlib绘制饼状图功能示例
Sep 10 #Python
numpy.random.shuffle打乱顺序函数的实现
Sep 10 #Python
You might like
PHP抓取淘宝商品的用户晒单评论+图片+搜索商品列表实例
2016/04/14 PHP
Centos7.7 64位利用本地完整安装包安装lnmp/lamp套件教程
2021/03/09 Servers
js获取元素在浏览器中的绝对位置
2010/07/24 Javascript
jQuery的写法不同导致的兼容性问题的解决方法
2010/07/29 Javascript
jQuery $.get 的妙用 访问本地文本文件
2012/07/12 Javascript
Js判断CSS文件加载完毕的具体实现
2014/01/17 Javascript
JS(JQuery)操作Array的相关方法介绍
2014/02/11 Javascript
js无刷新操作table的行和列
2014/03/27 Javascript
nodejs开发环境配置与使用
2014/11/17 NodeJs
JavaScript中的条件判断语句使用详解
2015/06/03 Javascript
AngularJS 实现弹性盒子布局的方法
2016/08/30 Javascript
jQuery实现导航回弹效果
2017/02/27 Javascript
JS实现求数组起始项到终止项之和的方法【基于数组扩展函数】
2017/06/13 Javascript
Vue组件之自定义事件的功能图解
2018/02/01 Javascript
微信小程序导航栏滑动定位功能示例(实现CSS3的positionsticky效果)
2019/01/24 Javascript
JS数据类型分类及常用判断方法
2020/11/19 Javascript
[48:35]2018DOTA2亚洲邀请赛 4.1 小组赛 A组加赛 TNC vs Optic
2018/04/03 DOTA
[48:26]VGJ.S vs infamous Supermajor 败者组 BO3 第二场 6.4
2018/06/05 DOTA
让python json encode datetime类型
2010/12/28 Python
Python基于pillow判断图片完整性的方法
2016/09/18 Python
flask使用session保存登录状态及拦截未登录请求代码
2018/01/19 Python
python使用Pycharm创建一个Django项目
2018/03/05 Python
Python实现合并两个列表的方法分析
2018/05/28 Python
pycharm 将django中多个app放到同个文件夹apps的处理方法
2018/05/30 Python
python 筛选数据集中列中value长度大于20的数据集方法
2018/06/14 Python
python实现学生成绩测评系统
2020/06/22 Python
美国第一香水网站:Perfume.com
2017/01/23 全球购物
AJAX的全称是什么
2012/11/06 面试题
院药学专业个人求职信
2013/09/21 职场文书
秋季婚礼证婚词
2014/01/11 职场文书
群众路线教育实践活动心得体会(教师)
2014/10/31 职场文书
班主任先进事迹材料
2014/12/17 职场文书
社区党员干部承诺书
2015/05/04 职场文书
高中化学教学反思
2016/02/22 职场文书
经典《舰娘》游改全新动画预告 预定11月开播
2022/04/01 日漫
TypeScript 使用 Tuple Union 声明函数重载
2022/04/07 Javascript