PyTorch: Softmax多分类实战操作


Posted in Python onJuly 07, 2020

多分类一种比较常用的做法是在最后一层加softmax归一化,值最大的维度所对应的位置则作为该样本对应的类。本文采用PyTorch框架,选用经典图像数据集mnist学习一波多分类。

MNIST数据集

MNIST 数据集(手写数字数据集)来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。MNIST数据集下载地址:http://yann.lecun.com/exdb/mnist/。手写数字的MNIST数据库包括60,000个的训练集样本,以及10,000个测试集样本。

PyTorch: Softmax多分类实战操作

其中:

train-images-idx3-ubyte.gz (训练数据集图片)

train-labels-idx1-ubyte.gz (训练数据集标记类别)

t10k-images-idx3-ubyte.gz: (测试数据集)

t10k-labels-idx1-ubyte.gz(测试数据集标记类别)

PyTorch: Softmax多分类实战操作

MNIST数据集是经典图像数据集,包括10个类别(0到9)。每一张图片拉成向量表示,如下图784维向量作为第一层输入特征。

PyTorch: Softmax多分类实战操作

Softmax分类

softmax函数的本质就是将一个K 维的任意实数向量压缩(映射)成另一个K维的实数向量,其中向量中的每个元素取值都介于(0,1)之间,并且压缩后的K个值相加等于1(变成了概率分布)。在选用Softmax做多分类时,可以根据值的大小来进行多分类的任务,如取权重最大的一维。softmax介绍和公式网上很多,这里不介绍了。下面使用Pytorch定义一个多层网络(4个隐藏层,最后一层softmax概率归一化),输出层为10正好对应10类。

PyTorch: Softmax多分类实战操作

PyTorch实战

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Training settings
batch_size = 64

# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)

test_dataset = datasets.MNIST(root='./mnist_data/',
               train=False,
               transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=False)
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.l1 = nn.Linear(784, 520)
    self.l2 = nn.Linear(520, 320)
    self.l3 = nn.Linear(320, 240)
    self.l4 = nn.Linear(240, 120)
    self.l5 = nn.Linear(120, 10)

  def forward(self, x):
    # Flatten the data (n, 1, 28, 28) --> (n, 784)
    x = x.view(-1, 784)
    x = F.relu(self.l1(x))
    x = F.relu(self.l2(x))
    x = F.relu(self.l3(x))
    x = F.relu(self.l4(x))
    return F.log_softmax(self.l5(x), dim=1)
    #return self.l5(x)
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
def train(epoch):

  # 每次输入barch_idx个数据
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data), Variable(target)

    optimizer.zero_grad()
    output = model(data)
    # loss
    loss = F.nll_loss(output, target)
    loss.backward()
    # update
    optimizer.step()
    if batch_idx % 200 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.data[0]))
def test():
  test_loss = 0
  correct = 0
  # 测试集
  for data, target in test_loader:
    data, target = Variable(data, volatile=True), Variable(target)
    output = model(data)
    # sum up batch loss
    test_loss += F.nll_loss(output, target).data[0]
    # get the index of the max
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

for epoch in range(1,6):
  train(epoch)
  test()

输出结果:
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.292192
Train Epoch: 1 [12800/60000 (21%)]	Loss: 2.289466
Train Epoch: 1 [25600/60000 (43%)]	Loss: 2.294221
Train Epoch: 1 [38400/60000 (64%)]	Loss: 2.169656
Train Epoch: 1 [51200/60000 (85%)]	Loss: 1.561276

Test set: Average loss: 0.0163, Accuracy: 6698/10000 (67%)

Train Epoch: 2 [0/60000 (0%)]	Loss: 0.993218
Train Epoch: 2 [12800/60000 (21%)]	Loss: 0.859608
Train Epoch: 2 [25600/60000 (43%)]	Loss: 0.499748
Train Epoch: 2 [38400/60000 (64%)]	Loss: 0.422055
Train Epoch: 2 [51200/60000 (85%)]	Loss: 0.413933

Test set: Average loss: 0.0065, Accuracy: 8797/10000 (88%)

Train Epoch: 3 [0/60000 (0%)]	Loss: 0.465154
Train Epoch: 3 [12800/60000 (21%)]	Loss: 0.321842
Train Epoch: 3 [25600/60000 (43%)]	Loss: 0.187147
Train Epoch: 3 [38400/60000 (64%)]	Loss: 0.469552
Train Epoch: 3 [51200/60000 (85%)]	Loss: 0.270332

Test set: Average loss: 0.0045, Accuracy: 9137/10000 (91%)

Train Epoch: 4 [0/60000 (0%)]	Loss: 0.197497
Train Epoch: 4 [12800/60000 (21%)]	Loss: 0.234830
Train Epoch: 4 [25600/60000 (43%)]	Loss: 0.260302
Train Epoch: 4 [38400/60000 (64%)]	Loss: 0.219375
Train Epoch: 4 [51200/60000 (85%)]	Loss: 0.292754

Test set: Average loss: 0.0037, Accuracy: 9277/10000 (93%)

Train Epoch: 5 [0/60000 (0%)]	Loss: 0.183354
Train Epoch: 5 [12800/60000 (21%)]	Loss: 0.207930
Train Epoch: 5 [25600/60000 (43%)]	Loss: 0.138435
Train Epoch: 5 [38400/60000 (64%)]	Loss: 0.120214
Train Epoch: 5 [51200/60000 (85%)]	Loss: 0.266199

Test set: Average loss: 0.0026, Accuracy: 9506/10000 (95%)
Process finished with exit code 0

随着训练迭代次数的增加,测试集的精确度还是有很大提高的。并且当迭代次数为5时,使用这种简单的网络可以达到95%的精确度。

以上这篇PyTorch: Softmax多分类实战操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中的闭包详细介绍和实例
Nov 21 Python
python学习数据结构实例代码
May 11 Python
python线程中同步锁详解
Apr 27 Python
python pandas 对series和dataframe的重置索引reindex方法
Jun 07 Python
对Python多线程读写文件加锁的实例详解
Jan 14 Python
Python爬虫动态ip代理防止被封的方法
Jul 07 Python
pip安装python库的方法总结
Aug 02 Python
python将字母转化为数字实例方法
Oct 04 Python
Python中的四种交换数值的方法解析
Nov 18 Python
对python中 math模块下 atan 和 atan2的区别详解
Jan 17 Python
Python实现弹球小游戏
Aug 01 Python
python基础之错误和异常处理
Oct 24 Python
opencv 形态学变换(开运算,闭运算,梯度运算)
Jul 07 #Python
解决pytorch 交叉熵损失输出为负数的问题
Jul 07 #Python
Python基于httpx模块实现发送请求
Jul 07 #Python
opencv 图像腐蚀和图像膨胀的实现
Jul 07 #Python
Pytorch损失函数nn.NLLLoss2d()用法说明
Jul 07 #Python
浅析Python __name__ 是什么
Jul 07 #Python
Pytorch上下采样函数--interpolate用法
Jul 07 #Python
You might like
人大复印资料处理程序_查询篇
2006/10/09 PHP
10个实用的PHP正则表达式汇总
2014/10/23 PHP
fckeditor上传文件按日期存放及重命名方法
2015/05/22 PHP
javascript的onchange事件与jQuery的change()方法比较
2009/09/28 Javascript
javascript 去字符串空格终极版(支持utf8)
2009/11/14 Javascript
JavaScript判断窗口是否最小化的代码(跨浏览器)
2010/08/01 Javascript
js自定义鼠标右键的实现原理及源码
2014/06/23 Javascript
一个JavaScript处理textarea中的字符成每一行实例
2014/09/22 Javascript
用AngularJS来实现监察表单按钮的禁用效果
2016/11/02 Javascript
微信小程序 页面跳转和数据传递实例详解
2017/01/19 Javascript
AngularJS实现的输入框字数限制提醒功能示例
2017/10/26 Javascript
微信小程序实现3D轮播图效果(非swiper组件)
2019/09/21 Javascript
vue-simple-uploader上传成功之后的response获取代码
2020/09/07 Javascript
[02:09]2018DOTA2亚洲邀请赛TNC赛前采访
2018/04/04 DOTA
[01:17]炒鸡美酒第四天TA暴走
2018/06/05 DOTA
[50:12]EG vs Fnatic 2018国际邀请赛小组赛BO2 第二场 8.19
2018/08/21 DOTA
在Python的Django框架中创建语言文件
2015/07/27 Python
Django csrf 验证问题的实现
2018/10/09 Python
详解Python3之数据指纹MD5校验与对比
2019/06/11 Python
在python plt图表中文字大小调节的方法
2019/07/08 Python
Python输出指定字符串的方法
2020/02/06 Python
PyCharm+Pipenv虚拟环境开发和依赖管理的教程详解
2020/04/16 Python
使用Python实现微信拍一拍功能的思路代码
2020/07/09 Python
flask开启多线程的具体方法
2020/08/02 Python
python和node.js生成当前时间戳的示例
2020/09/29 Python
美国中小型企业领先的办公家具供应商:Office Designs
2016/11/26 全球购物
纽约海:Sea New York
2018/11/04 全球购物
JENNIFER BEHR官网:各种耳环和发饰
2020/06/07 全球购物
专升本个人自我评价
2013/12/22 职场文书
安全生产管理责任书
2014/04/16 职场文书
村班子对照检查材料
2014/08/18 职场文书
离职感谢信怎么写
2015/01/22 职场文书
入党介绍人意见怎么写
2015/06/03 职场文书
Python下载商品数据并连接数据库且保存数据
2022/03/31 Python
《堡垒之夜》联动《刺客信条》 4月7日正式上线
2022/04/06 其他游戏
Golang入门之计时器
2022/05/04 Golang