keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
举例讲解Python中is和id的用法
Apr 03 Python
用Python实现一个简单的能够上传下载的HTTP服务器
May 05 Python
python中json格式数据输出的简单实现方法
Oct 31 Python
Python正则表达式教程之三:贪婪/非贪婪特性
Mar 02 Python
python中nan与inf转为特定数字方法示例
May 11 Python
python itchat给指定联系人发消息的方法
Jun 11 Python
python 实现快速生成连续、随机字母列表
Nov 28 Python
使用Python爬虫库BeautifulSoup遍历文档树并对标签进行操作详解
Jan 25 Python
tensorflow 获取checkpoint中的变量列表实例
Feb 11 Python
一篇文章搞懂python的转义字符及用法
Sep 03 Python
Matlab使用Plot函数实现数据动态显示方法总结
Feb 25 Python
Python实现信息管理系统
Jun 05 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
浅谈PHP值mysql操作类
2016/06/29 PHP
javascript 播放器 控制
2007/01/22 Javascript
jQuery 第二课 操作包装集元素代码
2010/03/14 Javascript
js对象之JS入门之Array对象操作小结
2011/01/09 Javascript
$.ajax返回的JSON无法执行success的解决方法
2011/09/09 Javascript
js调用webservice中的方法实现思路及代码
2013/02/25 Javascript
input链接页面、打开新网页等等的具体实现
2013/12/30 Javascript
Extjs Label的 fieldLabel和html属性值对齐的方法
2014/06/15 Javascript
JS实现同时搜索百度和必应的方法
2015/01/27 Javascript
JavaScript  cookie 跨域访问之广告推广
2016/04/20 Javascript
jQuery插件MovingBoxes实现左右滑动中间放大图片效果
2017/02/28 Javascript
vue的安装及element组件的安装方法
2018/03/09 Javascript
JavaScript求一组数的最小公倍数和最大公约数常用算法详解【面向对象,回归迭代和循环】
2018/05/07 Javascript
vue单页面应用打开新窗口显示跳转页面的实例
2018/09/21 Javascript
Vue slot用法(小结)
2018/10/22 Javascript
[01:12]快闪回顾DOTA2亚洲邀请赛(DAC) 静候2018新征程开启
2018/03/11 DOTA
Python中使用pprint函数进行格式化输出的教程
2015/04/07 Python
两个命令把 Vim 打造成 Python IDE的方法
2016/03/20 Python
python使用xlrd与xlwt对excel的读写和格式设定
2017/01/21 Python
Python中的默认参数实例分析
2018/01/29 Python
Tkinter中复选菜单是否被选中的判断与设置方式
2020/03/04 Python
python 读取二进制 显示图片案例
2020/04/24 Python
Python使用configparser读取ini配置文件
2020/05/25 Python
python json.dumps() json.dump()的区别详解
2020/07/14 Python
python实现图像高斯金字塔的示例代码
2020/12/11 Python
HTML5教程之html 5 本地数据库(Web Sql Database)
2014/04/03 HTML / CSS
研究生自我鉴定范文
2013/10/30 职场文书
个人评语大全
2014/05/04 职场文书
客户答谢会活动方案
2014/08/31 职场文书
财务工作疏忽检讨书
2014/09/11 职场文书
2014年转正工作总结
2014/11/08 职场文书
六年级语文下册教学计划
2015/01/22 职场文书
监考失职检讨书
2015/01/26 职场文书
协议书格式模板
2016/03/24 职场文书
Vue2.0搭建脚手架
2022/03/13 Vue.js
Vertica集成Apache Hudi重磅使用指南
2022/03/31 Servers