pytorch实现手写数字图片识别


Posted in Python onMay 20, 2021

本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下

数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3 + b3
        x = self.fc3(x)

        return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []
for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        #加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]
        #但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作
        x = x.view(x.size(0), 28*28)
        #  [b, 10]
        out = net(x)
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        # w' = w - lr*grad
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)
    # we get optimal [w1, b1, w2, b2, w3, b3]


total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from    matplotlib import pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

打印出损失下降的曲线图:

pytorch实现手写数字图片识别

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

pytorch实现手写数字图片识别

六张图片的预测全部正确。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python人人网登录应用实例
Sep 26 Python
Python中使用hashlib模块处理算法的教程
Apr 28 Python
Win7下Python与Tensorflow-CPU版开发环境的安装与配置过程
Jan 04 Python
Python爬虫实战:分析《战狼2》豆瓣影评
Mar 26 Python
高效使用Python字典的清单
Apr 04 Python
Python装饰器语法糖
Jan 02 Python
将python图片转为二进制文本的实例
Jan 24 Python
pandas基于时间序列的固定时间间隔求均值的方法
Jul 04 Python
pytorch实现保证每次运行使用的随机数都相同
Feb 20 Python
python绘制动态曲线教程
Feb 24 Python
pycharm无法安装第三方库的问题及解决方法以scrapy为例(图解)
May 09 Python
Python项目实战之使用Django框架实现支付宝付款功能
Feb 23 Python
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
教你怎么用Python生成九宫格照片
You might like
php桌面中心(二) 数据库写入
2007/03/11 PHP
PHP XML备份Mysql数据库
2009/05/27 PHP
PHP高级编程实例:编写守护进程
2014/09/02 PHP
php实现遍历文件夹的方法汇总
2017/03/02 PHP
js里取容器大小、定位、距离等属性搜集整理
2013/08/19 Javascript
使用jQuery设置disabled属性与移除disabled属性
2014/08/21 Javascript
jquery trigger函数执行两次的解决方法
2016/02/29 Javascript
Jquery元素追加和删除的实现方法
2016/05/24 Javascript
jQuery纵向导航菜单效果实现方法
2016/12/19 Javascript
微信小程序 Tab页切换更新数据
2017/01/05 Javascript
微信小程序(三):网络请求
2017/01/13 Javascript
nodejs实现OAuth2.0授权服务认证
2017/12/27 NodeJs
vue 子组件向父组件传值方法
2018/02/26 Javascript
angularJs中ng-model-options设置数据同步的方法
2018/09/30 Javascript
layui 上传插件 带预览 非自动上传功能的实例(非常实用)
2019/09/23 Javascript
Vue element-ui父组件控制子组件的表单校验操作
2020/07/17 Javascript
JavaScript实现移动小精灵的案例代码
2020/12/12 Javascript
[02:51]DOTA2英雄基础教程 艾欧
2014/01/13 DOTA
[38:38]完美世界DOTA2联赛PWL S3 access vs Rebirth 第二场 12.17
2020/12/18 DOTA
Python中实现远程调用(RPC、RMI)简单例子
2014/04/28 Python
跟老齐学Python之关于循环的小伎俩
2014/10/02 Python
使用Python脚本将绝对url替换为相对url的教程
2015/04/24 Python
解决Python2.7读写文件中的中文乱码问题
2018/04/12 Python
python matlibplot绘制多条曲线图
2021/02/19 Python
Python人脸识别第三方库face_recognition接口说明文档
2019/05/03 Python
pandas按条件筛选数据的实现
2021/02/20 Python
捷克家居装饰及图书音像购物网站:Velký košík
2018/04/16 全球购物
从当地商店送来的杂货:Instacart
2018/08/19 全球购物
Ryderwear澳洲官网:澳大利亚高端健身训练装备品牌
2018/09/18 全球购物
英文版餐饮业求职信
2013/10/18 职场文书
营销人才自我鉴定范文
2013/12/25 职场文书
西北政法大学自主招生自荐信
2014/01/29 职场文书
大学生通用个人自我评价
2014/04/27 职场文书
Vue全局事件总线你了解吗
2022/02/24 Vue.js
解决Mysql中的innoDB幻读问题
2022/04/29 MySQL
python manim实现排序算法动画示例
2022/08/14 Python