pytorch实现手写数字图片识别


Posted in Python onMay 20, 2021

本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下

数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3 + b3
        x = self.fc3(x)

        return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []
for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        #加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]
        #但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作
        x = x.view(x.size(0), 28*28)
        #  [b, 10]
        out = net(x)
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        # w' = w - lr*grad
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)
    # we get optimal [w1, b1, w2, b2, w3, b3]


total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from    matplotlib import pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

打印出损失下降的曲线图:

pytorch实现手写数字图片识别

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

pytorch实现手写数字图片识别

六张图片的预测全部正确。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中contextlib上下文管理模块的用法
Jun 28 Python
Python实现对字符串的加密解密方法示例
Apr 29 Python
Python 多线程Threading初学教程
Aug 22 Python
pandas修改DataFrame列名的方法
Apr 08 Python
python在每个字符后添加空格的实例
May 07 Python
Django的性能优化实现解析
Jul 30 Python
python psutil模块使用方法解析
Aug 01 Python
Python中pyecharts安装及安装失败的解决方法
Feb 18 Python
sklearn线性逻辑回归和非线性逻辑回归的实现
Jun 09 Python
Python matplotlib图例放在外侧保存时显示不完整问题解决
Jul 28 Python
Python 实现二叉查找树的示例代码
Dec 21 Python
基于 Python 实践感知器分类算法
Jan 07 Python
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
教你怎么用Python生成九宫格照片
You might like
php中\r \r\n \t的区别示例介绍
2014/02/08 PHP
php数值转换时间及时间转换数值用法示例
2017/05/18 PHP
PHP使用PDO实现mysql防注入功能详解
2019/12/20 PHP
Javascript setInterval的两种调用方法(实例讲解)
2013/11/29 Javascript
javascript实现五星评价代码(源码下载)
2015/08/11 Javascript
Bootstrap每天必学之附加导航(Affix)插件
2016/04/25 Javascript
微信小程序 欢迎界面开发的实例详解
2016/11/30 Javascript
jQuery Position方法使用和兼容性
2017/08/23 jQuery
javascript 中事件冒泡和事件捕获机制的详解
2017/09/01 Javascript
微信小程序使用input组件实现密码框功能【附源码下载】
2017/12/11 Javascript
vue 纯js监听滚动条到底部的实例讲解
2018/09/03 Javascript
vue elementUI使用tabs与导航栏联动
2019/06/21 Javascript
详解Vue3 Composition API中的提取和重用逻辑
2020/04/29 Javascript
[03:36]2014DOTA2 TI小组赛综述 八强诞生进军钥匙球馆
2014/07/15 DOTA
Python专用方法与迭代机制实例分析
2014/09/15 Python
使用Python编写一个简单的tic-tac-toe游戏的教程
2015/04/16 Python
python自定义解析简单xml格式文件的方法
2015/05/11 Python
python实现rsa加密实例详解
2017/07/19 Python
python实现类之间的方法互相调用
2018/04/29 Python
Python django使用多进程连接mysql错误的解决方法
2018/10/08 Python
Ubuntu下Python2与Python3的共存问题
2018/10/31 Python
linux查找当前python解释器的位置方法
2019/02/20 Python
为什么是 Python -m
2020/06/19 Python
Python如何实现远程方法调用
2020/08/07 Python
互斥锁解决 Python 中多线程共享全局变量的问题(推荐)
2020/09/28 Python
scrapy redis配置文件setting参数详解
2020/11/18 Python
纯CSS3制作的鼠标悬停时边框旋转
2017/01/03 HTML / CSS
澳大利亚个性化儿童礼品网站:Bright Star Kids
2019/06/14 全球购物
数据库笔试题
2013/05/09 面试题
夜大毕业生自我鉴定
2013/10/31 职场文书
学校司机岗位职责
2013/11/14 职场文书
护士自荐信范文
2013/12/15 职场文书
党员年度个人总结
2015/02/14 职场文书
故意伤害辩护词
2015/05/21 职场文书
九大龙王魂骨,山龙王留下躯干骨,榜首死的最憋屈(被捏碎)
2022/03/18 国漫
深入讲解Vue中父子组件通信与事件触发
2022/03/22 Vue.js