Pytorch中的数据集划分&正则化方法


Posted in Python onMay 27, 2021

1.训练集&验证集&测试集

训练集:训练数据

验证集:验证不同算法(比如利用网格搜索对超参数进行调整等),检验哪种更有效

测试集:正确评估分类器的性能

正常流程:验证集会记录每个时间戳的参数,在加载test数据前会加载那个最好的参数,再来评估。比方说训练完6000个epoch后,发现在第3520个epoch的validation表现最好,测试时会加载第3520个epoch的参数。

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms
#超参数
batch_size=200
learning_rate=0.01
epochs=10
#获取训练数据
train_db = datasets.MNIST('../data', train=True, download=True,   #train=True则得到的是训练集
                   transform=transforms.Compose([                 #transform进行数据预处理
                       transforms.ToTensor(),                     #转成Tensor类型的数据
                       transforms.Normalize((0.1307,), (0.3081,)) #进行数据标准化(减去均值除以方差)
                   ]))
#DataLoader把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化
train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True)
#获取测试数据
test_db = datasets.MNIST('../data', train=False,
                   transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                   ]))
test_loader = torch.utils.data.DataLoader(test_db, batch_size=batch_size, shuffle=True)
#将训练集拆分成训练集和验证集
print('train:', len(train_db), 'test:', len(test_db))                              #train: 60000 test: 10000
train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
print('db1:', len(train_db), 'db2:', len(val_db))                                  #db1: 50000 db2: 10000
train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_db, batch_size=batch_size, shuffle=True)
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.model = nn.Sequential(         #定义网络的每一层,
            nn.Linear(784, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 10),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        x = self.model(x)
        return x
net = MLP()
#定义sgd优化器,指明优化参数、学习率,net.parameters()得到这个类所定义的网络的参数[[w1,b1,w2,b2,...]
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss()
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)          #将二维的图片数据摊平[样本数,784]
        logits = net(data)                   #前向传播
        loss = criteon(logits, target)       #nn.CrossEntropyLoss()自带Softmax
        optimizer.zero_grad()                #梯度信息清空
        loss.backward()                      #反向传播获取梯度
        optimizer.step()                     #优化器更新
        if batch_idx % 100 == 0:             #每100个batch输出一次信息
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
    #验证集用来检测训练是否过拟合
    val_loss = 0
    correct = 0
    for data, target in val_loader:
        data = data.view(-1, 28 * 28)
        logits = net(data)
        val_loss += criteon(logits, target).item()
        pred = logits.data.max(dim=1)[1]
        correct += pred.eq(target.data).sum()
    val_loss /= len(val_loader.dataset)
    print('\nVAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        val_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))
#测试集用来评估
test_loss = 0
correct = 0                                         #correct记录正确分类的样本数
for data, target in test_loader:
    data = data.view(-1, 28 * 28)
    logits = net(data)
    test_loss += criteon(logits, target).item()     #其实就是criteon(logits, target)的值,标量
    pred = logits.data.max(dim=1)[1]                #也可以写成pred=logits.argmax(dim=1)
    correct += pred.eq(target.data).sum()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

2.正则化

正则化可以解决过拟合问题。

2.1L2范数(更常用)

在定义优化器的时候设定weigth_decay,即L2范数前面的λ参数。

optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)

2.2L1范数

Pytorch没有直接可以调用的方法,实现如下:

Pytorch中的数据集划分&正则化方法

3.动量(Momentum)

Adam优化器内置了momentum,SGD需要手动设置。

optimizer = torch.optim.SGD(model.parameters(), args=lr, momentum=args.momentum, weight_decay=args.weight_decay)

4.学习率衰减

torch.optim.lr_scheduler 中提供了基于多种epoch数目调整学习率的方法。

4.1torch.optim.lr_scheduler.ReduceLROnPlateau:基于测量指标对学习率进行动态的下降

torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)

训练过程中,optimizer会把learning rate 交给scheduler管理,当指标(比如loss)连续patience次数还没有改进时,需要降低学习率,factor为每次下降的比例。

scheduler.step(loss_val)每调用一次就会监听一次loss_val。

Pytorch中的数据集划分&正则化方法

4.2torch.optim.lr_scheduler.StepLR:基于epoch

torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)

当epoch每过stop_size时,学习率都变为初始学习率的gamma倍。

Pytorch中的数据集划分&正则化方法

5.提前停止(防止overfitting)

基于经验值。

6.Dropout随机失活

遍历每一层,设置消除神经网络中的节点概率,得到精简后的一个样本。

torch.nn.Dropout(p=dropout_prob)

p表示的示的是删除节点数的比例(Tip:tensorflow中keep_prob表示保留节点数的比例,不要混淆)

Pytorch中的数据集划分&正则化方法

测试阶段无需使用dropout,所以在train之前执行net_dropped.train()相当于启用dropout,测试之前执行net_dropped.eval()相当于不启用dropout。

Pytorch中的数据集划分&正则化方法

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
使用python实现接口的方法
Jul 07 Python
python基础之入门必看操作
Jul 26 Python
Python对象中__del__方法起作用的条件详解
Nov 01 Python
对Python的zip函数妙用,旋转矩阵详解
Dec 13 Python
详解Python计算机视觉 图像扭曲(仿射扭曲)
Mar 27 Python
django框架事务处理小结【ORM 事务及raw sql,customize sql 事务处理】
Jun 27 Python
python数据处理——对pandas进行数据变频或插值实例
Apr 22 Python
利用Vscode进行Python开发环境配置的步骤
Jun 22 Python
Python中使用aiohttp模拟服务器出现错误问题及解决方法
Oct 31 Python
python输出国际象棋棋盘的实例分享
Nov 26 Python
python实现简单区块链结构
Apr 25 Python
Python爬虫入门案例之回车桌面壁纸网美女图片采集
Oct 16 Python
Pytorch 如何实现常用正则化
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
You might like
DedeCms模板安装/制作概述
2007/03/11 PHP
php 调试利器debug_print_backtrace()
2012/07/23 PHP
php去除字符串换行符示例分享
2014/02/13 PHP
PHP中4种常用的抓取网络数据方法
2015/06/04 PHP
动态加载图片路径 保持JavaScript控件的相对独立性
2010/09/03 Javascript
事件绑定之小测试  onclick && addEventListener
2011/07/31 Javascript
电子商务网站上的常用的js放大镜效果
2011/12/08 Javascript
一个JQuery操作Table的代码分享
2012/03/30 Javascript
js 3种归并操作的实例代码
2013/10/30 Javascript
javascript对下拉列表框(select)的操作实例讲解
2013/11/29 Javascript
动态加载jquery库的方法
2014/02/12 Javascript
jquery实现弹出层遮罩效果的简单实例
2014/03/03 Javascript
全面解析Bootstrap排版使用方法(文字样式)
2015/11/30 Javascript
javascript设计模式之module(模块)模式
2016/08/19 Javascript
JavaScript 身份证号有效验证详解及实例代码
2016/10/20 Javascript
vue 使用自定义指令实现表单校验的方法
2018/08/28 Javascript
webpack@v4升级踩坑(小结)
2018/10/08 Javascript
移动端如何用下拉刷新的方式实现上拉加载
2018/12/10 Javascript
Vue3新特性之在Composition API中使用CSS Modules
2020/07/13 Javascript
原生JS生成指定位数的验证码
2020/10/28 Javascript
[02:27]2018DOTA2亚洲邀请赛赛前采访-OpTic
2018/04/03 DOTA
python实现12306火车票查询器
2017/04/20 Python
django实现登录时候输入密码错误5次锁定用户十分钟
2017/11/05 Python
解决PyCharm import torch包失败的问题
2018/10/13 Python
Python 微信爬虫完整实例【单线程与多线程】
2019/07/06 Python
python中调试或排错的五种方法示例
2019/09/12 Python
CSS3实现可爱的小黄人动画
2016/07/11 HTML / CSS
通过HTML5 Canvas API绘制弧线和圆形的教程
2016/03/14 HTML / CSS
苏格兰在线威士忌商店:The Whisky Barrel
2019/05/07 全球购物
Craghoppers德国官网:户外和旅行服装
2020/02/14 全球购物
美国在线购买内衣网站:HerRoom
2020/02/22 全球购物
酒店秘书求职信范文
2014/02/17 职场文书
店铺转让协议书(2014版)
2014/09/23 职场文书
公务员检讨书
2014/11/01 职场文书
老干部局2015年度工作总结
2015/10/22 职场文书
Java详细解析==和equals的区别
2022/04/07 Java/Android