keras-siamese用自己的数据集实现详解


Posted in Python onJune 10, 2020

Siamese网络不做过多介绍,思想并不难,输入两个图像,输出这两张图像的相似度,两个输入的网络结构是相同的,参数共享。

主要发现很多代码都是基于mnist数据集的,下面说一下怎么用自己的数据集实现siamese网络。

首先,先整理数据集,相同的类放到同一个文件夹下,如下图所示:

keras-siamese用自己的数据集实现详解

接下来,将pairs及对应的label写到csv中,代码如下:

import os
import random
import csv
#图片所在的路径
path = '/Users/mac/Desktop/wxd/flag/category/'
#files列表保存所有类别的路径
files=[]
same_pairs=[]
different_pairs=[]
for file in os.listdir(path):
 if file[0]=='.':
  continue
 file_path = os.path.join(path,file)
 files.append(file_path)
#该地址为csv要保存到的路径,a表示追加写入
with open('/Users/mac/Desktop/wxd/flag/data.csv','a') as f:
 #保存相同对
 writer = csv.writer(f)
 for file in files:
  imgs = os.listdir(file) 
  for i in range(0,len(imgs)-1):
   for j in range(i+1,len(imgs)):
    pairs = []
    name = file.split(sep='/')[-1]
    pairs.append(path+name+'/'+imgs[i])
    pairs.append(path+name+'/'+imgs[j])
    pairs.append(1)
    writer.writerow(pairs)
 #保存不同对
 for i in range(0,len(files)-1):
  for j in range(i+1,len(files)):
   filea = files[i]
   fileb = files[j]
   imga_li = os.listdir(filea)
   imgb_li = os.listdir(fileb)
   random.shuffle(imga_li)
   random.shuffle(imgb_li)
   a_li = imga_li[:]
   b_li = imgb_li[:]
   for p in range(len(a_li)):
    for q in range(len(b_li)):
     pairs = []
     name1 = filea.split(sep='/')[-1]
     name2 = fileb.split(sep='/')[-1]
     pairs.append(path+name1+'/'+a_li[p])
     pairs.append(path+name2+'/'+b_li[q])
     pairs.append(0)
     writer.writerow(pairs)

相当于csv每一行都包含一对结果,每一行有三列,第一列第一张图片路径,第二列第二张图片路径,第三列是不是相同的label,属于同一个类的label为1,不同类的为0,可参考下图:

keras-siamese用自己的数据集实现详解

然后,由于keras的fit函数需要将训练数据都塞入内存,而大部分训练数据都较大,因此才用fit_generator生成器的方法,便可以训练大数据,代码如下:

from __future__ import absolute_import
from __future__ import print_function
import numpy as np
from keras.models import Model
from keras.layers import Input, Dense, Dropout, BatchNormalization, Conv2D, MaxPooling2D, AveragePooling2D, concatenate, \
 Activation, ZeroPadding2D
from keras.layers import add, Flatten
from keras.utils import plot_model
from keras.metrics import top_k_categorical_accuracy
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
import tensorflow as tf
import random
import os
import cv2
import csv
import numpy as np
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import img_to_array
 
"""
自定义的参数
"""
im_width = 224
im_height = 224
epochs = 100
batch_size = 64
iterations = 1000
csv_path = ''
model_result = ''
 
 
# 计算欧式距离
def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))
 
def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)
 
# 计算loss
def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 square_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * square_pred + (1 - y_true) * margin_square)
 
def compute_accuracy(y_true, y_pred):
 '''计算准确率
 '''
 pred = y_pred.ravel() < 0.5
 print('pred:', pred)
 return np.mean(pred == y_true)
 
def accuracy(y_true, y_pred):
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))
 
def processImg(filename):
 """
 :param filename: 图像的路径
 :return: 返回的是归一化矩阵
 """
 img = cv2.imread(filename)
 img = cv2.resize(img, (im_width, im_height))
 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
 img = img_to_array(img)
 img /= 255
 return img
 
def Conv2d_BN(x, nb_filter, kernel_size, strides=(1, 1), padding='same', name=None):
 if name is not None:
  bn_name = name + '_bn'
  conv_name = name + '_conv'
 else:
  bn_name = None
  conv_name = None
 
 x = Conv2D(nb_filter, kernel_size, padding=padding, strides=strides, activation='relu', name=conv_name)(x)
 x = BatchNormalization(axis=3, name=bn_name)(x)
 return x
 
def bottleneck_Block(inpt, nb_filters, strides=(1, 1), with_conv_shortcut=False):
 k1, k2, k3 = nb_filters
 x = Conv2d_BN(inpt, nb_filter=k1, kernel_size=1, strides=strides, padding='same')
 x = Conv2d_BN(x, nb_filter=k2, kernel_size=3, padding='same')
 x = Conv2d_BN(x, nb_filter=k3, kernel_size=1, padding='same')
 if with_conv_shortcut:
  shortcut = Conv2d_BN(inpt, nb_filter=k3, strides=strides, kernel_size=1)
  x = add([x, shortcut])
  return x
 else:
  x = add([x, inpt])
  return x
 
def resnet_50():
 width = im_width
 height = im_height
 channel = 3
 inpt = Input(shape=(width, height, channel))
 x = ZeroPadding2D((3, 3))(inpt)
 x = Conv2d_BN(x, nb_filter=64, kernel_size=(7, 7), strides=(2, 2), padding='valid')
 x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)
 
 # conv2_x
 x = bottleneck_Block(x, nb_filters=[64, 64, 256], strides=(1, 1), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[64, 64, 256])
 x = bottleneck_Block(x, nb_filters=[64, 64, 256])
 
 # conv3_x
 x = bottleneck_Block(x, nb_filters=[128, 128, 512], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 x = bottleneck_Block(x, nb_filters=[128, 128, 512])
 
 # conv4_x
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 x = bottleneck_Block(x, nb_filters=[256, 256, 1024])
 
 # conv5_x
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048], strides=(2, 2), with_conv_shortcut=True)
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048])
 x = bottleneck_Block(x, nb_filters=[512, 512, 2048])
 
 x = AveragePooling2D(pool_size=(7, 7))(x)
 x = Flatten()(x)
 x = Dense(128, activation='relu')(x)
 return Model(inpt, x)
 
def generator(imgs, batch_size):
 """
 自定义迭代器
 :param imgs: 列表,每个包含一对矩阵以及label
 :param batch_size:
 :return:
 """
 while 1:
  random.shuffle(imgs)
  li = imgs[:batch_size]
  pairs = []
  labels = []
  for i in li:
   img1 = i[0]
   img2 = i[1]
   im1 = cv2.imread(img1)
   im2 = cv2.imread(img2)
   if im1 is None or im2 is None:
    continue
   label = int(i[2])
   img1 = processImg(img1)
   img2 = processImg(img2)
   pairs.append([img1, img2])
   labels.append(label)
  pairs = np.array(pairs)
  labels = np.array(labels)
  yield [pairs[:, 0], pairs[:, 1]], labels
 
input_shape = (im_width, im_height, 3)
base_network = resnet_50()
 
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)
 
# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)
 
distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])
with tf.device("/gpu:0"):
 model = Model([input_a, input_b], distance)
 # train
 rms = RMSprop()
 rows = csv.reader(open(csv_path, 'r'), delimiter=',')
 imgs = list(rows)
 checkpoint = ModelCheckpoint(filepath=model_result+'flag_{epoch:03d}.h5', verbose=1)
 model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
 model.fit_generator(generator(imgs, batch_size), epochs=epochs, steps_per_epoch=iterations, callbacks=[checkpoint])

用了回调函数保存了每一个epoch后的模型,也可以保存最好的,之后需要对模型进行测试。

测试时直接用load_model会报错,而应该变成如下形式调用:

model = load_model(model_path,custom_objects={'contrastive_loss': contrastive_loss }) #选取自己的.h模型名称

emmm,到这里,就成功训练测试完了~~~写的比较粗,因为这个代码在官方给的mnist上的改动不大,只是方便大家用自己的数据集,大家如果有更好的方法可以提出意见~~~希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python中暂存上传图片的方法
Feb 18 Python
Python生成随机MAC地址
Mar 10 Python
以一段代码为实例快速入门Python2.7
Mar 31 Python
python实现数值积分的Simpson方法实例分析
Jun 05 Python
详解Python3中的Sequence type的使用
Aug 01 Python
python 获取list特定元素下标的实例讲解
Apr 09 Python
Python 3.7新功能之dataclass装饰器详解
Apr 21 Python
详解Python安装scrapy的正确姿势
Jun 26 Python
python根据txt文本批量创建文件夹
Dec 08 Python
python使用Plotly绘图工具绘制气泡图
Apr 01 Python
python 有效的括号的实现代码示例
Nov 11 Python
Python计算信息熵实例
Jun 18 Python
python实现mean-shift聚类算法
Jun 10 #Python
Keras之自定义损失(loss)函数用法说明
Jun 10 #Python
Python xlwt模块使用代码实例
Jun 10 #Python
python中def是做什么的
Jun 10 #Python
keras实现调用自己训练的模型,并去掉全连接层
Jun 09 #Python
Python基于os.environ从windows获取环境变量
Jun 09 #Python
新手学习Python2和Python3中print不同的用法
Jun 09 #Python
You might like
关于PHP中操作MySQL数据库的一些要注意的问题
2006/10/09 PHP
使用TinyButStrong模板引擎来做WEB开发
2007/03/16 PHP
解析array splice的移除数组中指定键的值,返回一个新的数组
2013/07/02 PHP
PHP fopen()和 file_get_contents()应用与差异介绍
2014/03/19 PHP
php使用function_exists判断函数可用的方法
2014/11/19 PHP
关于ThinkPhp 框架表单验证及ajax验证问题
2017/07/19 PHP
PHP扩展mcrypt实现的AES加密功能示例
2019/01/29 PHP
一个简单的javascript类定义例子
2009/09/12 Javascript
JS 两个字符串时间的天数差计算
2013/08/25 Javascript
从零学jquery之如何使用回调函数
2014/05/16 Javascript
DOM基础教程之使用DOM控制表单
2015/01/20 Javascript
Node.js实用代码段之正确拼接Buffer
2016/03/17 Javascript
AngularJS入门教程之REST和定制服务详解
2016/08/19 Javascript
jQuery实现上传图片前预览效果功能
2017/08/03 jQuery
nodejs读取并去重excel文件
2018/04/22 NodeJs
vue.js使用v-if实现显示与隐藏功能示例
2018/07/06 Javascript
echarts实现地图定时切换散点与多图表级联联动详解
2018/08/07 Javascript
electron中使用bootstrap的示例代码
2018/11/06 Javascript
layui动态渲染生成左侧3级菜单的方法(根据后台返回数据)
2019/09/23 Javascript
关于javascript中的promise的用法和注意事项(推荐)
2021/01/15 Javascript
Python简单实现子网掩码转换的方法
2016/04/13 Python
Python错误: SyntaxError: Non-ASCII character解决办法
2017/06/08 Python
python多线程socket编程之多客户端接入
2017/09/12 Python
python中正则表达式 re.findall 用法
2018/10/23 Python
Python3实现的简单工资管理系统示例
2019/03/12 Python
Python列表操作方法详解
2020/02/09 Python
python GUI库图形界面开发之PyQt5动态(可拖动控件大小)布局控件QSplitter详细使用方法与实例
2020/03/06 Python
Python基于Webhook实现github自动化部署
2020/11/28 Python
使用CSS3的ruby-position固定注音位置的用法示例
2016/07/05 HTML / CSS
奥地利智能家居和智能生活网上商店:tink.at
2019/10/07 全球购物
写出SQL四条最基本的数据操作语句(DML)
2012/12/12 面试题
社会实践活动总结范文
2014/07/03 职场文书
大学生就业协议书范本(适用于公司企业)
2014/10/07 职场文书
2014年保安个人工作总结
2014/11/13 职场文书
Pytorch使用shuffle打乱数据的操作
2021/05/20 Python
使用@Value值注入及配置文件组件扫描
2021/07/09 Java/Android