Pytorch实现图像识别之数字识别(附详细注释)


Posted in Python onMay 11, 2021

使用了两个卷积层加上两个全连接层实现
本来打算从头手撕的,但是调试太耗时间了,改天有时间在从头写一份
详细过程看代码注释,参考了下一个博主的文章,但是链接没注意关了找不到了,博主看到了联系下我,我加上
代码相关的问题可以评论私聊,也可以翻看博客里的文章,部分有详细解释

Python实现代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

# 下载训练集
train_dataset = datasets.MNIST(root='E:\mnist',
                               train=True,
                               transform=transforms.ToTensor(),
                               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='E:\mnist',
                              train=False,
                              transform=transforms.ToTensor(),
                              download=True)

# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
batch_size = 64
# 建立一个数据迭代器
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True)


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2),
                                   nn.ReLU(), nn.MaxPool2d(2, 2))

        self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                                   nn.MaxPool2d(2, 2))

        self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                                 nn.BatchNorm1d(120), nn.ReLU())

        self.fc2 = nn.Sequential(
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(),
            nn.Linear(84, 10))
        # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

    def forward(self, x):
        x = self.conv1(x)
        # print("1:", x.shape)
        # 1: torch.Size([64, 6, 30, 30])
        # max pooling
        # 1: torch.Size([64, 6, 15, 15])
        x = self.conv2(x)
        # print("2:", x.shape)
        # 2: torch.Size([64, 16, 5, 5])
        # 对参数实现扁平化
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x


def test_image_data(images, labels):
    # 初始输出为一段数字图像序列
    # 将一段图像序列整合到一张图片上 (make_grid会默认将图片变成三通道,默认值为0)
    # images: torch.Size([64, 1, 28, 28])
    img = torchvision.utils.make_grid(images)
    # img: torch.Size([3, 242, 242])
    # 将通道维度置在第三个维度
    img = img.numpy().transpose(1, 2, 0)
    # img: torch.Size([242, 242, 3])
    # 减小图像对比度
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # print(labels)
    cv2.imshow('win2', img)
    key_pressed = cv2.waitKey(0)


# 初始化设备信息
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 学习速率
LR = 0.001
# 初始化网络
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(net.parameters(), lr=LR, )
epoch = 1
if __name__ == '__main__':
    for epoch in range(epoch):
        print("GPU:", torch.cuda.is_available())
        sum_loss = 0.0
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # print(inputs.shape)
            # torch.Size([64, 1, 28, 28])
            # 将内存中的数据复制到gpu显存中去
            inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
            # 将梯度归零
            optimizer.zero_grad()
            # 将数据传入网络进行前向运算
            outputs = net(inputs)
            # 得到损失函数
            loss = criterion(outputs, labels)
            # 反向传播
            loss.backward()
            # 通过梯度做一步参数更新
            optimizer.step()
            # print(loss)
            sum_loss += loss.item()
            if i % 100 == 99:
                print('[%d,%d] loss:%.03f' % (epoch + 1, i + 1, sum_loss / 100))
                sum_loss = 0.0
                # 将模型变换为测试模式
        net.eval()
        correct = 0
        total = 0
        for data_test in test_loader:
            _images, _labels = data_test
            # 将内存中的数据复制到gpu显存中去
            images, labels = Variable(_images).cuda(), Variable(_labels).cuda()
            # 图像预测结果
            output_test = net(images)
            # torch.Size([64, 10])
            # 从每行中找到最大预测索引
            _, predicted = torch.max(output_test, 1)
            # 图像可视化
            # print("predicted:", predicted)
            # test_image_data(_images, _labels)
            # 预测数据的数量
            total += labels.size(0)
            # 预测正确的数量
            correct += (predicted == labels).sum()
        print("correct1: ", correct)
        print("Test acc: {0}".format(correct.item() / total))

测试结果:

可以通过调用test_image_data函数查看测试图片

Pytorch实现图像识别之数字识别(附详细注释)

可以看到最后预测的准确度可以达到98%

Pytorch实现图像识别之数字识别(附详细注释)

到此这篇关于Pytorch实现图像识别之数字识别(附详细注释)的文章就介绍到这了,更多相关Pytorch 数字识别内容请搜索三水点靠木以前的文章或继续浏览下面的相关文章希望大家以后多多支持三水点靠木!

Python 相关文章推荐
Python爬虫代理IP池实现方法
Jan 05 Python
json跨域调用python的方法详解
Jan 11 Python
Python文件和流(实例讲解)
Sep 12 Python
Python基于identicon库创建类似Github上用的头像功能
Sep 25 Python
Python反射用法实例简析
Dec 22 Python
基于python批量处理dat文件及科学计算方法详解
May 08 Python
Python基于jieba库进行简单分词及词云功能实现方法
Jun 16 Python
Django添加favicon.ico图标的示例代码
Aug 07 Python
Python3打包exe代码2种方法实例解析
Feb 17 Python
Python基于jieba, wordcloud库生成中文词云
May 13 Python
在python下实现word2vec词向量训练与加载实例
Jun 09 Python
聊聊Python中关于a=[[]]*3的反思
Jun 02 Python
浅谈Python基础之列表那些事儿
详解Python牛顿插值法
Python中使用subprocess库创建附加进程
有趣的二维码:使用MyQR和qrcode来制作二维码
python保存大型 .mat 数据文件报错超出 IO 限制的操作
May 10 #Python
Python批量将csv文件转化成xml文件的实例
python基础之爬虫入门
You might like
使用PHP获取网络文件的实现代码
2010/01/01 PHP
php获取mysql数据库中的所有表名的代码
2011/04/23 PHP
php中文乱码怎么办如何让浏览器自动识别utf-8
2014/01/15 PHP
destoon安全设置中需要设置可写权限的目录及文件
2014/06/21 PHP
php版微信js-sdk支付接口类用法示例
2016/10/12 PHP
php微信开发之关注事件
2018/06/14 PHP
PHP基于array_unique实现二维数组去重
2020/07/14 PHP
javascript IFrame 强制刷新代码
2009/07/23 Javascript
JavaScript开发人员的10个关键习惯小结
2014/12/05 Javascript
jQuery使用after()方法在元素后面添加多项内容的方法
2015/03/26 Javascript
用nodejs的实现原理和搭建服务器(动态)
2016/08/10 NodeJs
浅谈angular懒加载的一些坑
2016/08/20 Javascript
有趣的bootstrap走动进度条
2016/12/01 Javascript
AngularJS过滤器filter用法总结
2016/12/13 Javascript
json数据处理及数据绑定
2017/01/25 Javascript
Angular 2 ngForm中的ngModel、[ngModel]和[(ngModel)]的写法
2017/06/29 Javascript
基于JavaScript实现选项卡效果
2017/07/21 Javascript
微信小程序实现登录遮罩效果
2018/11/01 Javascript
Nuxt.js之自动路由原理的实现方法
2018/11/21 Javascript
使用Vue.observable()进行状态管理的实例代码详解
2019/05/26 Javascript
深入了解响应式React Native Echarts组件
2019/05/29 Javascript
vue获取验证码倒计时组件
2019/08/26 Javascript
微信小程序如何加载数据库真实数据的实现
2020/03/04 Javascript
python编程嵌套函数实例代码
2018/02/11 Python
Python使用修饰器进行异常日志记录操作示例
2019/03/19 Python
基于python实现微信好友数据分析(简单)
2020/02/16 Python
Vs Code中8个好用的python 扩展插件
2020/10/12 Python
HTML5的自定义属性data-*详细介绍和JS操作实例
2014/04/10 HTML / CSS
关于canvas绘制模糊问题的解决方法
2019/09/24 HTML / CSS
Vince官网:全球著名设计师品牌,休闲而优雅的服饰
2017/01/15 全球购物
中学生打架检讨书
2014/02/10 职场文书
预备党员承诺书
2014/03/25 职场文书
公司授权委托书范文
2014/08/02 职场文书
工厂标语大全
2014/10/06 职场文书
原生CSS实现文字无限轮播的通用方法
2021/03/30 HTML / CSS
SQL实现LeetCode(180.连续的数字)
2021/08/04 MySQL