pytorch实现手写数字图片识别


Posted in Python onMay 20, 2021

本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下

数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3 + b3
        x = self.fc3(x)

        return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []
for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        #加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]
        #但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作
        x = x.view(x.size(0), 28*28)
        #  [b, 10]
        out = net(x)
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        # w' = w - lr*grad
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)
    # we get optimal [w1, b1, w2, b2, w3, b3]


total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from    matplotlib import pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

打印出损失下降的曲线图:

pytorch实现手写数字图片识别

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

pytorch实现手写数字图片识别

六张图片的预测全部正确。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python学习笔记之os模块使用总结
Nov 03 Python
Python 2.7.x 和 3.x 版本的重要区别小结
Nov 28 Python
仅用500行Python代码实现一个英文解析器的教程
Apr 02 Python
简单说明Python中的装饰器的用法
Apr 24 Python
Python打印斐波拉契数列实例
Jul 07 Python
使用Python多线程爬虫爬取电影天堂资源
Sep 23 Python
详解Python3除法之真除法、截断除法和下取整对比
May 23 Python
PYTHON绘制雷达图代码实例
Oct 15 Python
Python模块常用四种安装方式
Oct 20 Python
python将下载到本地m3u8视频合成MP4的代码详解
Nov 24 Python
python sleep和wait对比总结
Feb 03 Python
对Pytorch 中的contiguous理解说明
Mar 03 Python
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
教你怎么用Python生成九宫格照片
You might like
我的群发邮件程序
2006/10/09 PHP
thinkphp实现上一篇与下一篇的方法
2014/12/08 PHP
php写入文件不覆盖的实例讲解
2019/09/17 PHP
javascript 贪吃蛇实现代码
2008/11/22 Javascript
JQuery优缺点分析说明
2010/06/09 Javascript
jQuery 验证插件 Web前端设计模式(asp.net)
2010/10/17 Javascript
javascript弹出层输入框(示例代码)
2013/12/11 Javascript
js+csss实现的一个带复选框的下拉框
2014/09/29 Javascript
Node.js中防止错误导致的进程阻塞的方法
2016/08/11 Javascript
nodeJS删除文件方法示例
2016/12/25 NodeJs
关于vue.js发布后路径引用的问题解决
2017/08/15 Javascript
jQuery实现网页拼图游戏
2020/04/22 jQuery
3分钟读懂移动端rem使用方法(推荐)
2019/05/06 Javascript
微信小程序本地存储实现每日签到、连续签到功能
2019/10/09 Javascript
jQuery实现可编辑的表格
2019/12/11 jQuery
[03:44]2015国际邀请赛选手档案—Cloud9.NoTail
2015/07/28 DOTA
[01:29:42]Liquid vs VP Supermajor决赛 BO 第一场 6.10
2018/07/05 DOTA
python获得linux下所有挂载点(mount points)的方法
2015/04/29 Python
Python使用正则表达式过滤或替换HTML标签的方法详解
2017/09/25 Python
python多线程下信号处理程序示例
2019/05/31 Python
对numpy下的轴交换transpose和swapaxes的示例解读
2019/06/26 Python
pandas按行按列遍历Dataframe的几种方式
2019/10/23 Python
python机器学习库xgboost的使用
2020/01/20 Python
Django调用百度AI接口实现人脸注册登录代码实例
2020/04/23 Python
Senreve官网:美国旧金山的奢侈手袋品牌
2019/03/21 全球购物
初中女生自我鉴定
2013/12/19 职场文书
社区工作者思想汇报
2014/01/13 职场文书
网页美工求职信
2014/02/15 职场文书
会计电算化应届生自荐信
2014/02/25 职场文书
开服装店计划书
2014/08/15 职场文书
2014年加油站站长工作总结
2014/12/23 职场文书
婚礼新人答谢词
2015/01/04 职场文书
高三复习计划
2015/01/19 职场文书
小学生读书笔记范文
2015/06/30 职场文书
投诉信范文
2015/07/02 职场文书
vue中data里面的数据相互使用方式
2022/06/05 Vue.js