pytorch cnn 识别手写的字实现自建图片数据


Posted in Python onMay 20, 2018

本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》

结果如下:

pytorch cnn 识别手写的字实现自建图片数据

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python去除文件中空格、Tab及回车的方法
Apr 12 Python
Python简单实现enum功能的方法
Apr 25 Python
教你用python3根据关键词爬取百度百科的内容
Aug 18 Python
Anaconda2 5.2.0安装使用图文教程
Sep 19 Python
浅谈Python中的bs4基础
Oct 21 Python
浅谈pandas筛选出表中满足另一个表所有条件的数据方法
Feb 08 Python
将Python文件打包成.EXE可执行文件的方法
Aug 11 Python
python 多进程共享全局变量之Manager()详解
Aug 15 Python
python获取栅格点和面值的实现
Mar 10 Python
浅谈keras.callbacks设置模型保存策略
Jun 18 Python
python如何删除列为空的行
Jul 17 Python
python批量修改交换机密码的示例
Sep 22 Python
pytorch 把MNIST数据集转换成图片和txt的方法
May 20 #Python
Python安装lz4-0.10.1遇到的坑
May 20 #Python
Python requests发送post请求的一些疑点
May 20 #Python
python中virtualenvwrapper安装与使用
May 20 #Python
django静态文件加载的方法
May 20 #Python
django中静态文件配置static的方法
May 20 #Python
Python中跳台阶、变态跳台阶与矩形覆盖问题的解决方法
May 19 #Python
You might like
一个很方便的 XML 类!!原创的噢
2006/10/09 PHP
PHP Zip压缩 在线对文件进行压缩的函数
2010/05/26 PHP
php空间不支持socket但支持curl时recaptcha的用法
2011/11/07 PHP
PHP服务器页面间跳转实现方法
2012/08/02 PHP
一个PHP并发访问实例代码
2012/09/06 PHP
用php实现百度网盘图片直链的代码分享
2012/11/01 PHP
PHP生成word文档的三种实现方式
2016/11/14 PHP
phpMyAdmin通过密码漏洞留后门文件
2018/11/20 PHP
PHP使用PDO操作sqlite数据库应用案例
2019/03/07 PHP
关于 byval 与 byref 的区别分析总结
2007/10/08 Javascript
js获取html页面节点方法(递归方式)
2013/12/13 Javascript
javascript页面渲染速度测试脚本分享
2014/04/15 Javascript
浅谈jQuery before和insertBefore的区别
2016/12/04 Javascript
vue中实现滚动加载更多的示例
2017/11/08 Javascript
[01:20]DOTA2上海特级锦标赛现场采访:谁的ID最受青睐
2016/03/25 DOTA
使用Python绘制图表大全总结
2017/02/11 Python
Python AES加密实例解析
2018/01/18 Python
详解程序意外中断自动重启shell脚本(以Python为例)
2019/07/26 Python
40个你可能不知道的Python技巧附代码
2020/01/29 Python
Python logging模块写入中文出现乱码
2020/05/21 Python
python pillow库的基础使用教程
2021/01/13 Python
美国按摩椅批发网站:Titan Chair
2018/12/27 全球购物
湖南卫视在线视频媒体平台:芒果TV
2019/10/30 全球购物
int和Integer有什么区别
2013/05/25 面试题
前台接待的工作职责
2013/11/21 职场文书
开办化妆品公司创业计划书
2013/12/26 职场文书
工业学校毕业生自荐信范文
2014/01/03 职场文书
电台实习生求职信
2014/02/25 职场文书
《石榴》教学反思
2014/03/02 职场文书
小学生中国梦演讲稿
2014/04/23 职场文书
2014医学院领导干部四风对照检查材料思想汇报
2014/09/16 职场文书
2014社区健康教育工作总结
2014/12/16 职场文书
承诺函格式模板
2015/01/21 职场文书
2015年班组长工作总结
2015/04/10 职场文书
《藏戏》教学反思
2016/02/23 职场文书
yyds什么意思?90后已经听不懂00后讲话了……
2022/02/03 杂记