tensorflow识别自己手写数字


Posted in Python onMarch 14, 2018

tensorflow作为google开源的项目,现在赶超了caffe,好像成为最受欢迎的深度学习框架。确实在编写的时候更能感受到代码的真实存在,这点和caffe不同,caffe通过编写配置文件进行网络的生成。环境tensorflow是0.10的版本,注意其他版本有的语句会有错误,这是tensorflow版本之间的兼容问题。

还需要安装PIL:pip install Pillow

图片的格式: 

? 图像标准化,可安装在20×20像素的框内,同时保留其长宽比。
? 图片都集中在一个28×28的图像中。
? 像素以列为主进行排序。像素值0到255,0表示背景(白色),255表示前景(黑色)。

创建一个.png的文件,背景是白色的,手写的字体是黑色的,

下面是数据测试的代码,一个两层的卷积神经网,然后用save进行模型的保存。

# coding: UTF-8 
import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
import input_data 
''''' 
得到数据 
''' 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
 
training = mnist.train.images 
trainlable = mnist.train.labels 
testing = mnist.test.images 
testlabel = mnist.test.labels 
 
print ("MNIST loaded") 
# 获取交互式的方式 
sess = tf.InteractiveSession() 
# 初始化变量 
x = tf.placeholder("float", shape=[None, 784]) 
y_ = tf.placeholder("float", shape=[None, 10]) 
W = tf.Variable(tf.zeros([784, 10])) 
b = tf.Variable(tf.zeros([10])) 
''''' 
生成权重函数,其中shape是数据的形状 
''' 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
''''' 
生成偏执项 其中shape是数据形状 
''' 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], 
             strides=[1, 2, 2, 1], padding='SAME') 
 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
W_fc1 = weight_variable([7 * 7 * 64, 1024]) 
b_fc1 = bias_variable([1024]) 
 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
keep_prob = tf.placeholder("float") 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
 
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 
 
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 
 
# 保存网络训练的参数 
saver = tf.train.Saver() 
sess.run(tf.initialize_all_variables()) 
for i in range(8000): 
 batch = mnist.train.next_batch(50) 
 if i%100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={ 
    x:batch[0], y_: batch[1], keep_prob: 1.0}) 
  print "step %d, training accuracy %g"%(i, train_accuracy) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 
 
save_path = saver.save(sess, "model_mnist.ckpt") 
print("Model saved in life:", save_path) 
 
print "test accuracy %g"%accuracy.eval(feed_dict={ 
  x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

其中input_data.py如下代码,是进行mnist数据集的下载的:代码是由mnist数据集提供的官方下载的版本。

# Copyright 2015 Google Inc. All Rights Reserved. 
# 
# Licensed under the Apache License, Version 2.0 (the "License"); 
# you may not use this file except in compliance with the License. 
# You may obtain a copy of the License at 
# 
#   http://www.apache.org/licenses/LICENSE-2.0 
# 
# Unless required by applicable law or agreed to in writing, software 
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
# See the License for the specific language governing permissions and 
# limitations under the License. 
# ============================================================================== 
"""Functions for downloading and reading MNIST data.""" 
from __future__ import absolute_import 
from __future__ import division 
from __future__ import print_function 
import gzip 
import os 
import tensorflow.python.platform 
import numpy 
from six.moves import urllib 
from six.moves import xrange # pylint: disable=redefined-builtin 
import tensorflow as tf 
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 
def maybe_download(filename, work_directory): 
 """Download the data from Yann's website, unless it's already here.""" 
 if not os.path.exists(work_directory): 
  os.mkdir(work_directory) 
 filepath = os.path.join(work_directory, filename) 
 if not os.path.exists(filepath): 
  filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) 
  statinfo = os.stat(filepath) 
  print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 
 return filepath 
def _read32(bytestream): 
 dt = numpy.dtype(numpy.uint32).newbyteorder('>') 
 return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 
def extract_images(filename): 
 """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 
 print('Extracting', filename) 
 with gzip.open(filename) as bytestream: 
  magic = _read32(bytestream) 
  if magic != 2051: 
   raise ValueError( 
     'Invalid magic number %d in MNIST image file: %s' % 
     (magic, filename)) 
  num_images = _read32(bytestream) 
  rows = _read32(bytestream) 
  cols = _read32(bytestream) 
  buf = bytestream.read(rows * cols * num_images) 
  data = numpy.frombuffer(buf, dtype=numpy.uint8) 
  data = data.reshape(num_images, rows, cols, 1) 
  return data 
def dense_to_one_hot(labels_dense, num_classes=10): 
 """Convert class labels from scalars to one-hot vectors.""" 
 num_labels = labels_dense.shape[0] 
 index_offset = numpy.arange(num_labels) * num_classes 
 labels_one_hot = numpy.zeros((num_labels, num_classes)) 
 labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 
 return labels_one_hot 
def extract_labels(filename, one_hot=False): 
 """Extract the labels into a 1D uint8 numpy array [index].""" 
 print('Extracting', filename) 
 with gzip.open(filename) as bytestream: 
  magic = _read32(bytestream) 
  if magic != 2049: 
   raise ValueError( 
     'Invalid magic number %d in MNIST label file: %s' % 
     (magic, filename)) 
  num_items = _read32(bytestream) 
  buf = bytestream.read(num_items) 
  labels = numpy.frombuffer(buf, dtype=numpy.uint8) 
  if one_hot: 
   return dense_to_one_hot(labels) 
  return labels 
class DataSet(object): 
 def __init__(self, images, labels, fake_data=False, one_hot=False, 
        dtype=tf.float32): 
  """Construct a DataSet. 
  one_hot arg is used only if fake_data is true. `dtype` can be either 
  `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into 
  `[0, 1]`. 
  """ 
  dtype = tf.as_dtype(dtype).base_dtype 
  if dtype not in (tf.uint8, tf.float32): 
   raise TypeError('Invalid image dtype %r, expected uint8 or float32' % 
           dtype) 
  if fake_data: 
   self._num_examples = 10000 
   self.one_hot = one_hot 
  else: 
   assert images.shape[0] == labels.shape[0], ( 
     'images.shape: %s labels.shape: %s' % (images.shape, 
                         labels.shape)) 
   self._num_examples = images.shape[0] 
   # Convert shape from [num examples, rows, columns, depth] 
   # to [num examples, rows*columns] (assuming depth == 1) 
   assert images.shape[3] == 1 
   images = images.reshape(images.shape[0], 
               images.shape[1] * images.shape[2]) 
   if dtype == tf.float32: 
    # Convert from [0, 255] -> [0.0, 1.0]. 
    images = images.astype(numpy.float32) 
    images = numpy.multiply(images, 1.0 / 255.0) 
  self._images = images 
  self._labels = labels 
  self._epochs_completed = 0 
  self._index_in_epoch = 0 
 @property 
 def images(self): 
  return self._images 
 @property 
 def labels(self): 
  return self._labels 
 @property 
 def num_examples(self): 
  return self._num_examples 
 @property 
 def epochs_completed(self): 
  return self._epochs_completed 
 def next_batch(self, batch_size, fake_data=False): 
  """Return the next `batch_size` examples from this data set.""" 
  if fake_data: 
   fake_image = [1] * 784 
   if self.one_hot: 
    fake_label = [1] + [0] * 9 
   else: 
    fake_label = 0 
   return [fake_image for _ in xrange(batch_size)], [ 
     fake_label for _ in xrange(batch_size)] 
  start = self._index_in_epoch 
  self._index_in_epoch += batch_size 
  if self._index_in_epoch > self._num_examples: 
   # Finished epoch 
   self._epochs_completed += 1 
   # Shuffle the data 
   perm = numpy.arange(self._num_examples) 
   numpy.random.shuffle(perm) 
   self._images = self._images[perm] 
   self._labels = self._labels[perm] 
   # Start next epoch 
   start = 0 
   self._index_in_epoch = batch_size 
   assert batch_size <= self._num_examples 
  end = self._index_in_epoch 
  return self._images[start:end], self._labels[start:end] 
def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32): 
 class DataSets(object): 
  pass 
 data_sets = DataSets() 
 if fake_data: 
  def fake(): 
   return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) 
  data_sets.train = fake() 
  data_sets.validation = fake() 
  data_sets.test = fake() 
  return data_sets 
 TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 
 TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 
 TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 
 TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 
 VALIDATION_SIZE = 5000 
 local_file = maybe_download(TRAIN_IMAGES, train_dir) 
 train_images = extract_images(local_file) 
 local_file = maybe_download(TRAIN_LABELS, train_dir) 
 train_labels = extract_labels(local_file, one_hot=one_hot) 
 local_file = maybe_download(TEST_IMAGES, train_dir) 
 test_images = extract_images(local_file) 
 local_file = maybe_download(TEST_LABELS, train_dir) 
 test_labels = extract_labels(local_file, one_hot=one_hot) 
 validation_images = train_images[:VALIDATION_SIZE] 
 validation_labels = train_labels[:VALIDATION_SIZE] 
 train_images = train_images[VALIDATION_SIZE:] 
 train_labels = train_labels[VALIDATION_SIZE:] 
 data_sets.train = DataSet(train_images, train_labels, dtype=dtype) 
 data_sets.validation = DataSet(validation_images, validation_labels, 
                 dtype=dtype) 
 data_sets.test = DataSet(test_images, test_labels, dtype=dtype) 
 return data_sets

然后进行代码的测试:

# import modules 
import sys 
import tensorflow as tf 
from PIL import Image, ImageFilter 
 
 
def predictint(imvalue): 
  """ 
  This function returns the predicted integer. 
  The imput is the pixel values from the imageprepare() function. 
  """ 
 
  # Define the model (same as when creating the model file) 
  x = tf.placeholder(tf.float32, [None, 784]) 
  W = tf.Variable(tf.zeros([784, 10])) 
  b = tf.Variable(tf.zeros([10])) 
 
  def weight_variable(shape): 
    initial = tf.truncated_normal(shape, stddev=0.1) 
    return tf.Variable(initial) 
 
  def bias_variable(shape): 
    initial = tf.constant(0.1, shape=shape) 
    return tf.Variable(initial) 
 
  def conv2d(x, W): 
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
  def max_pool_2x2(x): 
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 
 
  W_conv1 = weight_variable([5, 5, 1, 32]) 
  b_conv1 = bias_variable([32]) 
 
  x_image = tf.reshape(x, [-1, 28, 28, 1]) 
  h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
  h_pool1 = max_pool_2x2(h_conv1) 
 
  W_conv2 = weight_variable([5, 5, 32, 64]) 
  b_conv2 = bias_variable([64]) 
 
  h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
  h_pool2 = max_pool_2x2(h_conv2) 
 
  W_fc1 = weight_variable([7 * 7 * 64, 1024]) 
  b_fc1 = bias_variable([1024]) 
 
  h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) 
  h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
  keep_prob = tf.placeholder(tf.float32) 
  h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
  W_fc2 = weight_variable([1024, 10]) 
  b_fc2 = bias_variable([10]) 
 
  y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 
 
  init_op = tf.initialize_all_variables() 
  saver = tf.train.Saver() 
 
  """ 
  Load the model_mnist.ckpt file 
  file is stored in the same directory as this python script is started 
  Use the model to predict the integer. Integer is returend as list. 
  Based on the documentatoin at 
  https://www.tensorflow.org/versions/master/how_tos/variables/index.html 
  """ 
  with tf.Session() as sess: 
    sess.run(init_op) 
    saver.restore(sess, "model_mnist.ckpt") 
    # print ("Model restored.") 
 
    prediction = tf.argmax(y_conv, 1) 
    return prediction.eval(feed_dict={x: [imvalue], keep_prob: 1.0}, session=sess) 
 
 
def imageprepare(argv): 
  """ 
  This function returns the pixel values. 
  The imput is a png file location. 
  """ 
  im = Image.open(argv).convert('L') 
  width = float(im.size[0]) 
  height = float(im.size[1]) 
  newImage = Image.new('L', (28, 28), (255)) # creates white canvas of 28x28 pixels 
 
  if width > height: # check which dimension is bigger 
    # Width is bigger. Width becomes 20 pixels. 
    nheight = int(round((20.0 / width * height), 0)) # resize height according to ratio width 
    if (nheight == 0): # rare case but minimum is 1 pixel 
      nheigth = 1 
      # resize and sharpen 
    img = im.resize((20, nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) 
    wtop = int(round(((28 - nheight) / 2), 0)) # caculate horizontal pozition 
    newImage.paste(img, (4, wtop)) # paste resized image on white canvas 
  else: 
    # Height is bigger. Heigth becomes 20 pixels. 
    nwidth = int(round((20.0 / height * width), 0)) # resize width according to ratio height 
    if (nwidth == 0): # rare case but minimum is 1 pixel 
      nwidth = 1 
      # resize and sharpen 
    img = im.resize((nwidth, 20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) 
    wleft = int(round(((28 - nwidth) / 2), 0)) # caculate vertical pozition 
    newImage.paste(img, (wleft, 4)) # paste resized image on white canvas 
 
  # newImage.save("sample.png") 
 
  tv = list(newImage.getdata()) # get pixel values 
 
  # normalize pixels to 0 and 1. 0 is pure white, 1 is pure black. 
  tva = [(255 - x) * 1.0 / 255.0 for x in tv] 
  return tva 
  # print(tva) 
 
 
def main(argv): 
  """ 
  Main function. 
  """ 
  imvalue = imageprepare(argv) 
  predint = predictint(imvalue) 
  print (predint[0]) # first value in list 
 
 
if __name__ == "__main__": 
  main('2.png')

其中我用于测试的代码如下:

tensorflow识别自己手写数字

可以将图片另存到路径下面,然后进行测试。

(1)载入我的手写数字的图像。
(2)将图像转换为黑白(模式“L”)
(3)确定原始图像的尺寸是最大的
(4)调整图像的大小,使得最大尺寸(醚的高度及宽度)为20像素,并且以相同的比例最小化尺寸刻度。
(5)锐化图像。这会极大地强化结果。
(6)把图像粘贴在28×28像素的白色画布上。在最大的尺寸上从顶部或侧面居中图像4个像素。最大尺寸始终是20个像素和4 + 20 + 4 = 28,最小尺寸被定位在28和缩放的图像的新的大小之间差的一半。
(7)获取新的图像(画布+居中的图像)的像素值。
(8)归一化像素值到0和1之间的一个值(这也在TensorFlow MNIST教程中完成)。其中0是白色的,1是纯黑色。从步骤7得到的像素值是与之相反的,其中255是白色的,0黑色,所以数值必须反转。下述公式包括反转和规格化(255-X)* 1.0 / 255.0

Python 相关文章推荐
Python编程中用close()方法关闭文件的教程
May 24 Python
用python写一个windows下的定时关机脚本(推荐)
Mar 21 Python
Python删除Java源文件中全部注释的实现方法
Aug 30 Python
python中利用Future对象回调别的函数示例代码
Sep 07 Python
Python实现读取json文件到excel表
Nov 18 Python
python自动发送邮件脚本
Jun 20 Python
在pandas多重索引multiIndex中选定指定索引的行方法
Nov 16 Python
Python3.5装饰器原理及应用实例详解
Apr 30 Python
Python 等分切分数据及规则命名的实例代码
Aug 16 Python
Python缓存技术实现过程详解
Sep 25 Python
Python 中判断列表是否为空的方法
Nov 24 Python
python 使用事件对象asyncio.Event来同步协程的操作
May 04 Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
python批量设置多个Excel文件页眉页脚的脚本
Mar 14 #Python
You might like
一段批量给页面上的控件赋值js
2010/06/19 Javascript
基于jquery的动态创建表格的插件
2011/04/05 Javascript
javascript window.confirm确认 取消对话框实现代码小结
2012/10/21 Javascript
JQuery事件e参数的方法preventDefault()取消默认行为
2013/09/26 Javascript
jquery预加载图片的方法
2015/05/27 Javascript
javascript获取网页宽高方法汇总
2015/07/19 Javascript
jQuery实现无限往下滚动效果代码
2016/04/16 Javascript
浅谈js的url解析函数封装
2016/06/28 Javascript
js实现日历与定时器
2017/02/22 Javascript
jQuery初级教程之网站品牌列表效果
2017/08/02 jQuery
JS实现定时任务每隔N秒请求后台setInterval定时和ajax请求问题
2017/10/15 Javascript
Vue数据双向绑定原理及简单实现方法
2018/05/18 Javascript
Vue2.x和Vue3.x的双向绑定原理详解
2020/11/05 Javascript
[01:06]DOTA2小知识课堂 Ep.02 吹风竟可解梦境缠绕
2019/12/05 DOTA
Python的__builtin__模块中的一些要点知识
2015/05/02 Python
Python使用multiprocessing实现一个最简单的分布式作业调度系统
2016/03/14 Python
在Python中通过threading模块定义和调用线程的方法
2016/07/12 Python
Python基于matplotlib绘制栈式直方图的方法示例
2017/08/09 Python
Django的分页器实例(paginator)
2017/12/01 Python
wxPython的安装图文教程(Windows)
2017/12/28 Python
python实现人脸识别经典算法(一) 特征脸法
2018/03/13 Python
Python基于递归算法实现的汉诺塔与Fibonacci数列示例
2018/04/18 Python
详解python itertools功能
2020/02/07 Python
Python使用Socket实现简单聊天程序
2020/02/28 Python
python中for in的用法详解
2020/04/17 Python
TensorFlow的reshape操作 tf.reshape的实现
2020/04/19 Python
python进度条显示之tqmd模块
2020/08/22 Python
Booking.com英国官网:全球酒店在线预订网站
2018/04/21 全球购物
全球精选男装和家居用品:Article
2020/04/13 全球购物
骨干教师培训感言
2014/01/16 职场文书
某某同志考察材料
2014/05/28 职场文书
十一国庆节“向国旗敬礼”主题班会活动方案
2014/09/27 职场文书
工作会议简报
2015/07/20 职场文书
初任公务员培训心得体会
2016/01/08 职场文书
不会写演讲稿,快来看看这篇文章!
2019/08/06 职场文书
Django分页器的用法你都了解吗
2021/05/26 Python