tensorflow识别自己手写数字


Posted in Python onMarch 14, 2018

tensorflow作为google开源的项目,现在赶超了caffe,好像成为最受欢迎的深度学习框架。确实在编写的时候更能感受到代码的真实存在,这点和caffe不同,caffe通过编写配置文件进行网络的生成。环境tensorflow是0.10的版本,注意其他版本有的语句会有错误,这是tensorflow版本之间的兼容问题。

还需要安装PIL:pip install Pillow

图片的格式: 

? 图像标准化,可安装在20×20像素的框内,同时保留其长宽比。
? 图片都集中在一个28×28的图像中。
? 像素以列为主进行排序。像素值0到255,0表示背景(白色),255表示前景(黑色)。

创建一个.png的文件,背景是白色的,手写的字体是黑色的,

下面是数据测试的代码,一个两层的卷积神经网,然后用save进行模型的保存。

# coding: UTF-8 
import tensorflow as tf 
import numpy as np 
import matplotlib.pyplot as plt 
import input_data 
''''' 
得到数据 
''' 
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 
 
training = mnist.train.images 
trainlable = mnist.train.labels 
testing = mnist.test.images 
testlabel = mnist.test.labels 
 
print ("MNIST loaded") 
# 获取交互式的方式 
sess = tf.InteractiveSession() 
# 初始化变量 
x = tf.placeholder("float", shape=[None, 784]) 
y_ = tf.placeholder("float", shape=[None, 10]) 
W = tf.Variable(tf.zeros([784, 10])) 
b = tf.Variable(tf.zeros([10])) 
''''' 
生成权重函数,其中shape是数据的形状 
''' 
def weight_variable(shape): 
  initial = tf.truncated_normal(shape, stddev=0.1) 
  return tf.Variable(initial) 
''''' 
生成偏执项 其中shape是数据形状 
''' 
def bias_variable(shape): 
  initial = tf.constant(0.1, shape=shape) 
  return tf.Variable(initial) 
 
def conv2d(x, W): 
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
def max_pool_2x2(x): 
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], 
             strides=[1, 2, 2, 1], padding='SAME') 
 
W_conv1 = weight_variable([5, 5, 1, 32]) 
b_conv1 = bias_variable([32]) 
x_image = tf.reshape(x, [-1, 28, 28, 1]) 
 
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
h_pool1 = max_pool_2x2(h_conv1) 
 
W_conv2 = weight_variable([5, 5, 32, 64]) 
b_conv2 = bias_variable([64]) 
 
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
h_pool2 = max_pool_2x2(h_conv2) 
 
 
W_fc1 = weight_variable([7 * 7 * 64, 1024]) 
b_fc1 = bias_variable([1024]) 
 
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
keep_prob = tf.placeholder("float") 
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
W_fc2 = weight_variable([1024, 10]) 
b_fc2 = bias_variable([10]) 
 
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 
 
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) 
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) 
 
# 保存网络训练的参数 
saver = tf.train.Saver() 
sess.run(tf.initialize_all_variables()) 
for i in range(8000): 
 batch = mnist.train.next_batch(50) 
 if i%100 == 0: 
  train_accuracy = accuracy.eval(feed_dict={ 
    x:batch[0], y_: batch[1], keep_prob: 1.0}) 
  print "step %d, training accuracy %g"%(i, train_accuracy) 
 train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 
 
save_path = saver.save(sess, "model_mnist.ckpt") 
print("Model saved in life:", save_path) 
 
print "test accuracy %g"%accuracy.eval(feed_dict={ 
  x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})

其中input_data.py如下代码,是进行mnist数据集的下载的:代码是由mnist数据集提供的官方下载的版本。

# Copyright 2015 Google Inc. All Rights Reserved. 
# 
# Licensed under the Apache License, Version 2.0 (the "License"); 
# you may not use this file except in compliance with the License. 
# You may obtain a copy of the License at 
# 
#   http://www.apache.org/licenses/LICENSE-2.0 
# 
# Unless required by applicable law or agreed to in writing, software 
# distributed under the License is distributed on an "AS IS" BASIS, 
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
# See the License for the specific language governing permissions and 
# limitations under the License. 
# ============================================================================== 
"""Functions for downloading and reading MNIST data.""" 
from __future__ import absolute_import 
from __future__ import division 
from __future__ import print_function 
import gzip 
import os 
import tensorflow.python.platform 
import numpy 
from six.moves import urllib 
from six.moves import xrange # pylint: disable=redefined-builtin 
import tensorflow as tf 
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 
def maybe_download(filename, work_directory): 
 """Download the data from Yann's website, unless it's already here.""" 
 if not os.path.exists(work_directory): 
  os.mkdir(work_directory) 
 filepath = os.path.join(work_directory, filename) 
 if not os.path.exists(filepath): 
  filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) 
  statinfo = os.stat(filepath) 
  print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 
 return filepath 
def _read32(bytestream): 
 dt = numpy.dtype(numpy.uint32).newbyteorder('>') 
 return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 
def extract_images(filename): 
 """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 
 print('Extracting', filename) 
 with gzip.open(filename) as bytestream: 
  magic = _read32(bytestream) 
  if magic != 2051: 
   raise ValueError( 
     'Invalid magic number %d in MNIST image file: %s' % 
     (magic, filename)) 
  num_images = _read32(bytestream) 
  rows = _read32(bytestream) 
  cols = _read32(bytestream) 
  buf = bytestream.read(rows * cols * num_images) 
  data = numpy.frombuffer(buf, dtype=numpy.uint8) 
  data = data.reshape(num_images, rows, cols, 1) 
  return data 
def dense_to_one_hot(labels_dense, num_classes=10): 
 """Convert class labels from scalars to one-hot vectors.""" 
 num_labels = labels_dense.shape[0] 
 index_offset = numpy.arange(num_labels) * num_classes 
 labels_one_hot = numpy.zeros((num_labels, num_classes)) 
 labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 
 return labels_one_hot 
def extract_labels(filename, one_hot=False): 
 """Extract the labels into a 1D uint8 numpy array [index].""" 
 print('Extracting', filename) 
 with gzip.open(filename) as bytestream: 
  magic = _read32(bytestream) 
  if magic != 2049: 
   raise ValueError( 
     'Invalid magic number %d in MNIST label file: %s' % 
     (magic, filename)) 
  num_items = _read32(bytestream) 
  buf = bytestream.read(num_items) 
  labels = numpy.frombuffer(buf, dtype=numpy.uint8) 
  if one_hot: 
   return dense_to_one_hot(labels) 
  return labels 
class DataSet(object): 
 def __init__(self, images, labels, fake_data=False, one_hot=False, 
        dtype=tf.float32): 
  """Construct a DataSet. 
  one_hot arg is used only if fake_data is true. `dtype` can be either 
  `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into 
  `[0, 1]`. 
  """ 
  dtype = tf.as_dtype(dtype).base_dtype 
  if dtype not in (tf.uint8, tf.float32): 
   raise TypeError('Invalid image dtype %r, expected uint8 or float32' % 
           dtype) 
  if fake_data: 
   self._num_examples = 10000 
   self.one_hot = one_hot 
  else: 
   assert images.shape[0] == labels.shape[0], ( 
     'images.shape: %s labels.shape: %s' % (images.shape, 
                         labels.shape)) 
   self._num_examples = images.shape[0] 
   # Convert shape from [num examples, rows, columns, depth] 
   # to [num examples, rows*columns] (assuming depth == 1) 
   assert images.shape[3] == 1 
   images = images.reshape(images.shape[0], 
               images.shape[1] * images.shape[2]) 
   if dtype == tf.float32: 
    # Convert from [0, 255] -> [0.0, 1.0]. 
    images = images.astype(numpy.float32) 
    images = numpy.multiply(images, 1.0 / 255.0) 
  self._images = images 
  self._labels = labels 
  self._epochs_completed = 0 
  self._index_in_epoch = 0 
 @property 
 def images(self): 
  return self._images 
 @property 
 def labels(self): 
  return self._labels 
 @property 
 def num_examples(self): 
  return self._num_examples 
 @property 
 def epochs_completed(self): 
  return self._epochs_completed 
 def next_batch(self, batch_size, fake_data=False): 
  """Return the next `batch_size` examples from this data set.""" 
  if fake_data: 
   fake_image = [1] * 784 
   if self.one_hot: 
    fake_label = [1] + [0] * 9 
   else: 
    fake_label = 0 
   return [fake_image for _ in xrange(batch_size)], [ 
     fake_label for _ in xrange(batch_size)] 
  start = self._index_in_epoch 
  self._index_in_epoch += batch_size 
  if self._index_in_epoch > self._num_examples: 
   # Finished epoch 
   self._epochs_completed += 1 
   # Shuffle the data 
   perm = numpy.arange(self._num_examples) 
   numpy.random.shuffle(perm) 
   self._images = self._images[perm] 
   self._labels = self._labels[perm] 
   # Start next epoch 
   start = 0 
   self._index_in_epoch = batch_size 
   assert batch_size <= self._num_examples 
  end = self._index_in_epoch 
  return self._images[start:end], self._labels[start:end] 
def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32): 
 class DataSets(object): 
  pass 
 data_sets = DataSets() 
 if fake_data: 
  def fake(): 
   return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) 
  data_sets.train = fake() 
  data_sets.validation = fake() 
  data_sets.test = fake() 
  return data_sets 
 TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 
 TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 
 TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 
 TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 
 VALIDATION_SIZE = 5000 
 local_file = maybe_download(TRAIN_IMAGES, train_dir) 
 train_images = extract_images(local_file) 
 local_file = maybe_download(TRAIN_LABELS, train_dir) 
 train_labels = extract_labels(local_file, one_hot=one_hot) 
 local_file = maybe_download(TEST_IMAGES, train_dir) 
 test_images = extract_images(local_file) 
 local_file = maybe_download(TEST_LABELS, train_dir) 
 test_labels = extract_labels(local_file, one_hot=one_hot) 
 validation_images = train_images[:VALIDATION_SIZE] 
 validation_labels = train_labels[:VALIDATION_SIZE] 
 train_images = train_images[VALIDATION_SIZE:] 
 train_labels = train_labels[VALIDATION_SIZE:] 
 data_sets.train = DataSet(train_images, train_labels, dtype=dtype) 
 data_sets.validation = DataSet(validation_images, validation_labels, 
                 dtype=dtype) 
 data_sets.test = DataSet(test_images, test_labels, dtype=dtype) 
 return data_sets

然后进行代码的测试:

# import modules 
import sys 
import tensorflow as tf 
from PIL import Image, ImageFilter 
 
 
def predictint(imvalue): 
  """ 
  This function returns the predicted integer. 
  The imput is the pixel values from the imageprepare() function. 
  """ 
 
  # Define the model (same as when creating the model file) 
  x = tf.placeholder(tf.float32, [None, 784]) 
  W = tf.Variable(tf.zeros([784, 10])) 
  b = tf.Variable(tf.zeros([10])) 
 
  def weight_variable(shape): 
    initial = tf.truncated_normal(shape, stddev=0.1) 
    return tf.Variable(initial) 
 
  def bias_variable(shape): 
    initial = tf.constant(0.1, shape=shape) 
    return tf.Variable(initial) 
 
  def conv2d(x, W): 
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') 
 
  def max_pool_2x2(x): 
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') 
 
  W_conv1 = weight_variable([5, 5, 1, 32]) 
  b_conv1 = bias_variable([32]) 
 
  x_image = tf.reshape(x, [-1, 28, 28, 1]) 
  h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) 
  h_pool1 = max_pool_2x2(h_conv1) 
 
  W_conv2 = weight_variable([5, 5, 32, 64]) 
  b_conv2 = bias_variable([64]) 
 
  h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) 
  h_pool2 = max_pool_2x2(h_conv2) 
 
  W_fc1 = weight_variable([7 * 7 * 64, 1024]) 
  b_fc1 = bias_variable([1024]) 
 
  h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) 
  h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) 
 
  keep_prob = tf.placeholder(tf.float32) 
  h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) 
 
  W_fc2 = weight_variable([1024, 10]) 
  b_fc2 = bias_variable([10]) 
 
  y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 
 
  init_op = tf.initialize_all_variables() 
  saver = tf.train.Saver() 
 
  """ 
  Load the model_mnist.ckpt file 
  file is stored in the same directory as this python script is started 
  Use the model to predict the integer. Integer is returend as list. 
  Based on the documentatoin at 
  https://www.tensorflow.org/versions/master/how_tos/variables/index.html 
  """ 
  with tf.Session() as sess: 
    sess.run(init_op) 
    saver.restore(sess, "model_mnist.ckpt") 
    # print ("Model restored.") 
 
    prediction = tf.argmax(y_conv, 1) 
    return prediction.eval(feed_dict={x: [imvalue], keep_prob: 1.0}, session=sess) 
 
 
def imageprepare(argv): 
  """ 
  This function returns the pixel values. 
  The imput is a png file location. 
  """ 
  im = Image.open(argv).convert('L') 
  width = float(im.size[0]) 
  height = float(im.size[1]) 
  newImage = Image.new('L', (28, 28), (255)) # creates white canvas of 28x28 pixels 
 
  if width > height: # check which dimension is bigger 
    # Width is bigger. Width becomes 20 pixels. 
    nheight = int(round((20.0 / width * height), 0)) # resize height according to ratio width 
    if (nheight == 0): # rare case but minimum is 1 pixel 
      nheigth = 1 
      # resize and sharpen 
    img = im.resize((20, nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) 
    wtop = int(round(((28 - nheight) / 2), 0)) # caculate horizontal pozition 
    newImage.paste(img, (4, wtop)) # paste resized image on white canvas 
  else: 
    # Height is bigger. Heigth becomes 20 pixels. 
    nwidth = int(round((20.0 / height * width), 0)) # resize width according to ratio height 
    if (nwidth == 0): # rare case but minimum is 1 pixel 
      nwidth = 1 
      # resize and sharpen 
    img = im.resize((nwidth, 20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) 
    wleft = int(round(((28 - nwidth) / 2), 0)) # caculate vertical pozition 
    newImage.paste(img, (wleft, 4)) # paste resized image on white canvas 
 
  # newImage.save("sample.png") 
 
  tv = list(newImage.getdata()) # get pixel values 
 
  # normalize pixels to 0 and 1. 0 is pure white, 1 is pure black. 
  tva = [(255 - x) * 1.0 / 255.0 for x in tv] 
  return tva 
  # print(tva) 
 
 
def main(argv): 
  """ 
  Main function. 
  """ 
  imvalue = imageprepare(argv) 
  predint = predictint(imvalue) 
  print (predint[0]) # first value in list 
 
 
if __name__ == "__main__": 
  main('2.png')

其中我用于测试的代码如下:

tensorflow识别自己手写数字

可以将图片另存到路径下面,然后进行测试。

(1)载入我的手写数字的图像。
(2)将图像转换为黑白(模式“L”)
(3)确定原始图像的尺寸是最大的
(4)调整图像的大小,使得最大尺寸(醚的高度及宽度)为20像素,并且以相同的比例最小化尺寸刻度。
(5)锐化图像。这会极大地强化结果。
(6)把图像粘贴在28×28像素的白色画布上。在最大的尺寸上从顶部或侧面居中图像4个像素。最大尺寸始终是20个像素和4 + 20 + 4 = 28,最小尺寸被定位在28和缩放的图像的新的大小之间差的一半。
(7)获取新的图像(画布+居中的图像)的像素值。
(8)归一化像素值到0和1之间的一个值(这也在TensorFlow MNIST教程中完成)。其中0是白色的,1是纯黑色。从步骤7得到的像素值是与之相反的,其中255是白色的,0黑色,所以数值必须反转。下述公式包括反转和规格化(255-X)* 1.0 / 255.0

Python 相关文章推荐
二种python发送邮件实例讲解(python发邮件附件可以使用email模块实现)
Dec 03 Python
python爬取网页转换为PDF文件
Jun 07 Python
Python使用Flask-SQLAlchemy连接数据库操作示例
Aug 31 Python
使用python对文件中的单词进行提取的方法示例
Dec 21 Python
通过pykafka接收Kafka消息队列的方法
Dec 27 Python
三步实现Django Paginator分页的方法
Jun 11 Python
python将四元数变换为旋转矩阵的实例
Dec 04 Python
Python基于pyjnius库实现访问java类
Jul 31 Python
python实现粒子群算法
Oct 15 Python
Python爬虫之Selenium实现窗口截图
Dec 04 Python
pycharm配置python 设置pip安装源为豆瓣源
Feb 05 Python
python 判断字符串当中是否包含字符(str.contain)
Jun 01 Python
磁盘垃圾文件清理器python代码实现
Aug 24 #Python
Django自定义用户认证示例详解
Mar 14 #Python
python如何压缩新文件到已有ZIP文件
Mar 14 #Python
python中format()函数的简单使用教程
Mar 14 #Python
Python批量提取PDF文件中文本的脚本
Mar 14 #Python
深入理解Django的中间件middleware
Mar 14 #Python
python批量设置多个Excel文件页眉页脚的脚本
Mar 14 #Python
You might like
PHP中使用mktime获取时间戳的一个黑色幽默分析
2012/05/31 PHP
wampserver改变默认网站目录的办法
2015/08/05 PHP
基于PHP后台的Android新闻浏览客户端
2016/05/23 PHP
PHP的RSA加密解密方法以及开发接口使用
2018/02/11 PHP
js+CSS 图片等比缩小并垂直居中实现代码
2008/12/01 Javascript
巧用jquery解决下拉菜单被Div遮挡的相关问题
2014/02/13 Javascript
javascript中attribute和property的区别详解
2014/06/05 Javascript
js实现图片轮播效果
2015/12/19 Javascript
详解JavaScript中的事件流和事件处理程序
2016/05/20 Javascript
JavaScript基于原型链的继承
2016/06/22 Javascript
Angular自定义组件实现数据双向数据绑定的实例
2017/12/11 Javascript
使用Bootstrap4 + Vue2实现分页查询的示例代码
2017/12/21 Javascript
vue.js中$set与数组更新方法
2018/03/08 Javascript
JS中原始值和引用值的储存方式示例详解
2018/03/23 Javascript
vue 动态绑定背景图片的方法
2018/08/10 Javascript
vue-cli3环境变量与分环境打包的方法示例
2019/02/18 Javascript
详解jQuery-each()方法
2019/03/13 jQuery
JS数组方法shift()、unshift()用法实例分析
2020/01/18 Javascript
JS求解两数之和算法详解
2020/04/28 Javascript
深入理解 ES6中的 Reflect用法
2020/07/18 Javascript
vue中使用腾讯云Im的示例
2020/10/23 Javascript
python实现简单的单变量线性回归方法
2018/11/08 Python
Python3.5 + sklearn利用SVM自动识别字母验证码方法示例
2019/05/10 Python
查看keras的默认backend实现方式
2020/06/19 Python
Python必须了解的35个关键词
2020/07/16 Python
python 从list中随机取值的方法
2020/11/16 Python
详解Python调用系统命令的六种方法
2021/01/28 Python
Infababy英国:婴儿推车、Travel System婴儿车和婴儿汽车座椅销售
2018/05/23 全球购物
女士鞋子、包包和服装在线,第一款10美元:ShoeDazzle
2019/07/26 全球购物
俄罗斯最大的香水和化妆品网上商店:Randewoo
2020/11/05 全球购物
公司廉洁自律承诺书
2014/03/27 职场文书
事业单位竞聘上岗实施方案
2014/03/28 职场文书
初婚初育证明范本
2014/11/24 职场文书
出纳工作检讨书范文
2014/12/27 职场文书
家电创业计划书
2019/08/05 职场文书
企业转让协议书(范文2篇)
2019/08/15 职场文书