Pytorch实现图像识别之数字识别(附详细注释)


Posted in Python onMay 11, 2021

使用了两个卷积层加上两个全连接层实现
本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份
详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上
代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释

Python实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

# 下载训练集
train_dataset = datasets.MNIST(root='E:\mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='E:\mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
batch_size = 64
# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), nn.MaxPool2d(2, 2))

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))

        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())

        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Linear(84, 10))
        # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

    def forward(self, x):
        x = self.conv1(x)
        # print("1:", x.shape)
        # 1: torch.Size([64, 6, 30, 30])
        # max pooling
        # 1: torch.Size([64, 6, 15, 15])
        x = self.conv2(x)
        # print("2:", x.shape)
        # 2: torch.Size([64, 16, 5, 5])
        # 对参数实现扁平化
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def test_image_data(images, labels):
    # 初始输出为一段数字图像序列
    # 将一段图像序列整合到一张图片上 (make_grid会默认将图片变成三通道,默认值为0)
    # images: torch.Size([64, 1, 28, 28])
    img = torchvision.utils.make_grid(images)
    # img: torch.Size([3, 242, 242])
    # 将通道维度置在第三个维度
    img = img.numpy().transpose(1, 2, 0)
    # img: torch.Size([242, 242, 3])
    # 减小图像对比度
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # print(labels)
    cv2.imshow('win2', img)
    key_pressed = cv2.waitKey(0)


# 初始化设备信息
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 学习速率
LR = 0.001
# 初始化网络
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(), lr=LR, )
epoch = 1
if __name__ == '__main__':
    for epoch in range(epoch):
        print("GPU:", torch.cuda.is_available())
        sum_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(inputs.shape)
            # torch.Size([64, 1, 28, 28])
            # 将内存中的数据复制到gpu显存中去
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            # 将梯度归零
            optimizer.zero_grad()
            # 将数据传入网络进行前向运算
            outputs = net(inputs)
            # 得到损失函数
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 通过梯度做一步参数更新
            optimizer.step()
            # print(loss)
            sum_loss += loss.item()
            if i % 100 == 99:
                print('[%d,%d] loss:%.03f' % (epoch + 1, i + 1, sum_loss / 100))
                sum_loss = 0.0
                # 将模型变换为测试模式
        net.eval()
        correct = 0
        total = 0
        for data_test in test_loader:
            _images, _labels = data_test
            # 将内存中的数据复制到gpu显存中去
            images, labels = Variable(_images).cuda(), Variable(_labels).cuda()
            # 图像预测结果
            output_test = net(images)
            # torch.Size([64, 10])
            # 从每行中找到最大预测索引
            _, predicted = torch.max(output_test, 1)
            # 图像可视化
            # print("predicted:", predicted)
            # test_image_data(_images, _labels)
            # 预测数据的数量
            total += labels.size(0)
            # 预测正确的数量
            correct += (predicted == labels).sum()
        print("correct1: ", correct)
        print("Test acc: {0}".format(correct.item() / total))

测试结果:

可以通过调用test_image_data函数查看测试图片

Pytorch实现图像识别之数字识别(附详细注释)

可以看到最后预测的准确度可以达到98%

Pytorch实现图像识别之数字识别(附详细注释)

到此这篇关于Pytorch实现图像识别之数字识别(附详细注释)的文章就介绍到这了,更多相关Pytorch 数字识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python中的类学习笔记
Sep 23 Python
详解Python 模拟实现生产者消费者模式的实例
Aug 10 Python
浅谈Python peewee 使用经验
Oct 20 Python
微信跳一跳游戏python脚本
Apr 01 Python
对python程序内存泄漏调试的记录
Jun 11 Python
python 申请内存空间,用于创建多维数组的实例
Dec 02 Python
在keras中model.fit_generator()和model.fit()的区别说明
Jun 17 Python
使用python实现下载我们想听的歌曲,速度超快
Jul 09 Python
解决阿里云邮件发送不能使用25端口问题
Aug 07 Python
教你如何用python操作摄像头以及对视频流的处理
Oct 12 Python
关于PyCharm安装后修改路径名称使其可重新打开的问题
Oct 20 Python
Python+OpenCV图像处理——实现直线检测
Oct 23 Python
浅谈Python基础之列表那些事儿
详解Python牛顿插值法
Python中使用subprocess库创建附加进程
有趣的二维码:使用MyQR和qrcode来制作二维码
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 #Python
Python批量将csv文件转化成xml文件的实例
python基础之爬虫入门
You might like
加强版phplib的DB类
2008/03/31 PHP
php mysql_list_dbs()函数用法示例
2017/03/29 PHP
php 删除一维数组中某一个值元素的操作方法
2018/02/01 PHP
Laravel 手动开关 Eloquent 修改器的操作方法
2019/12/30 PHP
[原创]后缀就扩展名为js的文件是什么文件
2007/12/06 Javascript
jquery下操作HTML控件的实现代码
2010/01/12 Javascript
Web开发之JavaScript
2012/03/29 Javascript
javascript实现详细时间提醒信息效果的方法
2015/03/11 Javascript
黑帽seo劫持程序,js劫持搜索引擎代码
2015/09/15 Javascript
基于Vue生产环境部署详解
2017/09/15 Javascript
Vue 实现双向绑定的四种方法
2018/03/16 Javascript
使用bootstrap实现下拉框搜索功能的实例讲解
2018/08/10 Javascript
vue-cli脚手架搭建的项目去除eslint验证的方法
2018/09/29 Javascript
JS求解两数之和算法详解
2020/04/28 Javascript
Electron整合React使用搭建开发环境的步骤详解
2020/06/07 Javascript
微信小程序地图实现展示线路
2020/07/29 Javascript
[02:26]2018DOTA2亚洲邀请赛赛前采访-Newbee篇
2018/04/03 DOTA
python通过ElementTree操作XML获取结点读取属性美化XML
2013/12/02 Python
用于统计项目中代码总行数的Python脚本分享
2015/04/21 Python
Python中对元组和列表按条件进行排序的方法示例
2015/11/10 Python
pandas获取groupby分组里最大值所在的行方法
2018/04/20 Python
解决安装python库时windows error5 报错的问题
2018/10/21 Python
Python2和Python3之间的str处理方式导致乱码的讲解
2019/01/03 Python
python生成器用法实例详解
2019/11/22 Python
Python爬取腾讯视频评论的思路详解
2019/12/19 Python
Jupyter安装拓展nbextensions及解决官网下载慢的问题
2021/03/03 Python
CSS3教程(1):什么是CSS3
2009/04/02 HTML / CSS
领先的钻石和订婚戒指零售商:Diamonds-USA
2016/12/11 全球购物
Dogeared官网:在美国手工制作的珠宝
2019/08/24 全球购物
自我评价范文
2013/12/22 职场文书
会计专业个人求职信范文
2014/01/08 职场文书
和平主题的演讲稿
2014/01/12 职场文书
如何写好自荐信
2014/04/07 职场文书
2014乡镇党政班子四风问题思想汇报
2014/09/14 职场文书
2016年教师学习教师法心得体会
2016/01/20 职场文书
SQLServer权限之只开启创建表权限
2022/04/12 SQL Server