使用TensorFlow-Slim进行图像分类的实现


Posted in Python onDecember 31, 2019

参考 https://github.com/tensorflow/models/tree/master/slim

使用TensorFlow-Slim进行图像分类

准备

安装TensorFlow

参考 https://www.tensorflow.org/install/

如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本

wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl

下载TF-slim图像模型库

cd $WORKSPACE
git clone https://github.com/tensorflow/models/

准备数据

有不少公开数据集,这里以官网提供的Flowers为例。

官网提供了下载和转换数据的代码,为了理解代码并能使用自己的数据,这里参考官方提供的代码进行修改。

cd $WORKSPACE/data
wget http://download.tensorflow.org/example_images/flower_photos.tgz
tar zxf flower_photos.tgz

数据集文件夹结构如下:

flower_photos
├── daisy
│  ├── 100080576_f52e8ee070_n.jpg
│  └── ...
├── dandelion
├── LICENSE.txt
├── roses
├── sunflowers
└── tulips

由于实际情况中我们自己的数据集并不一定把图片按类别放在不同的文件夹里,故我们生成list.txt来表示图片路径与标签的关系。

Python代码:

import os

class_names_to_ids = {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflowers': 3, 'tulips': 4}
data_dir = 'flower_photos/'
output_path = 'list.txt'

fd = open(output_path, 'w')
for class_name in class_names_to_ids.keys():
  images_list = os.listdir(data_dir + class_name)
  for image_name in images_list:
    fd.write('{}/{} {}\n'.format(class_name, image_name, class_names_to_ids[class_name]))

fd.close()

为了方便后期查看label标签,也可以定义labels.txt:

daisy
dandelion
roses
sunflowers
tulips

随机生成训练集与验证集:

Python代码:

import random

_NUM_VALIDATION = 350
_RANDOM_SEED = 0
list_path = 'list.txt'
train_list_path = 'list_train.txt'
val_list_path = 'list_val.txt'

fd = open(list_path)
lines = fd.readlines()
fd.close()
random.seed(_RANDOM_SEED)
random.shuffle(lines)

fd = open(train_list_path, 'w')
for line in lines[_NUM_VALIDATION:]:
  fd.write(line)

fd.close()
fd = open(val_list_path, 'w')
for line in lines[:_NUM_VALIDATION]:
  fd.write(line)

fd.close()

生成TFRecord数据:

Python代码:

import sys
sys.path.insert(0, '../models/slim/')
from datasets import dataset_utils
import math
import os
import tensorflow as tf

def convert_dataset(list_path, data_dir, output_dir, _NUM_SHARDS=5):
  fd = open(list_path)
  lines = [line.split() for line in fd]
  fd.close()
  num_per_shard = int(math.ceil(len(lines) / float(_NUM_SHARDS)))
  with tf.Graph().as_default():
    decode_jpeg_data = tf.placeholder(dtype=tf.string)
    decode_jpeg = tf.image.decode_jpeg(decode_jpeg_data, channels=3)
    with tf.Session('') as sess:
      for shard_id in range(_NUM_SHARDS):
        output_path = os.path.join(output_dir,
          'data_{:05}-of-{:05}.tfrecord'.format(shard_id, _NUM_SHARDS))
        tfrecord_writer = tf.python_io.TFRecordWriter(output_path)
        start_ndx = shard_id * num_per_shard
        end_ndx = min((shard_id + 1) * num_per_shard, len(lines))
        for i in range(start_ndx, end_ndx):
          sys.stdout.write('\r>> Converting image {}/{} shard {}'.format(
            i + 1, len(lines), shard_id))
          sys.stdout.flush()
          image_data = tf.gfile.FastGFile(os.path.join(data_dir, lines[i][0]), 'rb').read()
          image = sess.run(decode_jpeg, feed_dict={decode_jpeg_data: image_data})
          height, width = image.shape[0], image.shape[1]
          example = dataset_utils.image_to_tfexample(
            image_data, b'jpg', height, width, int(lines[i][1]))
          tfrecord_writer.write(example.SerializeToString())
        tfrecord_writer.close()
  sys.stdout.write('\n')
  sys.stdout.flush()

os.system('mkdir -p train')
convert_dataset('list_train.txt', 'flower_photos', 'train/')
os.system('mkdir -p val')
convert_dataset('list_val.txt', 'flower_photos', 'val/')

得到的文件夹结构如下:

data
├── flower_photos
├── labels.txt
├── list_train.txt
├── list.txt
├── list_val.txt
├── train
│  ├── data_00000-of-00005.tfrecord
│  ├── ...
│  └── data_00004-of-00005.tfrecord
└── val
  ├── data_00000-of-00005.tfrecord
  ├── ...
  └── data_00004-of-00005.tfrecord

(可选)下载模型

官方提供了不少预训练模型,这里以Inception-ResNet-v2以例。

cd $WORKSPACE/checkpoints
wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
tar zxf inception_resnet_v2_2016_08_30.tar.gz

训练

读入数据

官方提供了读入Flowers数据集的代码models/slim/datasets/flowers.py,同样这里也是参考并修改成能读入上面定义的通用数据集。

把下面代码写入models/slim/datasets/dataset_classification.py。

import os
import tensorflow as tf
slim = tf.contrib.slim

def get_dataset(dataset_dir, num_samples, num_classes, labels_to_names_path=None, file_pattern='*.tfrecord'):
  file_pattern = os.path.join(dataset_dir, file_pattern)
  keys_to_features = {
    'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
    'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
    'image/class/label': tf.FixedLenFeature(
      [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }
  items_to_handlers = {
    'image': slim.tfexample_decoder.Image(),
    'label': slim.tfexample_decoder.Tensor('image/class/label'),
  }
  decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
  items_to_descriptions = {
    'image': 'A color image of varying size.',
    'label': 'A single integer between 0 and ' + str(num_classes - 1),
  }
  labels_to_names = None
  if labels_to_names_path is not None:
    fd = open(labels_to_names_path)
    labels_to_names = {i : line.strip() for i, line in enumerate(fd)}
    fd.close()
  return slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=tf.TFRecordReader,
      decoder=decoder,
      num_samples=num_samples,
      items_to_descriptions=items_to_descriptions,
      num_classes=num_classes,
      labels_to_names=labels_to_names)

构建模型

官方提供了许多模型在models/slim/nets/。

如需要自定义模型,则参考官方提供的模型并放在对应的文件夹即可。

开始训练

官方提供了训练脚本,如果使用官方的数据读入和处理,可使用以下方式开始训练。

cd $WORKSPACE/models/slim
CUDA_VISIBLE_DEVICES="0" python train_image_classifier.py \
  --train_dir=train_logs \
  --dataset_name=flowers \
  --dataset_split_name=train \
  --dataset_dir=../../data/flowers \
  --model_name=inception_resnet_v2 \
  --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \
  --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
  --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
  --max_number_of_steps=1000 \
  --batch_size=32 \
  --learning_rate=0.01 \
  --learning_rate_decay_type=fixed \
  --save_interval_secs=60 \
  --save_summaries_secs=60 \
  --log_every_n_steps=10 \
  --optimizer=rmsprop \
  --weight_decay=0.00004

不fine-tune把--checkpoint_path, --checkpoint_exclude_scopes和--trainable_scopes删掉。

fine-tune所有层把--checkpoint_exclude_scopes和--trainable_scopes删掉。

如果只使用CPU则加上--clone_on_cpu=True。

其它参数可删掉用默认值或自行修改。

使用自己的数据则需要修改models/slim/train_image_classifier.py:

from datasets import dataset_factory

修改为

from datasets import dataset_classification

dataset = dataset_factory.get_dataset(
  FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

修改为

dataset = dataset_classification.get_dataset(
  FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

tf.app.flags.DEFINE_string(
  'dataset_dir', None, 'The directory where the dataset files are stored.')

后加入

tf.app.flags.DEFINE_integer(
  'num_samples', 3320, 'Number of samples.')

tf.app.flags.DEFINE_integer(
  'num_classes', 5, 'Number of classes.')

tf.app.flags.DEFINE_string(
  'labels_to_names_path', None, 'Label names file path.')

训练时执行以下命令即可:

cd $WORKSPACE/models/slim
python train_image_classifier.py \
  --train_dir=train_logs \
  --dataset_dir=../../data/train \
  --num_samples=3320 \
  --num_classes=5 \
  --labels_to_names_path=../../data/labels.txt \
  --model_name=inception_resnet_v2 \
  --checkpoint_path=../../checkpoints/inception_resnet_v2_2016_08_30.ckpt \
  --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \
  --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits

可视化log

可一边训练一边可视化训练的log,可看到Loss趋势。

tensorboard --logdir train_logs/

验证

官方提供了验证脚本。

python eval_image_classifier.py \
  --checkpoint_path=train_logs \
  --eval_dir=eval_logs \
  --dataset_name=flowers \
  --dataset_split_name=validation \
  --dataset_dir=../../data/flowers \
  --model_name=inception_resnet_v2

同样,如果是使用自己的数据集,则需要修改models/slim/eval_image_classifier.py:

from datasets import dataset_factory

修改为

from datasets import dataset_classification

dataset = dataset_factory.get_dataset(
  FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

修改为

dataset = dataset_classification.get_dataset(
  FLAGS.dataset_dir, FLAGS.num_samples, FLAGS.num_classes, FLAGS.labels_to_names_path)

tf.app.flags.DEFINE_string(
  'dataset_dir', None, 'The directory where the dataset files are stored.')

后加入

tf.app.flags.DEFINE_integer(
  'num_samples', 350, 'Number of samples.')

tf.app.flags.DEFINE_integer(
  'num_classes', 5, 'Number of classes.')

tf.app.flags.DEFINE_string(
  'labels_to_names_path', None, 'Label names file path.')

验证时执行以下命令即可:

python eval_image_classifier.py \
  --checkpoint_path=train_logs \
  --eval_dir=eval_logs \
  --dataset_dir=../../data/val \
  --num_samples=350 \
  --num_classes=5 \
  --model_name=inception_resnet_v2

可以一边训练一边验证,,注意使用其它的GPU或合理分配显存。

同样也可以可视化log,如果已经在可视化训练的log则建议使用其它端口,如:

tensorboard --logdir eval_logs/ --port 6007

测试

参考models/slim/eval_image_classifier.py,可编写读取图片用模型进行推导的脚本models/slim/test_image_classifier.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import math
import tensorflow as tf

from nets import nets_factory
from preprocessing import preprocessing_factory

slim = tf.contrib.slim

tf.app.flags.DEFINE_string(
  'master', '', 'The address of the TensorFlow master to use.')

tf.app.flags.DEFINE_string(
  'checkpoint_path', '/tmp/tfmodel/',
  'The directory where the model was written to or an absolute path to a '
  'checkpoint file.')

tf.app.flags.DEFINE_string(
  'test_path', '', 'Test image path.')

tf.app.flags.DEFINE_integer(
  'num_classes', 5, 'Number of classes.')

tf.app.flags.DEFINE_integer(
  'labels_offset', 0,
  'An offset for the labels in the dataset. This flag is primarily used to '
  'evaluate the VGG and ResNet architectures which do not use a background '
  'class for the ImageNet dataset.')

tf.app.flags.DEFINE_string(
  'model_name', 'inception_v3', 'The name of the architecture to evaluate.')

tf.app.flags.DEFINE_string(
  'preprocessing_name', None, 'The name of the preprocessing to use. If left '
  'as `None`, then the model_name flag is used.')

tf.app.flags.DEFINE_integer(
  'test_image_size', None, 'Eval image size')

FLAGS = tf.app.flags.FLAGS


def main(_):
  if not FLAGS.test_list:
    raise ValueError('You must supply the test list with --test_list')

  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default():
    tf_global_step = slim.get_or_create_global_step()

    ####################
    # Select the model #
    ####################
    network_fn = nets_factory.get_network_fn(
      FLAGS.model_name,
      num_classes=(FLAGS.num_classes - FLAGS.labels_offset),
      is_training=False)

    #####################################
    # Select the preprocessing function #
    #####################################
    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
      preprocessing_name,
      is_training=False)

    test_image_size = FLAGS.test_image_size or network_fn.default_image_size

    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
    else:
      checkpoint_path = FLAGS.checkpoint_path

    tf.Graph().as_default()
    with tf.Session() as sess:
      image = open(FLAGS.test_path, 'rb').read()
      image = tf.image.decode_jpeg(image, channels=3)
      processed_image = image_preprocessing_fn(image, test_image_size, test_image_size)
      processed_images = tf.expand_dims(processed_image, 0)
      logits, _ = network_fn(processed_images)
      predictions = tf.argmax(logits, 1)
      saver = tf.train.Saver()
      saver.restore(sess, checkpoint_path)
      np_image, network_input, predictions = sess.run([image, processed_image, predictions])
      print('{} {}'.format(FLAGS.test_path, predictions[0]))

if __name__ == '__main__':
  tf.app.run()

测试时执行以下命令即可:

python test_image_classifier.py \
  --checkpoint_path=train_logs/ \
  --test_path=../../data/flower_photos/tulips/6948239566_0ac0a124ee_n.jpg \
  --num_classes=5 \
  --model_name=inception_resnet_v2

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python 可爱的大小写
Sep 06 Python
Python读取图片EXIF信息类库介绍和使用实例
Jul 10 Python
零基础写python爬虫之打包生成exe文件
Nov 06 Python
wxPython使用系统剪切板的方法
Jun 16 Python
python中matplotlib实现最小二乘法拟合的过程详解
Jul 11 Python
Python实现识别手写数字大纲
Jan 29 Python
微信小程序python用户认证的实现
Jul 29 Python
python代码 FTP备份交换机配置脚本实例解析
Aug 01 Python
Python使用APScheduler实现定时任务过程解析
Sep 11 Python
Python数据库小程序源代码
Sep 15 Python
Python 装饰器@,对函数进行功能扩展操作示例【开闭原则】
Oct 17 Python
python 遍历pd.Series的index和value
Nov 26 Python
Pytorch之view及view_as使用详解
Dec 31 #Python
window环境pip切换国内源(pip安装异常缓慢的问题)
Dec 31 #Python
如何基于Python创建目录文件夹
Dec 31 #Python
Pytorch之contiguous的用法
Dec 31 #Python
python实现将json多行数据传入到mysql中使用
Dec 31 #Python
Pytorch之Variable的用法
Dec 31 #Python
Pytorch 多块GPU的使用详解
Dec 31 #Python
You might like
JAVA/JSP学习系列之二
2006/10/09 PHP
php读取torrent种子文件内容的方法(测试可用)
2016/05/03 PHP
有关PHP 中 config.m4 的探索
2020/08/26 PHP
jquery 获取json数据实现代码
2009/04/27 Javascript
jQuery AnythingSlider滑动效果插件
2010/02/07 Javascript
各种页面定时跳转(倒计时跳转)代码总结
2013/10/24 Javascript
理运用命名空间让js不产生冲突避免全局变量的泛滥
2014/06/15 Javascript
JS数组(Array)处理函数整理
2014/12/07 Javascript
JavaScript获得url查询参数的方法
2015/07/02 Javascript
javascript基础语法学习笔记
2016/01/04 Javascript
jQuery实现限制文本框的输入长度
2017/01/11 Javascript
vue 文件目录结构详解
2017/11/24 Javascript
微信小程序收藏功能的实现代码
2020/06/19 Javascript
python字符串对其居中显示的方法
2015/07/11 Python
浅谈pyhton学习中出现的各种问题(新手必看)
2017/05/17 Python
Python简单生成8位随机密码的方法
2017/05/24 Python
python版飞机大战代码分享
2018/11/20 Python
uwsgi+nginx部署Django项目操作示例
2018/12/04 Python
python web框架中实现原生分页
2019/09/08 Python
pygame实现成语填空游戏
2019/10/29 Python
python tqdm 实现滚动条不上下滚动代码(保持一行内滚动)
2020/02/19 Python
以设计师精品品质提供快速时尚:PopJulia
2018/01/09 全球购物
Street One瑞士:德国现代时装公司
2019/10/09 全球购物
在DELPHI中调用存储过程和使用内嵌SQL哪种方式更好
2016/11/22 面试题
护理专业毕业生自荐信范文
2014/01/05 职场文书
大学生求职自我评价
2014/01/16 职场文书
环境保护标语
2014/06/20 职场文书
文化大革命观后感
2015/06/17 职场文书
孕妇病假条怎么写
2015/08/17 职场文书
2016婚礼主持词开场白
2015/11/24 职场文书
2019年工作总结范文
2019/05/21 职场文书
《合作意向书》怎么写?
2019/08/20 职场文书
导游词之上海豫园
2019/10/24 职场文书
css实现两栏布局,左侧固定宽,右侧自适应的多种方法
2021/08/07 HTML / CSS
Nginx虚拟主机的搭建的实现步骤
2022/01/18 Servers
Mysql多层子查询示例代码(收藏夹案例)
2022/03/31 MySQL