Pytorch中的数据集划分&正则化方法


Posted in Python onMay 27, 2021

1.训练集&验证集&测试集

训练集:训练数据

验证集:验证不同算法(比如利用网格搜索对超参数进行调整等),检验哪种更有效

测试集:正确评估分类器的性能

正常流程:验证集会记录每个时间戳的参数,在加载test数据前会加载那个最好的参数,再来评估。比方说训练完6000个epoch后,发现在第3520个epoch的validation表现最好,测试时会加载第3520个epoch的参数。

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms
#超参数
batch_size=200
learning_rate=0.01
epochs=10
#获取训练数据
train_db = datasets.MNIST('../data', train=True, download=True,   #train=True则得到的是训练集
                   transform=transforms.Compose([                 #transform进行数据预处理
                       transforms.ToTensor(),                     #转成Tensor类型的数据
                       transforms.Normalize((0.1307,), (0.3081,)) #进行数据标准化(减去均值除以方差)
                   ]))
#DataLoader把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化
train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True)
#获取测试数据
test_db = datasets.MNIST('../data', train=False,
                   transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                   ]))
test_loader = torch.utils.data.DataLoader(test_db, batch_size=batch_size, shuffle=True)
#将训练集拆分成训练集和验证集
print('train:', len(train_db), 'test:', len(test_db))                              #train: 60000 test: 10000
train_db, val_db = torch.utils.data.random_split(train_db, [50000, 10000])
print('db1:', len(train_db), 'db2:', len(val_db))                                  #db1: 50000 db2: 10000
train_loader = torch.utils.data.DataLoader(train_db, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_db, batch_size=batch_size, shuffle=True)
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.model = nn.Sequential(         #定义网络的每一层,
            nn.Linear(784, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 10),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        x = self.model(x)
        return x
net = MLP()
#定义sgd优化器,指明优化参数、学习率,net.parameters()得到这个类所定义的网络的参数[[w1,b1,w2,b2,...]
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss()
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)          #将二维的图片数据摊平[样本数,784]
        logits = net(data)                   #前向传播
        loss = criteon(logits, target)       #nn.CrossEntropyLoss()自带Softmax
        optimizer.zero_grad()                #梯度信息清空
        loss.backward()                      #反向传播获取梯度
        optimizer.step()                     #优化器更新
        if batch_idx % 100 == 0:             #每100个batch输出一次信息
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
    #验证集用来检测训练是否过拟合
    val_loss = 0
    correct = 0
    for data, target in val_loader:
        data = data.view(-1, 28 * 28)
        logits = net(data)
        val_loss += criteon(logits, target).item()
        pred = logits.data.max(dim=1)[1]
        correct += pred.eq(target.data).sum()
    val_loss /= len(val_loader.dataset)
    print('\nVAL set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        val_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))
#测试集用来评估
test_loss = 0
correct = 0                                         #correct记录正确分类的样本数
for data, target in test_loader:
    data = data.view(-1, 28 * 28)
    logits = net(data)
    test_loss += criteon(logits, target).item()     #其实就是criteon(logits, target)的值,标量
    pred = logits.data.max(dim=1)[1]                #也可以写成pred=logits.argmax(dim=1)
    correct += pred.eq(target.data).sum()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

2.正则化

正则化可以解决过拟合问题。

2.1L2范数(更常用)

在定义优化器的时候设定weigth_decay,即L2范数前面的λ参数。

optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)

2.2L1范数

Pytorch没有直接可以调用的方法,实现如下:

Pytorch中的数据集划分&正则化方法

3.动量(Momentum)

Adam优化器内置了momentum,SGD需要手动设置。

optimizer = torch.optim.SGD(model.parameters(), args=lr, momentum=args.momentum, weight_decay=args.weight_decay)

4.学习率衰减

torch.optim.lr_scheduler 中提供了基于多种epoch数目调整学习率的方法。

4.1torch.optim.lr_scheduler.ReduceLROnPlateau:基于测量指标对学习率进行动态的下降

torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08)

训练过程中,optimizer会把learning rate 交给scheduler管理,当指标(比如loss)连续patience次数还没有改进时,需要降低学习率,factor为每次下降的比例。

scheduler.step(loss_val)每调用一次就会监听一次loss_val。

Pytorch中的数据集划分&正则化方法

4.2torch.optim.lr_scheduler.StepLR:基于epoch

torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)

当epoch每过stop_size时,学习率都变为初始学习率的gamma倍。

Pytorch中的数据集划分&正则化方法

5.提前停止(防止overfitting)

基于经验值。

6.Dropout随机失活

遍历每一层,设置消除神经网络中的节点概率,得到精简后的一个样本。

torch.nn.Dropout(p=dropout_prob)

p表示的示的是删除节点数的比例(Tip:tensorflow中keep_prob表示保留节点数的比例,不要混淆)

Pytorch中的数据集划分&正则化方法

测试阶段无需使用dropout,所以在train之前执行net_dropped.train()相当于启用dropout,测试之前执行net_dropped.eval()相当于不启用dropout。

Pytorch中的数据集划分&正则化方法

以上为个人经验,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python使用win32com在百度空间插入html元素示例
Feb 20 Python
python实现备份目录的方法
Aug 03 Python
Python常用知识点汇总
May 08 Python
详解PyTorch批训练及优化器比较
Apr 28 Python
在linux系统下安装python librtmp包的实现方法
Jul 22 Python
Python判断字符串是否xx开始或结尾的示例
Aug 08 Python
python基于三阶贝塞尔曲线的数据平滑算法
Dec 27 Python
python实现飞机大战游戏(pygame版)
Oct 26 Python
Python读取VOC中的xml目标框实例
Mar 10 Python
python安装和pycharm环境搭建设置方法
May 27 Python
django rest framework使用django-filter用法
Jul 15 Python
Python-openpyxl表格读取写入的案例详解
Nov 02 Python
Pytorch 如何实现常用正则化
PyTorch 实现L2正则化以及Dropout的操作
Python开发之QT解决无边框界面拖动卡屏问题(附带源码)
pytorch 实现在测试的时候启用dropout
使用Python脚本对GiteePages进行一键部署的使用说明
教你使用Python pypinyin库实现汉字转拼音
基于tensorflow权重文件的解读
May 26 #Python
You might like
PHP4实际应用经验篇(4)
2006/10/09 PHP
php源码加密 仿微盾PHP加密专家(PHPCodeLock)
2010/05/06 PHP
php短址转换实现方法
2015/02/25 PHP
PHP记录搜索引擎蜘蛛访问网站足迹的方法
2015/04/15 PHP
yii2局部关闭(开启)csrf的验证的实例代码
2017/07/10 PHP
PHP正则表达式处理函数(PCRE 函数)实例小结
2019/05/09 PHP
YUI 读码日记之 YAHOO.lang.is*
2008/03/22 Javascript
javascript 火狐(firefox)不显示本地图片问题解决
2008/07/05 Javascript
JavaScript CSS修改学习第二章 样式
2010/02/19 Javascript
javascript AOP 实现ajax回调函数使用比较方便
2010/11/20 Javascript
Chrome Form多次提交表单问题的解决方法
2011/05/09 Javascript
JavaScript的parseInt 取整使用
2011/05/09 Javascript
jquery中子元素和后代元素的区别示例介绍
2014/04/02 Javascript
JavaScript获取网页中第一个图片id的方法
2015/04/03 Javascript
javascript解决IE6下hover问题的方法
2015/07/28 Javascript
js获取客户端操作系统类型的方法【测试可用】
2016/05/27 Javascript
React优化子组件render的使用
2019/05/12 Javascript
详解微信小程序胶囊按钮返回|首页自定义导航栏功能
2019/06/14 Javascript
layUI实现三级导航菜单效果
2019/07/26 Javascript
Python访问MySQL封装的常用类实例
2014/11/11 Python
python在windows命令行下输出彩色文字的方法
2015/03/19 Python
Django之form组件自动校验数据实现
2020/01/14 Python
HTML5之WebGL 3D概述(上)—WebGL原生开发开启网页3D渲染新时代
2013/01/31 HTML / CSS
Monnier Frères美国官网:法国知名奢侈品网站
2016/11/22 全球购物
杭州-飞时达软件有限公司.net笔面试
2012/04/28 面试题
档案保密承诺书
2014/06/03 职场文书
幼儿园安全责任书范本
2014/07/24 职场文书
药店采购员岗位职责
2014/09/30 职场文书
党员个人批评与自我批评
2014/10/14 职场文书
2015年新农合工作总结
2015/03/30 职场文书
话题作文之呼唤
2019/12/18 职场文书
MySQL千万级数据表的优化实战记录
2021/08/04 MySQL
分享MySQL常用 内核 Debug 几种常见方法
2022/03/17 MySQL
Ruby使用Mysql2连接操作MySQL
2022/04/19 Ruby
Redis特殊数据类型Geospatial地理空间
2022/06/01 Redis
python自动获取微信公众号最新文章的实现代码
2022/07/15 Python