keras的siamese(孪生网络)实现案例


Posted in Python onJune 12, 2020

代码位于keras的官方样例,并做了微量修改和大量学习?。

最终效果:

keras的siamese(孪生网络)实现案例

keras的siamese(孪生网络)实现案例

import keras
import numpy as np
import matplotlib.pyplot as plt

import random

from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
 x, y = vects
 sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
 return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
 shape1, shape2 = shapes
 return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
 '''Contrastive loss from Hadsell-et-al.'06
 http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
 '''
 margin = 1
 sqaure_pred = K.square(y_pred)
 margin_square = K.square(K.maximum(margin - y_pred, 0))
 return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
 '''Positive and negative pair creation.
 Alternates between positive and negative pairs.
 '''
 pairs = []
 labels = []
 n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
 for d in range(num_classes):
  for i in range(n):
   z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
   pairs += [[x[z1], x[z2]]]
   inc = random.randrange(1, num_classes)
   dn = (d + inc) % num_classes
   z1, z2 = digit_indices[d][i], digit_indices[dn][i]
   pairs += [[x[z1], x[z2]]]
   labels += [1, 0]
 return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
 '''Base network to be shared (eq. to feature extraction).
 '''
 input = Input(shape=input_shape)
 x = Flatten()(input)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 x = Dropout(0.1)(x)
 x = Dense(128, activation='relu')(x)
 return Model(input, x)


def compute_accuracy(y_true, y_pred): # numpy上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 pred = y_pred.ravel() < 0.5
 return np.mean(pred == y_true)


def accuracy(y_true, y_pred): # Tensor上的操作
 '''Compute classification accuracy with a fixed threshold on distances.
 '''
 return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))

def plot_train_history(history, train_metrics, val_metrics):
 plt.plot(history.history.get(train_metrics), '-o')
 plt.plot(history.history.get(val_metrics), '-o')
 plt.ylabel(train_metrics)
 plt.xlabel('Epochs')
 plt.legend(['train', 'validation'])


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
     output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)
keras.utils.plot_model(model,"siamModel.png",show_shapes=True)
model.summary()

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
history=model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
   batch_size=128,
   epochs=epochs,verbose=2,
   validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

plt.figure(figsize=(8, 4))
plt.subplot(1, 2, 1)
plot_train_history(history, 'loss', 'val_loss')
plt.subplot(1, 2, 2)
plot_train_history(history, 'accuracy', 'val_accuracy')
plt.show()


# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))

以上这篇keras的siamese(孪生网络)实现案例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Flask SQLAlchemy一对一,一对多的使用方法实践
Feb 10 Python
详解Python中__str__和__repr__方法的区别
Apr 17 Python
Python 通过pip安装Django详细介绍
Apr 28 Python
Python简单删除列表中相同元素的方法示例
Jun 12 Python
用python实现对比两张图片的不同
Feb 05 Python
Python基于百度AI的文字识别的示例
Apr 21 Python
python数据结构之线性表的顺序存储结构
Sep 28 Python
Python检测数据类型的方法总结
May 20 Python
基于keras输出中间层结果的2种实现方式
Jan 24 Python
Ubuntu中配置TensorFlow使用环境的方法
Apr 21 Python
python从PDF中提取数据的示例
Oct 30 Python
Python中的socket网络模块介绍
Jul 23 Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
Selenium自动化测试工具使用方法汇总
Jun 12 #Python
Python使用socketServer包搭建简易服务器过程详解
Jun 12 #Python
Django之腾讯云短信的实现
Jun 12 #Python
You might like
ThinkPHP框架实现session跨域问题的解决方法
2014/07/01 PHP
PHP生成随机数的方法实例分析
2015/01/22 PHP
PHP实现的memcache环形队列类实例
2015/07/28 PHP
php记录搜索引擎爬行记录的实现代码
2018/03/02 PHP
网页禁用右键实现代码(JavaScript代码)
2009/10/29 Javascript
取选中的radio的值
2010/01/11 Javascript
判断字符串的长度(优化版)中文占两个字符
2014/10/30 Javascript
JS替换字符串中空格方法
2015/04/17 Javascript
设置点击文本框或图片弹出日历控件的实现代码
2016/05/12 Javascript
Javascript+CSS3实现进度条效果
2016/10/28 Javascript
jQuery的extend方法【三种】
2016/12/14 Javascript
Node.js利用Net模块实现多人命令行聊天室的方法
2016/12/23 Javascript
VUE前端cookie简单操作
2017/10/17 Javascript
vue轮播图插件vue-concise-slider的使用
2018/03/13 Javascript
NVM安装nodejs的方法实用步骤
2019/01/16 NodeJs
layui 数据表格复选框实现单选功能的例子
2019/09/19 Javascript
复制粘贴功能的Python程序
2008/04/04 Python
利用Python操作消息队列RabbitMQ的方法教程
2017/07/19 Python
python实现ID3决策树算法
2017/12/20 Python
python numpy 按行归一化的实例
2019/01/21 Python
Python-Tkinter Text输入内容在界面显示的实例
2019/07/12 Python
python读取dicom图像示例(SimpleITK和dicom包实现)
2020/01/16 Python
python实现在线翻译功能
2020/03/03 Python
CSS3 选择器 伪类选择器介绍
2012/01/21 HTML / CSS
加拿大当代时尚服饰、配饰和鞋类专业零售商和制造商:LE CHÂTEAU
2017/10/06 全球购物
中层干部岗位职责
2013/12/18 职场文书
单位在职证明范本
2014/01/09 职场文书
上课迟到检讨书100字
2014/01/11 职场文书
甜美蛋糕店创业计划书
2014/01/30 职场文书
教师师德反思材料
2014/02/15 职场文书
如何写股份合作协议书
2014/09/11 职场文书
感谢信模板大全
2015/01/23 职场文书
律政俏佳人观后感
2015/06/09 职场文书
机关单位2016年创先争优活动总结
2016/04/05 职场文书
Pytorch中TensorBoard及torchsummary的使用详解
2021/05/12 Python
Python实现Hash算法
2022/03/18 Python