keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
用Python的pandas框架操作Excel文件中的数据教程
Mar 31 Python
Python使用email模块对邮件进行编码和解码的实例教程
Jul 01 Python
Python编写一个优美的下载器
Apr 15 Python
python生成n个元素的全组合方法
Nov 13 Python
python 3.3 下载固定链接文件并保存的方法
Dec 18 Python
神经网络相关之基础概念的讲解
Dec 29 Python
python之验证码生成(gvcode与captcha)
Jan 02 Python
django Admin文档生成器使用详解
Jul 22 Python
在Python 的线程中运行协程的方法
Feb 24 Python
python输出结果刷新及进度条的实现操作
Jul 13 Python
python动态规划算法实例详解
Nov 22 Python
python实现学员管理系统(面向对象版)
Jun 05 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
mysql数据库差异比较的PHP代码
2012/02/05 PHP
浅谈Eclipse PDT调试PHP程序
2014/06/09 PHP
[原创]ThinkPHP中SHOW_RUN_TIME不能正常显示运行时间的解决方法
2015/10/10 PHP
客户端脚本中常常出现的一些问题和调试技巧
2007/01/09 Javascript
兼容IE与firefox火狐的回车事件(js与jquery)
2010/10/20 Javascript
CSS3 media queries结合jQuery实现响应式导航
2016/09/30 Javascript
快速解决js开发下拉框中blur与click冲突
2016/10/10 Javascript
JS匿名函数实例分析
2016/11/26 Javascript
微信小程序之拖拽排序(代码分享)
2017/01/21 Javascript
node.js + socket.io 实现点对点随机匹配聊天
2017/06/30 Javascript
详解React native全局变量的使用(跨组件的通信)
2017/09/07 Javascript
分享5个好用的javascript文件上传插件
2018/09/16 Javascript
vue实现固定位置显示功能
2019/05/30 Javascript
微信小程序自定义组件实现环形进度条
2020/11/17 Javascript
JavaScript实现轮播图效果代码实例
2019/09/28 Javascript
微信小程序实现日历小功能
2020/11/18 Javascript
[45:40]Ti4 冒泡赛第二天NEWBEE vs NaVi 1
2014/07/15 DOTA
六个窍门助你提高Python运行效率
2015/06/09 Python
Python错误处理操作示例
2018/07/18 Python
用python生成(动态彩色)二维码的方法(使用myqr库实现)
2019/06/24 Python
基于梯度爆炸的解决方法:clip gradient
2020/02/04 Python
python如何提取英语pdf内容并翻译
2020/03/03 Python
详解python算法常用技巧与内置库
2020/10/17 Python
python爬取2021猫眼票房字体加密实例
2021/02/19 Python
详解CSS3新增的背景属性
2019/12/25 HTML / CSS
印度最大的时尚购物网站:Myntra
2018/09/13 全球购物
西班牙香水和化妆品购物网站:Arenal Perfumerías
2019/03/01 全球购物
师范生教师实习自我鉴定
2013/09/27 职场文书
自荐信格式范文
2013/10/07 职场文书
数控技术专科生自我评价
2014/01/08 职场文书
公司门卫工作职责
2014/06/28 职场文书
环保项目建议书
2014/08/26 职场文书
合作协议书模板
2014/10/10 职场文书
安全员岗位职责范本
2015/04/11 职场文书
oracle覆盖导入dmp文件的2种方法
2021/05/21 Oracle
教你怎么用python实现字符串转日期
2021/05/24 Python