PyTorch: Softmax多分类实战操作


Posted in Python onJuly 07, 2020

多分类一种比较常用的做法是在最后一层加softmax归一化,值最大的维度所对应的位置则作为该样本对应的类。本文采用PyTorch框架,选用经典图像数据集mnist学习一波多分类。

MNIST数据集

MNIST 数据集(手写数字数据集)来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。MNIST数据集下载地址:http://yann.lecun.com/exdb/mnist/。手写数字的MNIST数据库包括60,000个的训练集样本,以及10,000个测试集样本。

PyTorch: Softmax多分类实战操作

其中:

train-images-idx3-ubyte.gz (训练数据集图片)

train-labels-idx1-ubyte.gz (训练数据集标记类别)

t10k-images-idx3-ubyte.gz: (测试数据集)

t10k-labels-idx1-ubyte.gz(测试数据集标记类别)

PyTorch: Softmax多分类实战操作

MNIST数据集是经典图像数据集,包括10个类别(0到9)。每一张图片拉成向量表示,如下图784维向量作为第一层输入特征。

PyTorch: Softmax多分类实战操作

Softmax分类

softmax函数的本质就是将一个K 维的任意实数向量压缩(映射)成另一个K维的实数向量,其中向量中的每个元素取值都介于(0,1)之间,并且压缩后的K个值相加等于1(变成了概率分布)。在选用Softmax做多分类时,可以根据值的大小来进行多分类的任务,如取权重最大的一维。softmax介绍和公式网上很多,这里不介绍了。下面使用Pytorch定义一个多层网络(4个隐藏层,最后一层softmax概率归一化),输出层为10正好对应10类。

PyTorch: Softmax多分类实战操作

PyTorch实战

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Training settings
batch_size = 64

# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)

test_dataset = datasets.MNIST(root='./mnist_data/',
               train=False,
               transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=False)
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.l1 = nn.Linear(784, 520)
    self.l2 = nn.Linear(520, 320)
    self.l3 = nn.Linear(320, 240)
    self.l4 = nn.Linear(240, 120)
    self.l5 = nn.Linear(120, 10)

  def forward(self, x):
    # Flatten the data (n, 1, 28, 28) --> (n, 784)
    x = x.view(-1, 784)
    x = F.relu(self.l1(x))
    x = F.relu(self.l2(x))
    x = F.relu(self.l3(x))
    x = F.relu(self.l4(x))
    return F.log_softmax(self.l5(x), dim=1)
    #return self.l5(x)
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
def train(epoch):

  # 每次输入barch_idx个数据
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data), Variable(target)

    optimizer.zero_grad()
    output = model(data)
    # loss
    loss = F.nll_loss(output, target)
    loss.backward()
    # update
    optimizer.step()
    if batch_idx % 200 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.data[0]))
def test():
  test_loss = 0
  correct = 0
  # 测试集
  for data, target in test_loader:
    data, target = Variable(data, volatile=True), Variable(target)
    output = model(data)
    # sum up batch loss
    test_loss += F.nll_loss(output, target).data[0]
    # get the index of the max
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

for epoch in range(1,6):
  train(epoch)
  test()

输出结果:
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.292192
Train Epoch: 1 [12800/60000 (21%)]	Loss: 2.289466
Train Epoch: 1 [25600/60000 (43%)]	Loss: 2.294221
Train Epoch: 1 [38400/60000 (64%)]	Loss: 2.169656
Train Epoch: 1 [51200/60000 (85%)]	Loss: 1.561276

Test set: Average loss: 0.0163, Accuracy: 6698/10000 (67%)

Train Epoch: 2 [0/60000 (0%)]	Loss: 0.993218
Train Epoch: 2 [12800/60000 (21%)]	Loss: 0.859608
Train Epoch: 2 [25600/60000 (43%)]	Loss: 0.499748
Train Epoch: 2 [38400/60000 (64%)]	Loss: 0.422055
Train Epoch: 2 [51200/60000 (85%)]	Loss: 0.413933

Test set: Average loss: 0.0065, Accuracy: 8797/10000 (88%)

Train Epoch: 3 [0/60000 (0%)]	Loss: 0.465154
Train Epoch: 3 [12800/60000 (21%)]	Loss: 0.321842
Train Epoch: 3 [25600/60000 (43%)]	Loss: 0.187147
Train Epoch: 3 [38400/60000 (64%)]	Loss: 0.469552
Train Epoch: 3 [51200/60000 (85%)]	Loss: 0.270332

Test set: Average loss: 0.0045, Accuracy: 9137/10000 (91%)

Train Epoch: 4 [0/60000 (0%)]	Loss: 0.197497
Train Epoch: 4 [12800/60000 (21%)]	Loss: 0.234830
Train Epoch: 4 [25600/60000 (43%)]	Loss: 0.260302
Train Epoch: 4 [38400/60000 (64%)]	Loss: 0.219375
Train Epoch: 4 [51200/60000 (85%)]	Loss: 0.292754

Test set: Average loss: 0.0037, Accuracy: 9277/10000 (93%)

Train Epoch: 5 [0/60000 (0%)]	Loss: 0.183354
Train Epoch: 5 [12800/60000 (21%)]	Loss: 0.207930
Train Epoch: 5 [25600/60000 (43%)]	Loss: 0.138435
Train Epoch: 5 [38400/60000 (64%)]	Loss: 0.120214
Train Epoch: 5 [51200/60000 (85%)]	Loss: 0.266199

Test set: Average loss: 0.0026, Accuracy: 9506/10000 (95%)
Process finished with exit code 0

随着训练迭代次数的增加,测试集的精确度还是有很大提高的。并且当迭代次数为5时,使用这种简单的网络可以达到95%的精确度。

以上这篇PyTorch: Softmax多分类实战操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python 图片验证码代码分享
Jul 04 Python
python计算程序开始到程序结束的运行时间和程序运行的CPU时间
Nov 28 Python
Python操作json数据的一个简单例子
Apr 17 Python
pygame学习笔记(1):矩形、圆型画图实例
Apr 15 Python
在Django的session中使用User对象的方法
Jul 23 Python
Python安装使用命令行交互模块pexpect的基础教程
May 12 Python
python中列表和元组的区别
Dec 18 Python
高效使用Python字典的清单
Apr 04 Python
新年快乐! python实现绚烂的烟花绽放效果
Jan 30 Python
python基于paramiko将文件上传到服务器代码实现
Jul 08 Python
浅谈django 重载str 方法
May 19 Python
Python进行特征提取的示例代码
Oct 15 Python
opencv 形态学变换(开运算,闭运算,梯度运算)
Jul 07 #Python
解决pytorch 交叉熵损失输出为负数的问题
Jul 07 #Python
Python基于httpx模块实现发送请求
Jul 07 #Python
opencv 图像腐蚀和图像膨胀的实现
Jul 07 #Python
Pytorch损失函数nn.NLLLoss2d()用法说明
Jul 07 #Python
浅析Python __name__ 是什么
Jul 07 #Python
Pytorch上下采样函数--interpolate用法
Jul 07 #Python
You might like
如何使用动态共享对象的模式来安装PHP
2006/10/09 PHP
谈谈PHP语法(4)
2006/10/09 PHP
PHP经典的给图片加水印程序
2006/12/06 PHP
PHP代码重构方法漫谈
2018/04/17 PHP
PHP如何通过表单直接提交大文件详解
2019/01/08 PHP
javascript 框架小结 个人工作经验
2009/06/13 Javascript
JavaScript中的this实例分析
2011/04/28 Javascript
定时器(setTimeout/setInterval)调用带参函数失效解决方法
2013/03/26 Javascript
小结Node.js中非阻塞IO和事件循环
2014/09/18 Javascript
关于Bootstrap弹出框无法调用问题的解决办法
2016/03/10 Javascript
Bootstrap table使用方法总结
2017/05/10 Javascript
JS点击缩略图整屏居中放大图片效果
2017/07/04 Javascript
vue项目使用axios发送请求让ajax请求头部携带cookie的方法
2018/09/26 Javascript
js异步上传多张图片插件的使用方法
2018/10/22 Javascript
Vue 实现把表单form数据 转化成json格式的数据
2019/10/29 Javascript
[02:08]我的刀塔不可能这么可爱 胡晓桃_1
2014/06/20 DOTA
[02:43]DOTA2亚洲邀请赛场馆攻略——带你走进东方体育中心
2018/03/19 DOTA
[01:08:24]DOTA2-DPC中国联赛 正赛 RNG vs Phoenix BO3 第一场 2月5日
2021/03/11 DOTA
使用Python程序抓取新浪在国内的所有IP的教程
2015/05/04 Python
PHP网页抓取之抓取百度贴吧邮箱数据代码分享
2016/04/13 Python
使用python编写监听端
2018/04/12 Python
tensorflow 获取模型所有参数总和数量的方法
2018/06/14 Python
使用tensorflow实现线性svm
2018/09/07 Python
python得到单词模式的示例
2018/10/15 Python
Python三元运算与lambda表达式实例解析
2019/11/30 Python
Python接口自动化测试的实现
2020/08/28 Python
CSS3 Columns分列式布局方法简介
2014/05/03 HTML / CSS
团员个人的自我评价
2013/12/02 职场文书
自我介绍演讲稿
2014/01/15 职场文书
法律七进实施方案
2014/03/15 职场文书
单位工作证明格式模板
2014/10/04 职场文书
井冈山红色之旅感想
2014/10/07 职场文书
广播稿:校园广播稿范文
2019/04/17 职场文书
python数字转对应中文的方法总结
2021/08/02 Python
【海涛dota解说】海涛小满开黑4v5被破两路翻盘潮汐第一视角解说
2022/04/01 DOTA
Python图像处理库PIL详细使用说明
2022/04/06 Python