pytorch cnn 识别手写的字实现自建图片数据


Posted in Python onMay 20, 2018

本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:

# library
# standard library
import os 
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1)  # reproducible 
# Hyper Parameters
EPOCH = 1        # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001       # learning rate 
 
root = "./mnist/raw/"
 
def default_loader(path):
  # return Image.open(path).convert('RGB')
  return Image.open(path)
 
class MyDataset(Dataset):
  def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
    fh = open(txt, 'r')
    imgs = []
    for line in fh:
      line = line.strip('\n')
      line = line.rstrip()
      words = line.split()
      imgs.append((words[0], int(words[1])))
    self.imgs = imgs
    self.transform = transform
    self.target_transform = target_transform
    self.loader = loader
    fh.close()
  def __getitem__(self, index):
    fn, label = self.imgs[index]
    img = self.loader(fn)
    img = Image.fromarray(np.array(img), mode='L')
    if self.transform is not None:
      img = self.transform(img)
    return img,label
  def __len__(self):
    return len(self.imgs)
 
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
 
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
 
class CNN(nn.Module):
  def __init__(self):
    super(CNN, self).__init__()
    self.conv1 = nn.Sequential(     # input shape (1, 28, 28)
      nn.Conv2d(
        in_channels=1,       # input height
        out_channels=16,      # n_filters
        kernel_size=5,       # filter size
        stride=1,          # filter movement/step
        padding=2,         # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
      ),               # output shape (16, 28, 28)
      nn.ReLU(),           # activation
      nn.MaxPool2d(kernel_size=2),  # choose max value in 2x2 area, output shape (16, 14, 14)
    )
    self.conv2 = nn.Sequential(     # input shape (16, 14, 14)
      nn.Conv2d(16, 32, 5, 1, 2),   # output shape (32, 14, 14)
      nn.ReLU(),           # activation
      nn.MaxPool2d(2),        # output shape (32, 7, 7)
    )
    self.out = nn.Linear(32 * 7 * 7, 10)  # fully connected layer, output 10 classes
 
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size(0), -1)      # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
    output = self.out(x)
    return output, x  # return x for visualization 
cnn = CNN()
print(cnn) # net architecture
 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()            # the target label is not one-hotted 
 
# training and testing
for epoch in range(EPOCH):
  for step, (x, y) in enumerate(train_loader):  # gives batch data, normalize x when iterate train_loader
    b_x = Variable(x)  # batch x
    b_y = Variable(y)  # batch y
 
    output = cnn(b_x)[0]        # cnn output
    loss = loss_func(output, b_y)  # cross entropy loss
    optimizer.zero_grad()      # clear gradients for this training step
    loss.backward()         # backpropagation, compute gradients
    optimizer.step()        # apply gradients
 
    if step % 50 == 0:
      cnn.eval()
      eval_loss = 0.
      eval_acc = 0.
      for i, (tx, ty) in enumerate(test_loader):
        t_x = Variable(tx)
        t_y = Variable(ty)
        output = cnn(t_x)[0]
        loss = loss_func(output, t_y)
        eval_loss += loss.data[0]
        pred = torch.max(output, 1)[1]
        num_correct = (pred == t_y).sum()
        eval_acc += float(num_correct.data[0])
      acc_rate = eval_acc / float(len(test_data))
      print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))

图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》

结果如下:

pytorch cnn 识别手写的字实现自建图片数据

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
利用python画一颗心的方法示例
Jan 31 Python
Python3 socket同步通信简单示例
Jun 07 Python
利用Python查看目录中的文件示例详解
Aug 28 Python
python 连接各类主流数据库的实例代码
Jan 30 Python
python中将一个全部为int的list 转化为str的list方法
Apr 09 Python
python实现二维插值的三维显示
Dec 17 Python
Python3日期与时间戳转换的几种方法详解
Jun 04 Python
python设计微型小说网站(基于Django+Bootstrap框架)
Jul 08 Python
flask框架渲染Jinja模板与传入模板变量操作详解
Jan 25 Python
Python OrderedDict字典排序方法详解
May 21 Python
Python小白不正确的使用类变量实例
May 29 Python
python文件及目录操作代码汇总
Jul 08 Python
pytorch 把MNIST数据集转换成图片和txt的方法
May 20 #Python
Python安装lz4-0.10.1遇到的坑
May 20 #Python
Python requests发送post请求的一些疑点
May 20 #Python
python中virtualenvwrapper安装与使用
May 20 #Python
django静态文件加载的方法
May 20 #Python
django中静态文件配置static的方法
May 20 #Python
Python中跳台阶、变态跳台阶与矩形覆盖问题的解决方法
May 19 #Python
You might like
PHP的FTP学习(一)[转自奥索]
2006/10/09 PHP
PHP数组遍历知识汇总(包含遍历方法、数组指针操作函数、数组遍历测速)
2014/07/05 PHP
Smarty实现页面静态化(生成HTML)的方法
2016/05/23 PHP
Yii2中DropDownList简单用法示例
2016/07/18 PHP
thinkPHP框架实现的短信接口验证码功能示例
2018/06/20 PHP
PHP接入微信H5支付的方法示例
2019/10/28 PHP
禁止刷新,回退的JS
2006/11/25 Javascript
又一个图片自动缩小的JS代码
2007/03/10 Javascript
JQuery textlimit 显示用户输入的字符数 限制用户输入的字符数
2009/05/14 Javascript
js实现GridView单选效果自动设置交替行、选中行、鼠标移动行背景色
2010/05/27 Javascript
分享14个很酷的jQuery导航菜单插件
2011/04/25 Javascript
深入理解javascript学习笔记(一) 编写高质量代码
2012/08/09 Javascript
js获得鼠标的坐标值的方法
2013/03/13 Javascript
animate动画示例(泪奔的小孩)及stop和delay的使用
2013/05/06 Javascript
ExtJS 刷新后如何默认选中刷新前最后一次选中的节点
2014/04/03 Javascript
jQuery插件制作之参数用法实例分析
2015/06/01 Javascript
JS与jQuery实现隔行变色的方法
2016/09/09 Javascript
从零开始学习Node.js系列教程二:文本提交与显示方法
2017/04/13 Javascript
js隐式转换的知识实例讲解
2018/09/28 Javascript
jquery 时间戳转日期过程详解
2019/10/12 jQuery
使用webpack将ES6转化ES5的实现方法
2019/10/13 Javascript
基于vue+echarts 数据可视化大屏展示的方法示例
2020/03/09 Javascript
从Node.js事件触发器到Vue自定义事件的深入讲解
2020/06/26 Javascript
Python单体模式的几种常见实现方法详解
2017/07/28 Python
Python遍历pandas数据方法总结
2018/02/09 Python
Python实现通过解析域名获取ip地址的方法分析
2019/05/17 Python
Python Django实现layui风格+django分页功能的例子
2019/08/29 Python
Python日期格式和字符串格式相互转换的方法
2020/02/18 Python
美国正版电视节目和电影在线观看:Hulu
2018/05/24 全球购物
德国购买门票网站:ADticket.de
2019/10/31 全球购物
面向对象编程的优势是什么
2015/12/17 面试题
简历中自我评价分享
2013/10/09 职场文书
2014年五一活动策划方案
2014/03/15 职场文书
丧事答谢词
2015/01/05 职场文书
生产现场禁烟通知
2015/04/23 职场文书
2016大学生国家助学贷款承诺书
2016/03/25 职场文书