keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python实现查找excel里某一列重复数据并且剔除后打印的方法
May 26 Python
Python fileinput模块使用实例
Jun 03 Python
Python中urllib+urllib2+cookielib模块编写爬虫实战
Jan 20 Python
Python学习小技巧之列表项的排序
May 20 Python
Python编程实战之Oracle数据库操作示例
Jun 21 Python
Python使用文件锁实现进程间同步功能【基于fcntl模块】
Oct 16 Python
Python内置模块turtle绘图详解
Dec 09 Python
python实现年会抽奖程序
Jan 22 Python
Pandas操作CSV文件的读写实现方法
Nov 13 Python
Python面向对象程序设计之静态方法、类方法、属性方法原理与用法分析
Mar 23 Python
numpy矩阵数值太多不能全部显示的解决
May 14 Python
利用Python判断整数是否是回文数的3种方法总结
Jul 07 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
Phpbean路由转发的php代码
2008/01/10 PHP
说说JSON和JSONP 也许你会豁然开朗
2012/09/02 Javascript
form表单action提交的js部分与html部分
2014/01/07 Javascript
alert和confirm功能介绍
2014/05/21 Javascript
javascript实现回车键提交表单方法总结
2015/01/10 Javascript
基于jQuery实现在线选座之高铁版
2015/08/24 Javascript
详解javascript高级定时器
2015/12/31 Javascript
javascript HTML+CSS实现经典橙色导航菜单
2016/02/16 Javascript
jQuery日历插件datepicker用法详解
2016/03/03 Javascript
jQuery插件pagination实现无刷新分页
2016/05/21 Javascript
性能优化之代码优化页面加载速度
2017/03/01 Javascript
Vue.js实现一个SPA登录页面的过程【推荐】
2017/04/29 Javascript
使用原生js封装的ajax实例(兼容jsonp)
2017/10/12 Javascript
js判断输入框不能为空格或null值的实现方法
2018/03/02 Javascript
Vue SPA单页应用首屏优化实践
2018/06/28 Javascript
Nginx设置为Node.js的前端服务器方法总结
2019/03/27 Javascript
vue指令做滚动加载和监听等
2019/05/26 Javascript
[02:39]DOTA2英雄基础教程 极限穿梭编织者
2013/12/05 DOTA
Python zip()函数用法实例分析
2018/03/17 Python
使用requests库制作Python爬虫
2018/03/25 Python
Python快速转换numpy数组中Nan和Inf的方法实例说明
2019/02/21 Python
Django Rest framework三种分页方式详解
2019/07/26 Python
如何在 Django 模板中输出 "{{"
2020/01/24 Python
video结合canvas实现视频在线截图功能
2018/06/25 HTML / CSS
kmart凯马特官网:美国最大的打折零售商和全球最大的批发商之一
2016/11/17 全球购物
古驰英国官网:GUCCI英国
2020/03/07 全球购物
牵手50台湾:专为黄金岁月的单身人士而设的交友网站
2021/02/18 全球购物
2019年Java面试必问之经典试题
2012/09/12 面试题
大学毕业生通用自荐信范文
2013/10/31 职场文书
审核会计岗位职责
2013/11/08 职场文书
心得体会开头
2014/01/01 职场文书
文明市民先进事迹
2014/05/15 职场文书
车辆转让协议书
2014/09/24 职场文书
写给导师的自荐信
2015/03/06 职场文书
计算机专业自荐信范文
2015/03/26 职场文书
文明医院的标语集锦!
2019/07/24 职场文书