pytorch实现手写数字图片识别


Posted in Python onMay 20, 2021

本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下

数据集:MNIST数据集,代码中会自动下载,不用自己手动下载。数据集很小,不需要GPU设备,可以很好的体会到pytorch的魅力。
模型+训练+预测程序:

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt
from utils import plot_image, plot_curve, one_hot

# step1  load dataset
batch_size = 512
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,)
                                   )
                               ])),
    batch_size=batch_size, shuffle=False)
x , y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, "image_sample")

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(28*28, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x))
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x))
        # h3 = h2w3 + b3
        x = self.fc3(x)

        return x
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []
for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        #加载进来的图片是一个四维的tensor,x: [b, 1, 28, 28], y:[512]
        #但是我们网络的输入要是一个一维向量(也就是二维tensor),所以要进行展平操作
        x = x.view(x.size(0), 28*28)
        #  [b, 10]
        out = net(x)
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward()
        # w' = w - lr*grad
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)
    # we get optimal [w1, b1, w2, b2, w3, b3]


total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)
acc = total_correct/total_num
print("acc:", acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

主程序中调用的函数(注意命名为utils):

import  torch
from    matplotlib import pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()


def plot_image(img, label, name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

打印出损失下降的曲线图:

pytorch实现手写数字图片识别

训练3个epoch之后,在测试集上的精度就可以89%左右,可见模型的准确度还是很不错的。
输出六张测试集的图片以及预测结果:

pytorch实现手写数字图片识别

六张图片的预测全部正确。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python利用有道翻译实现"语言翻译器"的功能实例
Nov 14 Python
Windows下安装Django框架的方法简明教程
Mar 28 Python
python实现超简单的视频对象提取功能
Jun 04 Python
django之跨表查询及添加记录的示例代码
Oct 16 Python
Python爬虫实现爬取百度百科词条功能实例
Apr 05 Python
Python3.5模块的定义、导入、优化操作图文详解
Apr 27 Python
Python3 批量扫描端口的例子
Jul 25 Python
python集合常见运算案例解析
Oct 17 Python
python栈的基本定义与使用方法示例【初始化、赋值、入栈、出栈等】
Oct 24 Python
python中return的返回和执行实例
Dec 24 Python
详解python with 上下文管理器
Sep 02 Python
python-for x in range的用法(注意要点、细节)
May 10 Python
解决python3安装pandas出错的问题
May 20 #Python
python 如何在list中找Topk的数值和索引
May 20 #Python
Requests什么的通通爬不了的Python超强反爬虫方案!
python使用glob检索文件的操作
python opencv通过按键采集图片源码
python 如何执行控制台命令与操作剪切板
教你怎么用Python生成九宫格照片
You might like
php对大文件进行读取操作的实现代码
2013/01/23 PHP
PHP中将ip地址转成十进制数的两种实用方法
2013/08/15 PHP
php文件上传 你真的掌握了吗
2016/11/28 PHP
PHP闭包定义与使用简单示例
2018/04/13 PHP
php与js的区别是什么
2013/08/05 Javascript
教你用jquery实现iframe自适应高度
2014/06/11 Javascript
基于jQuery实现表单提交验证
2014/11/24 Javascript
js获取数组的最后一个元素
2015/04/14 Javascript
JS版元素周期表实现方法
2015/08/05 Javascript
jQuery内容折叠效果插件用法实例分析(附demo源码)
2016/04/28 Javascript
jQuery Form表单取值的方法
2017/01/11 Javascript
深入解析js轮播插件核心代码的实现过程
2017/04/14 Javascript
Angular.JS中select下拉框设置value的方法
2017/06/20 Javascript
webpack处理 css\less\sass 样式的方法
2017/08/21 Javascript
JavaScript实现滑动导航栏效果
2017/08/30 Javascript
jQuery 开发之EasyUI 添加数据的实例
2017/09/26 jQuery
vue单个组件实现无限层级多选菜单功能
2018/04/10 Javascript
vue 下列表侧滑操作实例代码详解
2018/07/24 Javascript
微信小程序 JS动态修改样式的实现方法
2018/12/16 Javascript
Javascript读取上传文件内容/类型/字节数
2019/04/30 Javascript
浅谈Vue SSR中的Bundle的具有使用
2019/11/21 Javascript
python进阶教程之循环对象
2014/08/30 Python
Python操作MySQL数据库9个实用实例
2015/12/11 Python
分析python动态规划的递归、非递归实现
2018/03/04 Python
Python操作mongodb数据库的方法详解
2018/12/08 Python
sklearn中的交叉验证的实现(Cross-Validation)
2021/02/22 Python
美国医生配方营养补充剂供应商:Healthy Directions
2019/07/10 全球购物
日本AOKI官方商城:AOKI西装
2020/06/11 全球购物
什么是JNDI的上下文?如何初始化JNDI上下文
2012/03/10 面试题
做一个有道德的人演讲稿
2014/05/14 职场文书
中学生运动会口号
2014/06/07 职场文书
励志演讲稿200字
2014/08/21 职场文书
基层组织建设年活动总结
2015/05/09 职场文书
谢师宴家长致辞
2015/07/27 职场文书
少儿励志名言(80句)
2019/08/14 职场文书
你知道哪几种MYSQL的连接查询
2021/06/03 MySQL