使用TensorFlow-Slim进行图像分类的实现


Posted in Python onDecember 31, 2019

参考 https://github.com/tensorflow/models/tree/master/slim

使用TensorFlow-Slim进行图像分类

准备

安装TensorFlow

参考 https://www.tensorflow.org/install/

如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本

wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl

下载TF-slim图像模型库

cd $WORKSPACE
git clone https://github.com/tensorflow/models/

准备数据

有不少公开数据集,这里以官网提供的Flowers为例。

官网提供了下载和转换数据的代码,为了理解代码并能使用自己的数据,这里参考官方提供的代码进行修改。

cd $WORKSPACE/data
wget http://download.tensorflow.org/example_images/flower_photos.tgz
tar zxf flower_photos.tgz

数据集文件夹结构如下:

flower_photos
├── daisy
│  ├── 100080576_f52e8ee070_n.jpg
│  └── ...
├── dandelion
├── LICENSE.txt
├── roses
├── sunflowers
└── tulips

由于实际情况中我们自己的数据集并不一定把图片按类别放在不同的文件夹里,故我们生成list.txt来表示图片路径与标签的关系。

Python代码:

import os

class_names_to_ids = {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
data_dir = 'flower_photos/'
output_path = 'list.txt'

fd = open(output_path, 'w')
for class_name in class_names_to_ids.keys():
  images_list = os.listdir(data_dir + class_name)
  for image_name in images_list:
    fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name]))

fd.close()

为了方便后期查看label标签,也可以定义labels.txt:

daisy
dandelion
roses
sunflowers
tulips

随机生成训练集与验证集:

Python代码:

import random

_NUM_VALIDATION = 350
_RANDOM_SEED = 0
list_path = 'list.txt'
train_list_path = 'list_train.txt'
val_list_path = 'list_val.txt'

fd = open(list_path)
lines = fd.readlines()
fd.close()
random.seed(_RANDOM_SEED)
random.shuffle(lines)

fd = open(train_list_path, 'w')
for line in lines[_NUM_VALIDATION:]:
  fd.write(line)

fd.close()
fd = open(val_list_path, 'w')
for line in lines[:_NUM_VALIDATION]:
  fd.write(line)

fd.close()

生成TFRecord数据:

Python代码:

import sys
sys.path.insert(0, '../models/slim/')
from datasets import dataset_utils
import math
import os
import tensorflow as tf

def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5):
  fd = open(list_path)
  lines = [line.split() for line in fd]
  fd.close()
  num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))
  with tf.Graph().as_default():
    decode_jpeg_data = tf.placeholder(dtype=tf.string)
    decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)
    with tf.Session('') as sess:
      for shard_id in range(_NUM_SHARDS):
        output_path = os.path.join(output_dir,
          'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS))
        tfrecord_writer = tf.python_io.TFRecordWriter(output_path)
        start_ndx = shard_id * num_per_shard
        end_ndx = min((shard_id + 1) * num_per_shard, len(lines))
        for i in range(start_ndx, end_ndx):
          sys.stdout.write('\r>> Converting image {}/{} shard {}'.format(
            i + 1, len(lines), shard_id))
          sys.stdout.flush()
          image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]), 'rb').read()
          image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data})
          height, width = image.shape[0], image.shape[1]
          example = dataset_utils.image_to_tfexample(
            image_data, b'jpg', height, width, int(lines[i][1]))
          tfrecord_writer.write(example.SerializeToString())
        tfrecord_writer.close()
  sys.stdout.write('\n')
  sys.stdout.flush()

os.system('mkdir -p train')
convert_dataset('list_train.txt', 'flower_photos', 'train/')
os.system('mkdir -p val')
convert_dataset('list_val.txt', 'flower_photos', 'val/')

得到的文件夹结构如下:

data
├── flower_photos
├── labels.txt
├── list_train.txt
├── list.txt
├── list_val.txt
├── train
│  ├── data_00000-of-00005.tfrecord
│  ├── ...
│  └── data_00004-of-00005.tfrecord
└── val
  ├── data_00000-of-00005.tfrecord
  ├── ...
  └── data_00004-of-00005.tfrecord

(可选)下载模型

官方提供了不少预训练模型,这里以Inception-ResNet-v2以例。

cd $WORKSPACE/checkpoints
wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
tar zxf inception_resnet_v2_2016_08_30.tar.gz

训练

读入数据

官方提供了读入Flowers数据集的代码models/slim/datasets/flowers.py,同样这里也是参考并修改成能读入上面定义的通用数据集。

把下面代码写入models/slim/datasets/dataset_classification.py。

import os
import tensorflow as tf
slim = tf.contrib.slim

def get_dataset(dataset_dir, num_samples, num_classes, labels_to_names_path=None, file_pattern='*.tfrecord'):
  file_pattern = os.path.join(dataset_dir, file_pattern)
  keys_to_features = {
    'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
    'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
    'image/class/label': tf.FixedLenFeature(
      [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }
  items_to_handlers = {
    'image': slim.tfexample_decoder.Image(),
    'label': slim.tfexample_decoder.Tensor('image/class/label'),
  }
  decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
  items_to_descriptions = {
    'image': 'A color image of varying size.',
    'label': 'A single integer between 0 and ' + str(num_classes - 1),
  }
  labels_to_names = None
  if labels_to_names_path is not None:
    fd = open(labels_to_names_path)
    labels_to_names = {i : line.strip() for i, line in enumerate(fd)}
    fd.close()
  return slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=tf.TFRecordReader,
      decoder=decoder,
      num_samples=num_samples,
      items_to_descriptions=items_to_descriptions,
      num_classes=num_classes,
      labels_to_names=labels_to_names)

构建模型

官方提供了许多模型在models/slim/nets/。

如需要自定义模型,则参考官方提供的模型并放在对应的文件夹即可。

开始训练

官方提供了训练脚本,如果使用官方的数据读入和处理,可使用以下方式开始训练。

cd $WORKSPACE/models/slim
CUDA_VISIBLE_DEVICES="0" python train_image_classifier.py \
  --train_dir=train_logs \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=../../data/flowers \
  --model_name=inception_resnet_v2 \
  --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \
  --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
  --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
  --max_number_of_steps=1000 \
  --batch_size=32 \
  --learning_rate=0.01 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

不fine-tune把--checkpoint_path, --checkpoint_exclude_scopes和--trainable_scopes删掉。

fine-tune所有层把--checkpoint_exclude_scopes和--trainable_scopes删掉。

如果只使用CPU则加上--clone_on_cpu=True。

其它参数可删掉用默认值或自行修改。

使用自己的数据则需要修改models/slim/train_image_classifier.py:

from datasets import dataset_factory

修改为

from datasets import dataset_classification

dataset = dataset_factory.get_dataset(
  FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

修改为

dataset = dataset_classification.get_dataset(
  FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

tf.app.flags.DEFINE_string(
  'dataset_dir', None, 'The directory where the dataset files are stored.')

后加入

tf.app.flags.DEFINE_integer(
  'num_samples', 3320, 'Number of samples.')

tf.app.flags.DEFINE_integer(
  'num_classes', 5, 'Number of classes.')

tf.app.flags.DEFINE_string(
  'labels_to_names_path', None, 'Label names file path.')

训练时执行以下命令即可:

cd $WORKSPACE/models/slim
python train_image_classifier.py \
  --train_dir=train_logs \
  --dataset_dir=../../data/train \
  --num_samples=3320 \
  --num_classes=5 \
  --labels_to_names_path=../../data/labels.txt \
  --model_name=inception_resnet_v2 \
  --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \
  --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
  --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits

可视化log

可一边训练一边可视化训练的log,可看到Loss趋势。

tensorboard --logdir train_logs/

验证

官方提供了验证脚本。

python eval_image_classifier.py \
  --checkpoint_path=train_logs \
  --eval_dir=eval_logs \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=../../data/flowers \
  --model_name=inception_resnet_v2

同样,如果是使用自己的数据集,则需要修改models/slim/eval_image_classifier.py:

from datasets import dataset_factory

修改为

from datasets import dataset_classification

dataset = dataset_factory.get_dataset(
  FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

修改为

dataset = dataset_classification.get_dataset(
  FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

tf.app.flags.DEFINE_string(
  'dataset_dir', None, 'The directory where the dataset files are stored.')

后加入

tf.app.flags.DEFINE_integer(
  'num_samples', 350, 'Number of samples.')

tf.app.flags.DEFINE_integer(
  'num_classes', 5, 'Number of classes.')

tf.app.flags.DEFINE_string(
  'labels_to_names_path', None, 'Label names file path.')

验证时执行以下命令即可:

python eval_image_classifier.py \
  --checkpoint_path=train_logs \
  --eval_dir=eval_logs \
  --dataset_dir=../../data/val \
  --num_samples=350 \
  --num_classes=5 \
  --model_name=inception_resnet_v2

可以一边训练一边验证,,注意使用其它的GPU或合理分配显存。

同样也可以可视化log,如果已经在可视化训练的log则建议使用其它端口,如:

tensorboard --logdir eval_logs/ --port 6007

测试

参考models/slim/eval_image_classifier.py,可编写读取图片用模型进行推导的脚本models/slim/test_image_classifier.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import math
import tensorflow as tf

from nets import nets_factory
from preprocessing import preprocessing_factory

slim = tf.contrib.slim

tf.app.flags.DEFINE_string(
  'master', '', 'The address of the TensorFlow master to use.')

tf.app.flags.DEFINE_string(
  'checkpoint_path', '/tmp/tfmodel/',
  'The directory where the model was written to or an absolute path to a '
  'checkpoint file.')

tf.app.flags.DEFINE_string(
  'test_path', '', 'Test image path.')

tf.app.flags.DEFINE_integer(
  'num_classes', 5, 'Number of classes.')

tf.app.flags.DEFINE_integer(
  'labels_offset', 0,
  'An offset for the labels in the dataset. This flag is primarily used to '
  'evaluate the VGG and ResNet architectures which do not use a background '
  'class for the ImageNet dataset.')

tf.app.flags.DEFINE_string(
  'model_name', 'inception_v3', 'The name of the architecture to evaluate.')

tf.app.flags.DEFINE_string(
  'preprocessing_name', None, 'The name of the preprocessing to use. If left '
  'as `None`, then the model_name flag is used.')

tf.app.flags.DEFINE_integer(
  'test_image_size', None, 'Eval image size')

FLAGS = tf.app.flags.FLAGS


def main(_):
  if not FLAGS.test_list:
    raise ValueError('You must supply the test list with --test_list')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ####################
    # Select the model #
    ####################
    network_fn = nets_factory.get_network_fn(
      FLAGS.model_name,
      num_classes=(FLAGS.num_classes - FLAGS.labels_offset),
      is_training=False)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
      preprocessing_name,
      is_training=False)

    test_image_size = FLAGS.test_image_size or network_fn.default_image_size

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.Graph().as_default()
    with tf.Session() as sess:
      image = open(FLAGS.test_path, 'rb').read()
      image = tf.image.decode_jpeg(image, channels=3)
      processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
      processed_images = tf.expand_dims(processed_image, 0)
      logits, _ = network_fn(processed_images)
      predictions = tf.argmax(logits, 1)
      saver = tf.train.Saver()
      saver.restore(sess, checkpoint_path)
      np_image, network_input, predictions = sess.run([image, processed_image, predictions])
      print('{} {}'.format(FLAGS.test_path, predictions[0]))

if __name__ == '__main__':
  tf.app.run()

测试时执行以下命令即可:

python test_image_classifier.py \
  --checkpoint_path=train_logs/ \
  --test_path=../../data/flower_photos/tulips/6948239566_0ac0a124ee_n.jpg \
  --num_classes=5 \
  --model_name=inception_resnet_v2

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
老生常谈Python startswith()函数与endswith函数
Sep 08 Python
Windows系统下多版本pip的共存问题详解
Oct 10 Python
CentOS 6.5中安装Python 3.6.2的方法步骤
Dec 03 Python
python实现学生管理系统
Jan 11 Python
Python3生成手写体数字方法
Jan 30 Python
Python并发之多进程的方法实例代码
Aug 15 Python
详解python3安装pillow后报错没有pillow模块以及没有PIL模块问题解决
Apr 17 Python
django-初始配置(纯手写)详解
Jul 30 Python
Django中的用户身份验证示例详解
Aug 07 Python
Python @property原理解析和用法实例
Feb 11 Python
python GUI库图形界面开发之PyQt5浏览器控件QWebEngineView详细使用方法
Feb 26 Python
pip已经安装好第三方库但pycharm中import时还是标红的解决方案
Oct 09 Python
Pytorch之view及view_as使用详解
Dec 31 #Python
window环境pip切换国内源(pip安装异常缓慢的问题)
Dec 31 #Python
如何基于Python创建目录文件夹
Dec 31 #Python
Pytorch之contiguous的用法
Dec 31 #Python
python实现将json多行数据传入到mysql中使用
Dec 31 #Python
Pytorch之Variable的用法
Dec 31 #Python
Pytorch 多块GPU的使用详解
Dec 31 #Python
You might like
完美实现GIF动画缩略图的php代码
2011/01/02 PHP
php获取网卡的MAC地址支持WIN/LINUX系统
2014/04/30 PHP
PHP 错误处理机制
2015/07/06 PHP
基于PHP代码实现中奖概率算法可用于刮刮卡、大转盘等抽奖算法
2015/12/20 PHP
PHP中的print_r 与 var_dump 输出数组
2016/06/13 PHP
基于thinkPHP框架实现留言板的方法
2016/10/17 PHP
thinkPHP5.0框架简单配置作用域的方法
2017/03/17 PHP
用javascript实现点击链接弹出"图片另存为"而不是直接打开
2007/08/15 Javascript
jquery插件之信息弹出框showInfoDialog(成功/错误/警告/通知/背景遮罩)
2013/01/09 Javascript
jQuery获取复选框被选中数量及判断选择值的方法详解
2016/05/25 Javascript
jQuery实现弹出窗口弹出div层的实例代码
2017/01/09 Javascript
详解Vue2+Echarts实现多种图表数据可视化Dashboard(附源码)
2017/03/21 Javascript
JavaScript门面模式详解
2017/10/19 Javascript
纯js实现隔行变色效果
2017/11/29 Javascript
mpvue中配置vuex并持久化到本地Storage图文教程解析
2018/03/15 Javascript
Vue 组件传值几种常用方法【总结】
2018/05/28 Javascript
vue移动端html5页面根据屏幕适配的四种解决方法
2018/10/19 Javascript
js/jquery遍历对象和数组的方法分析【forEach,map与each方法】
2019/02/27 jQuery
微信小程序实现搜索历史功能
2020/03/26 Javascript
微信小程序入口场景的问题集合与相关解决方法
2019/06/26 Javascript
javascript 模块依赖管理的本质深入详解
2020/04/30 Javascript
解决vue字符串换行问题(绝对管用)
2020/08/06 Javascript
解决vue项目中遇到 Cannot find module ‘chalk‘ 报错的问题
2020/11/05 Javascript
python实现爬虫统计学校BBS男女比例之多线程爬虫(二)
2015/12/31 Python
使用python爬虫实现网络股票信息爬取的demo
2018/01/05 Python
mac PyCharm添加Python解释器及添加package路径的方法
2018/10/29 Python
python pygame实现滚动横版射击游戏城市之战
2019/11/25 Python
基于HTML5新特性Mutation Observer实现编辑器的撤销和回退操作
2016/01/11 HTML / CSS
程序运行正确, 但退出时却"core dump"了,怎么回事
2014/02/19 面试题
银行实习生的自我评价
2014/01/13 职场文书
承诺书样本
2014/08/30 职场文书
怎样写离婚协议书
2015/01/26 职场文书
大二学年个人总结
2015/03/03 职场文书
文化苦旅读书笔记
2015/06/29 职场文书
Django如何与Ajax交互
2021/04/29 Python
解析探秘fescar分布式事务实现原理
2022/02/28 Java/Android