keras-siamese用自己的数据集实现详解


Posted in Python onJune 10, 2020

Siamese网络不做过多介绍,思想并不难,输入两个图像,输出这两张图像的相似度,两个输入的网络结构是相同的,参数共享。

主要发现很多代码都是基于mnist数据集的,下面说一下怎么用自己的数据集实现siamese网络。

首先,先整理数据集,相同的类放到同一个文件夹下,如下图所示:

keras-siamese用自己的数据集实现详解

接下来,将pairs及对应的label写到csv中,代码如下:

import os
import random
import csv
#图片所在的路径
path = '/Users/mac/Desktop/wxd/flag/category/'
#files列表保存所有类别的路径
files=[]
same_pairs=[]
different_pairs=[]
for file in os.listdir(path):
 if file[0]=='.':
  continue
 file_path = os.path.join(path,file)
 files.append(file_path)
#该地址为csv要保存到的路径,a表示追加写入
with open('/Users/mac/Desktop/wxd/flag/data.csv','a') as f:
 #保存相同对
 writer = csv.writer(f)
 for file in files:
  imgs = os.listdir(file) 
  for i in range(0,len(imgs)-1):
   for j in range(i+1,len(imgs)):
    pairs = []
    name = file.split(sep='/')[-1]
    pairs.append(path+name+'/'+imgs[i])
    pairs.append(path+name+'/'+imgs[j])
    pairs.append(1)
    writer.writerow(pairs)
 #保存不同对
 for i in range(0,len(files)-1):
  for j in range(i+1,len(files)):
   filea = files[i]
   fileb = files[j]
   imga_li = os.listdir(filea)
   imgb_li = os.listdir(fileb)
   random.shuffle(imga_li)
   random.shuffle(imgb_li)
   a_li = imga_li[:]
   b_li = imgb_li[:]
   for p in range(len(a_li)):
    for q in range(len(b_li)):
     pairs = []
     name1 = filea.split(sep='/')[-1]
     name2 = fileb.split(sep='/')[-1]
     pairs.append(path+name1+'/'+a_li[p])
     pairs.append(path+name2+'/'+b_li[q])
     pairs.append(0)
     writer.writerow(pairs)

相当于csv每一行都包含一对结果,每一行有三列,第一列第一张图片路径,第二列第二张图片路径,第三列是不是相同的label,属于同一个类的label为1,不同类的为0,可参考下图:

keras-siamese用自己的数据集实现详解

然后,由于keras的fit函数需要将训练数据都塞入内存,而大部分训练数据都较大,因此才用fit_generator生成器的方法,便可以训练大数据,代码如下:

from __future__ import absolute_import
from __future__ import print_function
import numpy as np
from keras.models import Model
from keras.layers import Input, Dense, Dropout, BatchNormalization, Conv2D, MaxPooling2D, AveragePooling2D, concatenate, \
 Activation, ZeroPadding2D
from keras.layers import add, Flatten
from keras.utils import plot_model
from keras.metrics import top_k_categorical_accuracy
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
import tensorflow as tf
import random
import os
import cv2
import csv
import numpy as np
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import img_to_array
 
"""
自定义的参数
"""
im_width = 224
im_height = 224
epochs = 100
batch_size = 64
iterations = 1000
csv_path = ''
model_result = ''
 
 
# 计算欧式距离
def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))
 
def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)
 
# 计算loss
def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 square_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * square_pred + (1 - y_true) * margin_square)
 
def compute_accuracy(y_true, y_pred):
 '''计算准确率
 '''
 pred = y_pred.ravel() < 0.5
 print('pred:', pred)
 return np.mean(pred == y_true)
 
def accuracy(y_true, y_pred):
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))
 
def processImg(filename):
 """
 :param filename: 图像的路径
 :return: 返回的是归一化矩阵
 """
 img = cv2.imread(filename)
 img = cv2.resize(img, (im_width, im_height))
 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
 img = img_to_array(img)
 img /= 255
 return img
 
def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same', name=None):
 if name is not None:
  bn_name = name + '_bn'
  conv_name = name + '_conv'
 else:
  bn_name = None
  conv_name = None
 
 x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides, activation='relu', name=conv_name)(x)
 x = BatchNormalization(axis=3, name=bn_name)(x)
 return x
 
def bottleneck_Block(inpt, nb_filters, strides=(1, 1), with_conv_shortcut=False):
 k1, k2, k3 = nb_filters
 x = Conv2d_BN(inpt, nb_filter=k1, kernel_size=1, strides=strides, padding='same')
 x = Conv2d_BN(x, nb_filter=k2, kernel_size=3, padding='same')
 x = Conv2d_BN(x, nb_filter=k3, kernel_size=1, padding='same')
 if with_conv_shortcut:
  shortcut = Conv2d_BN(inpt, nb_filter=k3, strides=strides, kernel_size=1)
  x = add([x, shortcut])
  return x
 else:
  x = add([x, inpt])
  return x
 
def resnet_50():
 width = im_width
 height = im_height
 channel = 3
 inpt = Input(shape=(width, height, channel))
 x = ZeroPadding2D((3, 3))(inpt)
 x = Conv2d_BN(x, nb_filter=64, kernel_size=(7, 7), strides=(2, 2), padding='valid')
 x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
 
 # conv2_x
 x = bottleneck_Block(x, nb_filters=[64, 64, 256], strides=(1, 1), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[64, 64, 256])
 x = bottleneck_Block(x, nb_filters=[64, 64, 256])
 
 # conv3_x
 x = bottleneck_Block(x, nb_filters=[128, 128, 512], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 
 # conv4_x
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 
 # conv5_x
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048])
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048])
 
 x = AveragePooling2D(pool_size=(7, 7))(x)
 x = Flatten()(x)
 x = Dense(128, activation='relu')(x)
 return Model(inpt, x)
 
def generator(imgs, batch_size):
 """
 自定义迭代器
 :param imgs: 列表,每个包含一对矩阵以及label
 :param batch_size:
 :return:
 """
 while 1:
  random.shuffle(imgs)
  li = imgs[:batch_size]
  pairs = []
  labels = []
  for i in li:
   img1 = i[0]
   img2 = i[1]
   im1 = cv2.imread(img1)
   im2 = cv2.imread(img2)
   if im1 is None or im2 is None:
    continue
   label = int(i[2])
   img1 = processImg(img1)
   img2 = processImg(img2)
   pairs.append([img1, img2])
   labels.append(label)
  pairs = np.array(pairs)
  labels = np.array(labels)
  yield [pairs[:, 0], pairs[:, 1]], labels
 
input_shape = (im_width, im_height, 3)
base_network = resnet_50()
 
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)
 
# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)
 
distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])
with tf.device("/gpu:0"):
 model = Model([input_a, input_b], distance)
 # train
 rms = RMSprop()
 rows = csv.reader(open(csv_path, 'r'), delimiter=',')
 imgs = list(rows)
 checkpoint = ModelCheckpoint(filepath=model_result+'flag_{epoch:03d}.h5', verbose=1)
 model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
 model.fit_generator(generator(imgs, batch_size), epochs=epochs, steps_per_epoch=iterations, callbacks=[checkpoint])

用了回调函数保存了每一个epoch后的模型,也可以保存最好的,之后需要对模型进行测试。

测试时直接用load_model会报错,而应该变成如下形式调用:

model = load_model(model_path,custom_objects={'contrastive_loss': contrastive_loss }) #选取自己的.h模型名称

emmm,到这里,就成功训练测试完了~~~写的比较粗,因为这个代码在官方给的mnist上的改动不大,只是方便大家用自己的数据集,大家如果有更好的方法可以提出意见~~~希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python之eval()函数危险性浅析
Jul 03 Python
python的类方法和静态方法
Dec 13 Python
分析Python编程时利用wxPython来支持多线程的方法
Apr 07 Python
实例解析Python的Twisted框架中Deferred对象的用法
May 25 Python
Python处理json字符串转化为字典的简单实现
Jul 07 Python
在PyCharm中实现关闭一个死循环程序的方法
Nov 29 Python
flask-restful使用总结
Dec 04 Python
在Python函数中输入任意数量参数的实例
Jul 16 Python
python multiprocessing模块用法及原理介绍
Aug 20 Python
你可能不知道的Python 技巧小结
Jan 29 Python
Python 基于FIR实现Hilbert滤波器求信号包络详解
Feb 26 Python
利用python画出AUC曲线的实例
Feb 28 Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
You might like
php定时执行任务设置详解
2015/02/06 PHP
CI框架的安全性分析
2016/05/18 PHP
yii2 modal弹窗之ActiveForm ajax表单异步验证
2016/06/13 PHP
JS控制表格隔行变色
2006/06/26 Javascript
JavaScript与Image加载事件(onload)、加载状态(complete)
2011/02/14 Javascript
dwz 如何去掉ajaxloading具体代码
2013/05/22 Javascript
js获取和设置属性的方法
2014/02/20 Javascript
一款基jquery超炫的动画导航菜单可响应单击事件
2014/11/02 Javascript
node.js中的fs.linkSync方法使用说明
2014/12/15 Javascript
浅谈 javascript 事件处理
2015/01/04 Javascript
自定义函数实现IE7与IE8不兼容js中trim函数的问题
2015/02/03 Javascript
Node.js+Express配置入门教程
2016/05/19 Javascript
从零开始学习Node.js系列教程五:服务器监听方法示例
2017/04/13 Javascript
微信小程序开发教程之增加mixin扩展
2017/08/09 Javascript
Vue2.0实现组件数据的双向绑定问题
2018/03/06 Javascript
jQuery实现表单动态添加与删除数据操作示例
2018/07/03 jQuery
如何在selenium中使用js实现定位
2020/08/18 Javascript
python 队列详解及实例代码
2016/10/18 Python
python生成不重复随机数和对list乱序的解决方法
2018/04/09 Python
Python实现的简单排列组合算法示例
2018/07/04 Python
使用Python正则表达式操作文本数据的方法
2019/05/14 Python
Django使用模板后无法找到静态资源文件问题解决
2019/07/19 Python
python两种获取剪贴板内容的方法
2020/11/06 Python
CSS3字体效果的设置方法小结
2016/06/13 HTML / CSS
html5使用canvas绘制一张图片
2014/12/15 HTML / CSS
在校生汽车维修实习自我鉴定
2013/09/19 职场文书
财政局长自荐信范文
2013/12/22 职场文书
汽车制造与装配专业自荐信范文
2014/01/02 职场文书
幼儿园秋游活动方案
2014/01/21 职场文书
求职自荐信怎么写
2014/03/06 职场文书
党员教师工作决心书
2014/03/13 职场文书
促销活动总结报告
2014/04/26 职场文书
后备干部培训方案
2014/05/22 职场文书
给校长的一封检讨书
2014/09/20 职场文书
2014年党的群众路线活动个人整改措施
2014/10/28 职场文书
利用JuiceFS使MySQL 备份验证性能提升 10 倍
2022/03/17 MySQL