PyTorch: Softmax多分类实战操作


Posted in Python onJuly 07, 2020

多分类一种比较常用的做法是在最后一层加softmax归一化,值最大的维度所对应的位置则作为该样本对应的类。本文采用PyTorch框架,选用经典图像数据集mnist学习一波多分类。

MNIST数据集

MNIST 数据集(手写数字数据集)来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。MNIST数据集下载地址:http://yann.lecun.com/exdb/mnist/。手写数字的MNIST数据库包括60,000个的训练集样本,以及10,000个测试集样本。

PyTorch: Softmax多分类实战操作

其中:

train-images-idx3-ubyte.gz (训练数据集图片)

train-labels-idx1-ubyte.gz (训练数据集标记类别)

t10k-images-idx3-ubyte.gz: (测试数据集)

t10k-labels-idx1-ubyte.gz(测试数据集标记类别)

PyTorch: Softmax多分类实战操作

MNIST数据集是经典图像数据集,包括10个类别(0到9)。每一张图片拉成向量表示,如下图784维向量作为第一层输入特征。

PyTorch: Softmax多分类实战操作

Softmax分类

softmax函数的本质就是将一个K 维的任意实数向量压缩(映射)成另一个K维的实数向量,其中向量中的每个元素取值都介于(0,1)之间,并且压缩后的K个值相加等于1(变成了概率分布)。在选用Softmax做多分类时,可以根据值的大小来进行多分类的任务,如取权重最大的一维。softmax介绍和公式网上很多,这里不介绍了。下面使用Pytorch定义一个多层网络(4个隐藏层,最后一层softmax概率归一化),输出层为10正好对应10类。

PyTorch: Softmax多分类实战操作

PyTorch实战

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Training settings
batch_size = 64

# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)

test_dataset = datasets.MNIST(root='./mnist_data/',
               train=False,
               transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=False)
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.l1 = nn.Linear(784, 520)
    self.l2 = nn.Linear(520, 320)
    self.l3 = nn.Linear(320, 240)
    self.l4 = nn.Linear(240, 120)
    self.l5 = nn.Linear(120, 10)

  def forward(self, x):
    # Flatten the data (n, 1, 28, 28) --> (n, 784)
    x = x.view(-1, 784)
    x = F.relu(self.l1(x))
    x = F.relu(self.l2(x))
    x = F.relu(self.l3(x))
    x = F.relu(self.l4(x))
    return F.log_softmax(self.l5(x), dim=1)
    #return self.l5(x)
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
def train(epoch):

  # 每次输入barch_idx个数据
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data), Variable(target)

    optimizer.zero_grad()
    output = model(data)
    # loss
    loss = F.nll_loss(output, target)
    loss.backward()
    # update
    optimizer.step()
    if batch_idx % 200 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.data[0]))
def test():
  test_loss = 0
  correct = 0
  # 测试集
  for data, target in test_loader:
    data, target = Variable(data, volatile=True), Variable(target)
    output = model(data)
    # sum up batch loss
    test_loss += F.nll_loss(output, target).data[0]
    # get the index of the max
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

for epoch in range(1,6):
  train(epoch)
  test()

输出结果:
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.292192
Train Epoch: 1 [12800/60000 (21%)]	Loss: 2.289466
Train Epoch: 1 [25600/60000 (43%)]	Loss: 2.294221
Train Epoch: 1 [38400/60000 (64%)]	Loss: 2.169656
Train Epoch: 1 [51200/60000 (85%)]	Loss: 1.561276

Test set: Average loss: 0.0163, Accuracy: 6698/10000 (67%)

Train Epoch: 2 [0/60000 (0%)]	Loss: 0.993218
Train Epoch: 2 [12800/60000 (21%)]	Loss: 0.859608
Train Epoch: 2 [25600/60000 (43%)]	Loss: 0.499748
Train Epoch: 2 [38400/60000 (64%)]	Loss: 0.422055
Train Epoch: 2 [51200/60000 (85%)]	Loss: 0.413933

Test set: Average loss: 0.0065, Accuracy: 8797/10000 (88%)

Train Epoch: 3 [0/60000 (0%)]	Loss: 0.465154
Train Epoch: 3 [12800/60000 (21%)]	Loss: 0.321842
Train Epoch: 3 [25600/60000 (43%)]	Loss: 0.187147
Train Epoch: 3 [38400/60000 (64%)]	Loss: 0.469552
Train Epoch: 3 [51200/60000 (85%)]	Loss: 0.270332

Test set: Average loss: 0.0045, Accuracy: 9137/10000 (91%)

Train Epoch: 4 [0/60000 (0%)]	Loss: 0.197497
Train Epoch: 4 [12800/60000 (21%)]	Loss: 0.234830
Train Epoch: 4 [25600/60000 (43%)]	Loss: 0.260302
Train Epoch: 4 [38400/60000 (64%)]	Loss: 0.219375
Train Epoch: 4 [51200/60000 (85%)]	Loss: 0.292754

Test set: Average loss: 0.0037, Accuracy: 9277/10000 (93%)

Train Epoch: 5 [0/60000 (0%)]	Loss: 0.183354
Train Epoch: 5 [12800/60000 (21%)]	Loss: 0.207930
Train Epoch: 5 [25600/60000 (43%)]	Loss: 0.138435
Train Epoch: 5 [38400/60000 (64%)]	Loss: 0.120214
Train Epoch: 5 [51200/60000 (85%)]	Loss: 0.266199

Test set: Average loss: 0.0026, Accuracy: 9506/10000 (95%)
Process finished with exit code 0

随着训练迭代次数的增加,测试集的精确度还是有很大提高的。并且当迭代次数为5时,使用这种简单的网络可以达到95%的精确度。

以上这篇PyTorch: Softmax多分类实战操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
深入理解Python中的super()方法
Nov 20 Python
Python 打印中文字符的三种方法
Aug 14 Python
Python3 SSH远程连接服务器的方法示例
Dec 29 Python
python 实现将文件或文件夹用相对路径打包为 tar.gz 文件的方法
Jun 10 Python
django的聚合函数和aggregate、annotate方法使用详解
Jul 23 Python
python迭代器常见用法实例分析
Nov 22 Python
Python装饰器原理与基本用法分析
Jan 07 Python
pytorch 图像中的数据预处理和批标准化实例
Jan 15 Python
win7上tensorflow2.2.0安装成功 引用DLL load failed时找不到指定模块 tensorflow has no attribute xxx 解决方法
May 20 Python
Python实现图片查找轮廓、多边形拟合、最小外接矩形代码
Jul 14 Python
详解解决jupyter不能使用pytorch的问题
Feb 18 Python
Python实现简单得递归下降Parser
May 02 Python
opencv 形态学变换(开运算,闭运算,梯度运算)
Jul 07 #Python
解决pytorch 交叉熵损失输出为负数的问题
Jul 07 #Python
Python基于httpx模块实现发送请求
Jul 07 #Python
opencv 图像腐蚀和图像膨胀的实现
Jul 07 #Python
Pytorch损失函数nn.NLLLoss2d()用法说明
Jul 07 #Python
浅析Python __name__ 是什么
Jul 07 #Python
Pytorch上下采样函数--interpolate用法
Jul 07 #Python
You might like
php下载远程文件类(支持断点续传)
2008/11/14 PHP
PHP教程之PHP中shell脚本的使用方法分享
2012/02/23 PHP
那些年一起学习的PHP(一)
2012/03/21 PHP
PHP+MySQL之Insert Into数据插入用法分析
2015/09/27 PHP
JS中使用Array函数shift和pop创建可忽略参数的例子
2014/05/28 Javascript
jQuery的:parent选择器定义和用法
2014/07/01 Javascript
详解angular2采用自定义指令(Directive)方式加载jquery插件
2017/02/09 Javascript
基于JavaScript实现类名的添加与移除
2017/04/23 Javascript
webpack写jquery插件的环境配置
2017/12/21 jQuery
JavaScript中filter的用法实例分析
2019/02/27 Javascript
Python中文件遍历的两种方法
2014/06/16 Python
python错误:AttributeError: 'module' object has no attribute 'setdefaultencoding'问题的解决方法
2014/08/22 Python
Django中url的反向查询的方法
2018/03/14 Python
python实现图书馆研习室自动预约功能
2018/04/27 Python
python3 实现对图片进行局部切割的方法
2018/12/05 Python
对python中字典keys,values,items的使用详解
2019/02/03 Python
在Django model中设置多个字段联合唯一约束的实例
2019/07/17 Python
python正则表达式匹配不包含某几个字符的字符串方法
2019/07/23 Python
python 实现绘制整齐的表格
2019/11/18 Python
TensorFlow Saver:保存和读取模型参数.ckpt实例
2020/02/10 Python
使用Python实现音频双通道分离
2020/12/25 Python
使用HTML5的链接预取功能(link prefetching)给网站提速
2012/12/13 HTML / CSS
使用html5 canvas 画时钟代码实例分享
2015/11/11 HTML / CSS
德国奢侈品网上商城:Mytheresa
2016/08/24 全球购物
Geekbuying波兰:购买中国电子产品
2019/10/20 全球购物
应届毕业生个人自荐信范文
2013/11/30 职场文书
自行车租赁公司创业计划书
2014/01/28 职场文书
幼儿教师工作感言
2014/02/14 职场文书
新年团拜会主持词
2014/04/02 职场文书
代领学位证书毕业证书委托书
2014/09/30 职场文书
单位作风建设自查报告
2014/10/23 职场文书
2014年学习委员工作总结
2014/11/14 职场文书
2015年财务工作总结范文
2015/03/31 职场文书
计算机实训心得体会
2016/01/14 职场文书
企业版Windows 11有哪些新功能? Win11适用于企业的功能介绍
2021/11/21 数码科技
Python+腾讯云服务器实现每日自动健康打卡
2021/12/06 Python