PyTorch: Softmax多分类实战操作


Posted in Python onJuly 07, 2020

多分类一种比较常用的做法是在最后一层加softmax归一化,值最大的维度所对应的位置则作为该样本对应的类。本文采用PyTorch框架,选用经典图像数据集mnist学习一波多分类。

MNIST数据集

MNIST 数据集(手写数字数据集)来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据。MNIST数据集下载地址:http://yann.lecun.com/exdb/mnist/。手写数字的MNIST数据库包括60,000个的训练集样本,以及10,000个测试集样本。

PyTorch: Softmax多分类实战操作

其中:

train-images-idx3-ubyte.gz (训练数据集图片)

train-labels-idx1-ubyte.gz (训练数据集标记类别)

t10k-images-idx3-ubyte.gz: (测试数据集)

t10k-labels-idx1-ubyte.gz(测试数据集标记类别)

PyTorch: Softmax多分类实战操作

MNIST数据集是经典图像数据集,包括10个类别(0到9)。每一张图片拉成向量表示,如下图784维向量作为第一层输入特征。

PyTorch: Softmax多分类实战操作

Softmax分类

softmax函数的本质就是将一个K 维的任意实数向量压缩(映射)成另一个K维的实数向量,其中向量中的每个元素取值都介于(0,1)之间,并且压缩后的K个值相加等于1(变成了概率分布)。在选用Softmax做多分类时,可以根据值的大小来进行多分类的任务,如取权重最大的一维。softmax介绍和公式网上很多,这里不介绍了。下面使用Pytorch定义一个多层网络(4个隐藏层,最后一层softmax概率归一化),输出层为10正好对应10类。

PyTorch: Softmax多分类实战操作

PyTorch实战

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Training settings
batch_size = 64

# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)

test_dataset = datasets.MNIST(root='./mnist_data/',
               train=False,
               transform=transforms.ToTensor())

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=False)
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.l1 = nn.Linear(784, 520)
    self.l2 = nn.Linear(520, 320)
    self.l3 = nn.Linear(320, 240)
    self.l4 = nn.Linear(240, 120)
    self.l5 = nn.Linear(120, 10)

  def forward(self, x):
    # Flatten the data (n, 1, 28, 28) --> (n, 784)
    x = x.view(-1, 784)
    x = F.relu(self.l1(x))
    x = F.relu(self.l2(x))
    x = F.relu(self.l3(x))
    x = F.relu(self.l4(x))
    return F.log_softmax(self.l5(x), dim=1)
    #return self.l5(x)
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
def train(epoch):

  # 每次输入barch_idx个数据
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = Variable(data), Variable(target)

    optimizer.zero_grad()
    output = model(data)
    # loss
    loss = F.nll_loss(output, target)
    loss.backward()
    # update
    optimizer.step()
    if batch_idx % 200 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.data[0]))
def test():
  test_loss = 0
  correct = 0
  # 测试集
  for data, target in test_loader:
    data, target = Variable(data, volatile=True), Variable(target)
    output = model(data)
    # sum up batch loss
    test_loss += F.nll_loss(output, target).data[0]
    # get the index of the max
    pred = output.data.max(1, keepdim=True)[1]
    correct += pred.eq(target.data.view_as(pred)).cpu().sum()

  test_loss /= len(test_loader.dataset)
  print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

for epoch in range(1,6):
  train(epoch)
  test()

输出结果:
Train Epoch: 1 [0/60000 (0%)]	Loss: 2.292192
Train Epoch: 1 [12800/60000 (21%)]	Loss: 2.289466
Train Epoch: 1 [25600/60000 (43%)]	Loss: 2.294221
Train Epoch: 1 [38400/60000 (64%)]	Loss: 2.169656
Train Epoch: 1 [51200/60000 (85%)]	Loss: 1.561276

Test set: Average loss: 0.0163, Accuracy: 6698/10000 (67%)

Train Epoch: 2 [0/60000 (0%)]	Loss: 0.993218
Train Epoch: 2 [12800/60000 (21%)]	Loss: 0.859608
Train Epoch: 2 [25600/60000 (43%)]	Loss: 0.499748
Train Epoch: 2 [38400/60000 (64%)]	Loss: 0.422055
Train Epoch: 2 [51200/60000 (85%)]	Loss: 0.413933

Test set: Average loss: 0.0065, Accuracy: 8797/10000 (88%)

Train Epoch: 3 [0/60000 (0%)]	Loss: 0.465154
Train Epoch: 3 [12800/60000 (21%)]	Loss: 0.321842
Train Epoch: 3 [25600/60000 (43%)]	Loss: 0.187147
Train Epoch: 3 [38400/60000 (64%)]	Loss: 0.469552
Train Epoch: 3 [51200/60000 (85%)]	Loss: 0.270332

Test set: Average loss: 0.0045, Accuracy: 9137/10000 (91%)

Train Epoch: 4 [0/60000 (0%)]	Loss: 0.197497
Train Epoch: 4 [12800/60000 (21%)]	Loss: 0.234830
Train Epoch: 4 [25600/60000 (43%)]	Loss: 0.260302
Train Epoch: 4 [38400/60000 (64%)]	Loss: 0.219375
Train Epoch: 4 [51200/60000 (85%)]	Loss: 0.292754

Test set: Average loss: 0.0037, Accuracy: 9277/10000 (93%)

Train Epoch: 5 [0/60000 (0%)]	Loss: 0.183354
Train Epoch: 5 [12800/60000 (21%)]	Loss: 0.207930
Train Epoch: 5 [25600/60000 (43%)]	Loss: 0.138435
Train Epoch: 5 [38400/60000 (64%)]	Loss: 0.120214
Train Epoch: 5 [51200/60000 (85%)]	Loss: 0.266199

Test set: Average loss: 0.0026, Accuracy: 9506/10000 (95%)
Process finished with exit code 0

随着训练迭代次数的增加,测试集的精确度还是有很大提高的。并且当迭代次数为5时,使用这种简单的网络可以达到95%的精确度。

以上这篇PyTorch: Softmax多分类实战操作就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python基础教程之Hello World!
Aug 29 Python
跟老齐学Python之复习if语句
Oct 02 Python
基础的十进制按位运算总结与在Python中的计算示例
Jun 28 Python
tensorflow入门之训练简单的神经网络方法
Feb 26 Python
pandas带有重复索引操作方法
Jun 08 Python
[原创]Python入门教程3. 列表基本操作【定义、运算、常用函数】
Oct 30 Python
python 构造三维全零数组的方法
Nov 12 Python
在python中使用requests 模拟浏览器发送请求数据的方法
Dec 26 Python
Python操作配置文件ini的三种方法讲解
Feb 22 Python
python assert的用处示例详解
Apr 01 Python
Python基于BeautifulSoup爬取京东商品信息
Jun 01 Python
详解Python高阶函数
Aug 15 Python
opencv 形态学变换(开运算,闭运算,梯度运算)
Jul 07 #Python
解决pytorch 交叉熵损失输出为负数的问题
Jul 07 #Python
Python基于httpx模块实现发送请求
Jul 07 #Python
opencv 图像腐蚀和图像膨胀的实现
Jul 07 #Python
Pytorch损失函数nn.NLLLoss2d()用法说明
Jul 07 #Python
浅析Python __name__ 是什么
Jul 07 #Python
Pytorch上下采样函数--interpolate用法
Jul 07 #Python
You might like
咖啡是不是喝了会上瘾?咖啡是必须品吗!
2021/03/04 新手入门
PHP.MVC的模板标签系统(一)
2006/09/05 PHP
解析php防止form重复提交的方法
2013/07/01 PHP
CodeIgniter框架过滤HTML危险代码
2014/06/12 PHP
PHP list() 将数组中的值赋给变量的简单实例
2016/06/13 PHP
PHP7新特性foreach 修改示例介绍
2016/08/26 PHP
通过修改Laravel Auth使用salt和password进行认证用户详解
2017/08/17 PHP
PHP登录验证功能示例【用户名、密码、验证码、数据库、已登陆验证、自动登录和注销登录等】
2019/02/25 PHP
对采用动态原型方式无法展示继承机制得思考
2009/12/04 Javascript
原生JS实现响应式瀑布流布局
2015/04/02 Javascript
jquery动画效果学习笔记(8种效果)
2015/11/13 Javascript
详解js中call与apply关键字的作用
2016/11/21 Javascript
利用JQuery实现datatables插件的增加和删除行功能
2017/01/06 Javascript
jQuery插件FusionWidgets实现的Cylinder图效果示例【附demo源码】
2017/03/23 jQuery
微信小程序使用audio组件播放音乐功能示例【附源码下载】
2017/12/08 Javascript
laydate如何根据开始时间或者结束时间限制范围
2018/11/15 Javascript
Vue axios全局拦截 get请求、post请求、配置请求的实例代码
2018/11/28 Javascript
uni-app之APP和小程序微信授权方法
2019/05/09 Javascript
Python3实现购物车功能
2018/04/18 Python
Python pandas.DataFrame调整列顺序及修改index名的方法
2019/06/21 Python
Django shell调试models输出的SQL语句方法
2019/08/29 Python
python常用排序算法的实现代码
2019/11/08 Python
python使用正则来处理各种匹配问题
2019/12/22 Python
python中有关时间日期格式转换问题
2019/12/25 Python
python实现简单的购物程序代码实例
2020/03/03 Python
基于python检查SSL证书到期情况代码实例
2020/04/04 Python
使用CSS3代码绘制可爱的Hello Kitty猫
2016/08/03 HTML / CSS
html5 Canvas画图教程(7)—canvas里画曲线之quadraticCurveTo方法
2013/01/09 HTML / CSS
Michael Kors香港官网:美国奢侈品品牌
2019/12/26 全球购物
实习自我评价怎么写
2013/12/02 职场文书
预防艾滋病宣传活动总结
2015/05/09 职场文书
2016七一建党节慰问信
2015/11/30 职场文书
2016八一建军节慰问信
2015/11/30 职场文书
PHP策略模式写法
2021/04/01 PHP
Python将CSV文件转化为HTML文件的操作方法
2021/06/30 Python
前端canvas中物体边框和控制点的实现示例
2022/08/05 Javascript