在Pytorch中使用样本权重(sample_weight)的正确方法


Posted in Python onAugust 17, 2019

step:

1.将标签转换为one-hot形式。

2.将每一个one-hot标签中的1改为预设样本权重的值

即可在Pytorch中使用样本权重。

eg:

对于单个样本:loss = - Q * log(P),如下:

P = [0.1,0.2,0.4,0.3]
Q = [0,0,1,0]
loss = -Q * np.log(P)

增加样本权重则为loss = - Q * log(P) *sample_weight

P = [0.1,0.2,0.4,0.3]
Q = [0,0,sample_weight,0]
loss_samle_weight = -Q * np.log(P)

在pytorch中示例程序

train_data = np.load(open('train_data.npy','rb'))
train_labels = []
for i in range(8):
  train_labels += [i] *100
train_labels = np.array(train_labels)
train_labels = to_categorical(train_labels).astype("float32")
sample_1 = [random.random() for i in range(len(train_data))]
for i in range(len(train_data)):
  floor = i / 100
  train_labels[i][floor] = sample_1[i]
train_data = torch.from_numpy(train_data) 
train_labels = torch.from_numpy(train_labels) 
dataset = dataf.TensorDataset(train_data,train_labels) 
trainloader = dataf.DataLoader(dataset, batch_size=batch_size, shuffle=True)

对应one-target的多分类交叉熵损失函数如下:

def my_loss(outputs, targets):
  
  output2 = outputs - torch.max(outputs, 1, True)[0]
 
 
  P = torch.exp(output2) / torch.sum(torch.exp(output2), 1,True) + 1e-10
 
 
  loss = -torch.mean(targets * torch.log(P))
 
 
  return loss

以上这篇在Pytorch中使用样本权重(sample_weight)的正确方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python数组复制拷贝的实现方法
Jun 09 Python
Python+matplotlib实现华丽的文本框演示代码
Jan 22 Python
Python3 XML 获取雅虎天气的实现方法
Feb 01 Python
python实现决策树分类(2)
Aug 30 Python
Python正则表达式和元字符详解
Nov 29 Python
Python设计模式之解释器模式原理与用法实例分析
Jan 10 Python
Python range、enumerate和zip函数用法详解
Sep 11 Python
tensorflow2.0与tensorflow1.0的性能区别介绍
Feb 07 Python
Django Admin后台添加数据库视图过程解析
Apr 01 Python
Python使用Excel将数据写入多个sheet
May 16 Python
解决使用Pandas 读取超过65536行的Excel文件问题
Nov 10 Python
python反编译教程之2048小游戏实例
Mar 03 Python
获取Pytorch中间某一层权重或者特征的例子
Aug 17 #Python
pyenv与virtualenv安装实现python多版本多项目管理
Aug 17 #Python
pytorch 获取层权重,对特定层注入hook, 提取中间层输出的方法
Aug 17 #Python
关于PyTorch源码解读之torchvision.models
Aug 17 #Python
django项目用higcharts统计最近七天文章点击量
Aug 17 #Python
Django对models里的objects的使用详解
Aug 17 #Python
python3.6中@property装饰器的使用方法示例
Aug 17 #Python
You might like
php 中文字符入库或显示乱码问题的解决方法
2010/04/12 PHP
PHP常用函数和常见疑难问题解答
2014/03/05 PHP
Codeigniter实现智能裁剪图片的方法
2014/06/12 PHP
php简单实现MVC
2015/02/05 PHP
php中namespace use用法实例分析
2016/01/22 PHP
PHP XML Expat解析器知识点总结
2019/02/15 PHP
laravel异步监控定时调度器实例详解
2019/06/21 PHP
用JavaScript脚本实现Web页面信息交互
2006/10/11 Javascript
YUI 读码日记之 YAHOO.util.Dom - Part.1
2008/03/22 Javascript
jquery简单瀑布流实现原理及ie8下测试代码
2013/01/23 Javascript
javascript使用百度地图api和html5特性获取浏览器位置
2014/01/10 Javascript
jQuery Validate初步体验(一)
2015/12/12 Javascript
详解JavaScript数组和字符串中去除重复值的方法
2016/03/07 Javascript
Javascript的无new构建实例详解
2016/05/15 Javascript
Javascript中字符串replace方法的第二个参数探究
2016/12/05 Javascript
js实现的在线调色板功能完整实例
2016/12/21 Javascript
如何提高数据访问速度
2016/12/26 Javascript
强大的 Angular 表单验证功能详细介绍
2017/05/23 Javascript
详解ES6中的代理模式——Proxy
2018/01/08 Javascript
nuxt框架中路由鉴权之Koa和Session的用法
2018/05/09 Javascript
Vue+element-ui 实现表格的分页功能示例
2018/08/18 Javascript
在AngularJs中设置请求头信息(headers)的方法及不同方法的比较
2018/09/04 Javascript
vue将毫秒数转化为正常日期格式的实例
2018/09/16 Javascript
Python中的Matplotlib模块入门教程
2015/04/15 Python
Python实现简单截取中文字符串的方法
2015/06/15 Python
Python的re模块正则表达式操作
2016/05/25 Python
Python 查看list中是否含有某元素的方法
2018/06/27 Python
python抓取网页内容并进行语音播报的方法
2018/12/24 Python
如何基于Python创建目录文件夹
2019/12/31 Python
python 爬虫 实现增量去重和定时爬取实例
2020/02/28 Python
Python Tornado实现WEB服务器Socket服务器共存并实现交互的方法
2020/05/26 Python
HTML5 video标签(播放器)学习笔记(一):使用入门
2015/04/24 HTML / CSS
来自世界各地的饮料:Flavourly
2019/05/06 全球购物
思想政治自我鉴定
2013/10/06 职场文书
《七月的天山》教学反思
2016/02/19 职场文书
为什么中国式养孩子很累?
2019/08/07 职场文书