tensorflow 1.0用CNN进行图像分类


Posted in Python onApril 15, 2018

tensorflow升级到1.0之后,增加了一些高级模块: 如tf.layers, tf.metrics, 和tf.losses,使得代码稍微有些简化。

任务:花卉分类

版本:tensorflow 1.0

数据:flower-photos

花总共有五类,分别放在5个文件夹下。

闲话不多说,直接上代码,希望大家能看懂:)

复制代码

# -*- coding: utf-8 -*-

from skimage import io,transform
import glob
import os
import tensorflow as tf
import numpy as np
import time

path='e:/flower/'

#将所有的图片resize成100*100
w=100
h=100
c=3


#读取图片
def read_img(path):
 cate=[path+x for x in os.listdir(path) if os.path.isdir(path+x)]
 imgs=[]
 labels=[]
 for idx,folder in enumerate(cate):
  for im in glob.glob(folder+'/*.jpg'):
   print('reading the images:%s'%(im))
   img=io.imread(im)
   img=transform.resize(img,(w,h))
   imgs.append(img)
   labels.append(idx)
 return np.asarray(imgs,np.float32),np.asarray(labels,np.int32)
data,label=read_img(path)


#打乱顺序
num_example=data.shape[0]
arr=np.arange(num_example)
np.random.shuffle(arr)
data=data[arr]
label=label[arr]


#将所有数据分为训练集和验证集
ratio=0.8
s=np.int(num_example*ratio)
x_train=data[:s]
y_train=label[:s]
x_val=data[s:]
y_val=label[s:]

#-----------------构建网络----------------------
#占位符
x=tf.placeholder(tf.float32,shape=[None,w,h,c],name='x')
y_=tf.placeholder(tf.int32,shape=[None,],name='y_')

#第一个卷积层(100——>50)
conv1=tf.layers.conv2d(
  inputs=x,
  filters=32,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool1=tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

#第二个卷积层(50->25)
conv2=tf.layers.conv2d(
  inputs=pool1,
  filters=64,
  kernel_size=[5, 5],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool2=tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

#第三个卷积层(25->12)
conv3=tf.layers.conv2d(
  inputs=pool2,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool3=tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)

#第四个卷积层(12->6)
conv4=tf.layers.conv2d(
  inputs=pool3,
  filters=128,
  kernel_size=[3, 3],
  padding="same",
  activation=tf.nn.relu,
  kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
pool4=tf.layers.max_pooling2d(inputs=conv4, pool_size=[2, 2], strides=2)

re1 = tf.reshape(pool4, [-1, 6 * 6 * 128])

#全连接层
dense1 = tf.layers.dense(inputs=re1, 
      units=1024, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
dense2= tf.layers.dense(inputs=dense1, 
      units=512, 
      activation=tf.nn.relu,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
logits= tf.layers.dense(inputs=dense2, 
      units=5, 
      activation=None,
      kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
      kernel_regularizer=tf.contrib.layers.l2_regularizer(0.003))
#---------------------------网络结束---------------------------

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) 
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


#定义一个函数,按批次取数据
def minibatches(inputs=None, targets=None, batch_size=None, shuffle=False):
 assert len(inputs) == len(targets)
 if shuffle:
  indices = np.arange(len(inputs))
  np.random.shuffle(indices)
 for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
  if shuffle:
   excerpt = indices[start_idx:start_idx + batch_size]
  else:
   excerpt = slice(start_idx, start_idx + batch_size)
  yield inputs[excerpt], targets[excerpt]


#训练和测试数据,可将n_epoch设置更大一些

n_epoch=10
batch_size=64
sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
for epoch in range(n_epoch):
 start_time = time.time()
 
 #training
 train_loss, train_acc, n_batch = 0, 0, 0
 for x_train_a, y_train_a in minibatches(x_train, y_train, batch_size, shuffle=True):
  _,err,ac=sess.run([train_op,loss,acc], feed_dict={x: x_train_a, y_: y_train_a})
  train_loss += err; train_acc += ac; n_batch += 1
 print(" train loss: %f" % (train_loss/ n_batch))
 print(" train acc: %f" % (train_acc/ n_batch))
 
 #validation
 val_loss, val_acc, n_batch = 0, 0, 0
 for x_val_a, y_val_a in minibatches(x_val, y_val, batch_size, shuffle=False):
  err, ac = sess.run([loss,acc], feed_dict={x: x_val_a, y_: y_val_a})
  val_loss += err; val_acc += ac; n_batch += 1
 print(" validation loss: %f" % (val_loss/ n_batch))
 print(" validation acc: %f" % (val_acc/ n_batch))

sess.close()

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python随机生成数模块random使用实例
Apr 13 Python
用PyQt进行Python图形界面的程序的开发的入门指引
Apr 14 Python
Python 'takes exactly 1 argument (2 given)' Python error
Dec 13 Python
Python爬虫实现网页信息抓取功能示例【URL与正则模块】
May 18 Python
Python实现按中文排序的方法示例
Apr 25 Python
浅析Python3中的对象垃圾收集机制
Jun 06 Python
简单了解python中对象的取反运算符
Jul 01 Python
Django 全局的static和templates的使用详解
Jul 19 Python
Django框架创建mysql连接与使用示例
Jul 29 Python
Python实现代码统计工具
Sep 19 Python
Python函数式编程实例详解
Jan 17 Python
使用Python将语音转换为文本的方法
Aug 10 Python
tensorflow学习笔记之mnist的卷积神经网络实例
Apr 15 #Python
tensorflow学习笔记之简单的神经网络训练和测试
Apr 15 #Python
Pytorch入门之mnist分类实例
Apr 14 #Python
pytorch构建网络模型的4种方法
Apr 13 #Python
Python输入二维数组方法
Apr 13 #Python
Python基于递归实现电话号码映射功能示例
Apr 13 #Python
Python的多维空数组赋值方法
Apr 13 #Python
You might like
PHP中显示格式化的用户输入
2006/10/09 PHP
php表单转换textarea换行符的方法
2010/09/10 PHP
php 获取今日、昨日、上周、本月的起始时间戳和结束时间戳的方法
2013/09/28 PHP
PHP入门教程之会话控制技巧(cookie与session)
2016/09/11 PHP
chrome调试javascript详解
2015/10/21 Javascript
Nodejs进阶:核心模块net入门学习与实例讲解
2016/11/21 NodeJs
jQ处理xml文件和xml字符串的方法(详解)
2016/11/22 Javascript
JS实现动画兼容性的transition和transform实例分析
2016/12/13 Javascript
JS实现点击表头表格自动排序(含数字、字符串、日期)
2017/01/22 Javascript
如何用js判断dom是否有存在某class的值
2017/02/13 Javascript
关于vue面试题汇总
2018/03/20 Javascript
详解JavaScript 事件流
2020/09/02 Javascript
Python实现批量转换文件编码的方法
2015/07/28 Python
Python的Django框架中使用SQLAlchemy操作数据库的教程
2016/06/02 Python
如何在Python函数执行前后增加额外的行为
2016/10/20 Python
Python实现简单的四则运算计算器
2016/11/02 Python
Python 模块EasyGui详细介绍
2017/02/19 Python
python实现聚类算法原理
2018/02/12 Python
纯用NumPy实现神经网络的示例代码
2018/10/24 Python
如何使用Python自动控制windows桌面
2019/07/11 Python
Win10下python 2.7与python 3.7双环境安装教程图解
2019/10/12 Python
Python matplotlib 绘制双Y轴曲线图的示例代码
2020/06/12 Python
CSS3 函数技巧 用css 实现js实现的事情(clac Counters Tooltip)
2017/08/15 HTML / CSS
GUESS盖尔斯法国官网:美国时尚品牌
2016/09/23 全球购物
Madewell美德威尔美国官网:美国休闲服饰品牌
2016/11/25 全球购物
CK巴西官方网站:Calvin Klein巴西
2019/07/19 全球购物
java程序员面试交流
2012/11/29 面试题
预备党员入党思想汇报
2014/01/04 职场文书
大学毕业后的十年规划
2014/01/07 职场文书
班级入场式解说词
2014/02/01 职场文书
结对共建协议书
2014/08/20 职场文书
大学生党员个人总结
2015/02/13 职场文书
毕业生对母校寄语
2015/02/26 职场文书
向雷锋同志学习倡议书
2015/04/27 职场文书
小学德育工作总结2015
2015/05/12 职场文书
2016年端午节寄语
2015/12/04 职场文书