keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python open()文件处理使用介绍
Nov 30 Python
Python Queue模块详解
Nov 30 Python
wxPython中listbox用法实例详解
Jun 01 Python
Python有序字典简单实现方法示例
Sep 28 Python
python利用小波分析进行特征提取的实例
Jan 09 Python
让你Python到很爽的加速递归函数的装饰器
May 26 Python
Python爬虫爬取Bilibili弹幕过程解析
Oct 10 Python
python 动态调用函数实例解析
Oct 21 Python
tensorflow 实现打印pb模型的所有节点
Jan 23 Python
Python实现AI自动抠图实例解析
Mar 05 Python
matplotlib自定义鼠标光标坐标格式的实现
Jan 08 Python
Python万能模板案例之matplotlib绘制直方图的基本配置
Apr 13 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
PHP+Mysql+jQuery实现发布微博程序 jQuery篇
2011/10/08 PHP
页面只有一个text的时候,回车自动submit的解决方法
2010/08/12 Javascript
关于js datetime的那点事
2011/11/15 Javascript
DOM操作和jQuery实现选项移动操作的简单实例
2016/06/07 Javascript
jQuery animate easing使用方法图文详解
2016/06/17 Javascript
jQuery选择器总结之常用元素查找方法
2016/08/04 Javascript
JS实现颜色的10进制转化成rgba格式的方法
2017/09/04 Javascript
详解在Vue中有条件地使用CSS类
2017/09/30 Javascript
vue计算属性get和set用法示例
2019/02/08 Javascript
[00:53]2015国际邀请赛 中国区预选赛一触即发
2015/05/14 DOTA
[02:12]打造更好的电竞完美世界:完美盛典回顾篇
2018/12/19 DOTA
python应用程序在windows下不出现cmd窗口的办法
2014/05/29 Python
python函数装饰器用法实例详解
2015/06/04 Python
python 基础教程之Map使用方法
2017/01/17 Python
详谈python read readline readlines的区别
2017/09/22 Python
基于python中的TCP及UDP(详解)
2017/11/06 Python
PyQt5实现让QScrollArea支持鼠标拖动的操作方法
2019/06/19 Python
Numpy之reshape()使用详解
2019/12/26 Python
Python @property装饰器原理解析
2020/01/22 Python
python实现跨excel sheet复制代码实例
2020/03/03 Python
keras自定义损失函数并且模型加载的写法介绍
2020/06/15 Python
使用python修改文件并立即写回到原始位置操作(inplace读写)
2020/06/28 Python
如何在windows下安装配置python工具Ulipad
2020/10/27 Python
HTML5中Localstorage的使用教程
2015/07/09 HTML / CSS
整理HTML5移动端开发的常用触摸事件
2016/04/15 HTML / CSS
阿联酋航空丹麦官方网站:Emirates DK
2019/08/25 全球购物
应届大学生求职信
2013/12/01 职场文书
就业推荐表自我鉴定范文
2014/03/21 职场文书
家长寄语大全
2014/04/02 职场文书
创先争优宣传标语
2014/10/08 职场文书
村干部群众路线整改措施思想汇报
2014/10/12 职场文书
考试没考好检讨书(精选篇)
2014/11/16 职场文书
班主任自我评价范文
2015/03/11 职场文书
《抽屉原理》教学反思
2016/02/20 职场文书
python实现socket简单通信的示例代码
2021/04/13 Python
Java中的Kotlin 内部类原理
2022/06/16 Java/Android