关于pytorch处理类别不平衡的问题


Posted in Python onDecember 31, 2019

当训练样本不均匀时,我们可以采用过采样、欠采样、数据增强等手段来避免过拟合。今天遇到一个3d点云数据集合,样本分布极不均匀,正例与负例相差4-5个数量级。数据增强效果就不会太好了,另外过采样也不太合适,因为是空间数据,新增的点有可能会对真实分布产生未知影响。所以采用欠采样来缓解类别不平衡的问题。

下面的代码展示了如何使用WeightedRandomSampler来完成抽样。

numDataPoints = 1000
data_dim = 5
bs = 100

# Create dummy data with class imbalance 9 to 1
data = torch.FloatTensor(numDataPoints, data_dim)
target = np.hstack((np.zeros(int(numDataPoints * 0.9), dtype=np.int32),
     np.ones(int(numDataPoints * 0.1), dtype=np.int32)))

print 'target train 0/1: {}/{}'.format(
 len(np.where(target == 0)[0]), len(np.where(target == 1)[0]))

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

target = torch.from_numpy(target).long()
train_dataset = torch.utils.data.TensorDataset(data, target)

train_loader = DataLoader(
 train_dataset, batch_size=bs, num_workers=1, sampler=sampler)

for i, (data, target) in enumerate(train_loader):
 print "batch index {}, 0/1: {}/{}".format(
  i,
  len(np.where(target.numpy() == 0)[0]),
  len(np.where(target.numpy() == 1)[0]))

核心部分为实际使用时替换下变量把sampler传递给DataLoader即可,注意使用了sampler就不能使用shuffle,另外需要指定采样点个数:

class_sample_count = np.array(
 [len(np.where(target == t)[0]) for t in np.unique(target)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in target])

samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

参考:https://discuss.pytorch.org/t/how-to-handle-imbalanced-classes/11264/2

以上这篇关于pytorch处理类别不平衡的问题就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
合并Excel工作薄中成绩表的VBA代码,非常适合教育一线的朋友
Apr 09 Python
python 实现插入排序算法
Jun 05 Python
ubuntu系统下 python链接mysql数据库的方法
Jan 09 Python
numpy.random.seed()的使用实例解析
Feb 03 Python
python使用jieba实现中文分词去停用词方法示例
Mar 11 Python
pandas将numpy数组写入到csv的实例
Jul 04 Python
Django实现表单验证
Sep 08 Python
Python3爬虫学习之将爬取的信息保存到本地的方法详解
Dec 12 Python
Python:slice与indices的用法
Nov 25 Python
通用的Django注册功能模块实现方法
Feb 05 Python
asyncio异步编程之Task对象详解
Mar 13 Python
Python读取和写入Excel数据
Apr 20 Python
pytorch 指定gpu训练与多gpu并行训练示例
Dec 31 #Python
浅析Django中关于session的使用
Dec 30 #Python
使用pickle存储数据dump 和 load实例讲解
Dec 30 #Python
在Python中利用pickle保存变量的实例
Dec 30 #Python
python Popen 获取输出,等待运行完成示例
Dec 30 #Python
Python3常见函数range()用法详解
Dec 30 #Python
Python Pickle 实现在同一个文件中序列化多个对象
Dec 30 #Python
You might like
Laravel 4 初级教程之Pages、表单验证
2014/10/30 PHP
php实现paypal 授权登录
2015/05/28 PHP
PHP+AjaxForm异步带进度条上传文件实例代码
2017/08/14 PHP
PHP实现统计所有字符在字符串中出现次数的方法
2017/10/17 PHP
php 删除一维数组中某一个值元素的操作方法
2018/02/01 PHP
解析使用js判断只能输入数字、字母等验证的方法(总结)
2013/05/14 Javascript
当某个文本框成为焦点时即清除文本框内容
2014/04/28 Javascript
Web 开发中Ajax的Session 超时处理方法
2017/01/19 Javascript
js实现百度登录框鼠标拖拽效果
2017/03/07 Javascript
Bootstrap + AngularJS 实现简单的数据过滤字符查找功能
2017/07/27 Javascript
基于js 本地存储(详解)
2017/08/16 Javascript
在vue中添加Echarts图表的基本使用教程
2017/11/22 Javascript
实例讲解Vue.js中router传参
2018/04/22 Javascript
详解js类型判断
2018/05/22 Javascript
微信小程序实现slideUp、slideDown滑动效果及点击空白隐藏功能示例
2018/12/11 Javascript
element-ui多文件上传的实现示例
2019/04/10 Javascript
layui时间控件选择时间范围的实现方法
2019/09/28 Javascript
js实现图片上传到服务器和回显
2020/01/19 Javascript
详解webpack-dev-middleware 源码解读
2020/03/23 Javascript
Python使用MySQLdb for Python操作数据库教程
2014/10/11 Python
Python中的字符串操作和编码Unicode详解
2017/01/18 Python
解决项目pycharm能运行,在终端却无法运行的问题
2019/01/19 Python
python ddt数据驱动最简实例代码
2019/02/22 Python
利用Django模版生成树状结构实例代码
2019/05/19 Python
Python 文件数据读写的具体实现
2020/01/24 Python
基于Python fminunc 的替代方法
2020/02/29 Python
python实例化对象的具体方法
2020/06/17 Python
Numpy中np.random.rand()和np.random.randn() 用法和区别详解
2020/10/23 Python
美国高街时尚品牌:OASAP
2016/07/24 全球购物
税务干部鉴定材料
2014/02/11 职场文书
企业安全生产目标责任书
2014/07/23 职场文书
老兵退伍标语
2014/10/07 职场文书
工作失职检讨书(精华篇)
2014/10/15 职场文书
2015年大学生社会实践评语
2015/03/26 职场文书
《有余数的除法》教学反思
2016/02/22 职场文书
CSS3实现指纹特效代码
2022/03/17 HTML / CSS