keras导入weights方式


Posted in Python onJune 12, 2020

keras源码engine中toplogy.py定义了加载权重的函数:

load_weights(self, filepath, by_name=False)

其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16

VGG19/resnet50等,源码描述如下:

If `by_name` is False (default) weights are loaded
based on the network's topology, meaning the architecture
should be the same as when the weights were saved.
Note that layers that don't have weights are not taken
into account in the topological ordering, so adding or
removing layers is fine as long as they don't have weights.

若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了

模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用,源码描述如下:

If `by_name` is True, weights are loaded into layers
only if they share the same name. This is useful
for fine-tuning or transfer-learning models where
some of the layers have changed.

在进行边缘检测时,利用VGG网络的主体结构,网络中增加反卷积层,这时加载权重应该使用

model.load_weights(filepath,by_name=True)

补充知识:Keras下实现mnist手写数字

之前一直在用tensorflow,被同学推荐来用keras了,把之前文档中的mnist手写数字数据集拿来练手,

代码如下。

import struct
import numpy as np
import os
 
import keras
from keras.models import Sequential 
from keras.layers import Dense
from keras.optimizers import SGD
 
def load_mnist(path, kind):
  labels_path = os.path.join(path, '%s-labels.idx1-ubyte' % kind)
  images_path = os.path.join(path, '%s-images.idx3-ubyte' % kind)
  with open(labels_path, 'rb') as lbpath:
    magic, n = struct.unpack('>II', lbpath.read(8))
    labels = np.fromfile(lbpath, dtype=np.uint8)
  with open(images_path, 'rb') as imgpath:
    magic, num, rows, cols = struct.unpack(">IIII", imgpath.read(16))
    images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784) #28*28=784
  return images, labels
 
#loading train and test data
X_train, Y_train = load_mnist('.\\data', kind='train')
X_test, Y_test = load_mnist('.\\data', kind='t10k')
 
#turn labels to one_hot code
Y_train_ohe = keras.utils.to_categorical(Y_train, num_classes=10)
 
#define models
model = Sequential()
 
model.add(Dense(input_dim=X_train.shape[1],output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=50,init='uniform',activation='tanh'))
model.add(Dense(input_dim=50,output_dim=Y_train_ohe.shape[1],init='uniform',activation='softmax')) 
 
sgd = SGD(lr=0.001, decay=1e-7, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=["accuracy"])
 
#start training
model.fit(X_train,Y_train_ohe,epochs=50,batch_size=300,shuffle=True,verbose=1,validation_split=0.3)
 
#count accuracy
y_train_pred = model.predict_classes(X_train, verbose=0)
 
train_acc = np.sum(Y_train == y_train_pred, axis=0) / X_train.shape[0] 
print('Training accuracy: %.2f%%' % (train_acc * 100))
 
y_test_pred = model.predict_classes(X_test, verbose=0)
test_acc = np.sum(Y_test == y_test_pred, axis=0) / X_test.shape[0] 
print('Test accuracy: %.2f%%' % (test_acc * 100))

训练结果如下:

Epoch 45/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2174 - acc: 0.9380 - val_loss: 0.2341 - val_acc: 0.9323
Epoch 46/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2061 - acc: 0.9404 - val_loss: 0.2244 - val_acc: 0.9358
Epoch 47/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1994 - acc: 0.9413 - val_loss: 0.2295 - val_acc: 0.9347
Epoch 48/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.2003 - acc: 0.9413 - val_loss: 0.2224 - val_acc: 0.9350
Epoch 49/50
42000/42000 [==============================] - 1s 18us/step - loss: 0.2013 - acc: 0.9417 - val_loss: 0.2248 - val_acc: 0.9359
Epoch 50/50
42000/42000 [==============================] - 1s 17us/step - loss: 0.1960 - acc: 0.9433 - val_loss: 0.2300 - val_acc: 0.9346
Training accuracy: 94.11%
Test accuracy: 93.61%

以上这篇keras导入weights方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python自动化测试Eclipse+Pydev 搭建开发环境
Aug 15 Python
Python爬取十篇新闻统计TF-IDF
Jan 03 Python
浅谈pandas中Dataframe的查询方法([], loc, iloc, at, iat, ix)
Apr 10 Python
python TKinter获取文本框内容的方法
Oct 11 Python
python中的json总结
Oct 11 Python
python requests.post带head和body的实例
Jan 02 Python
Python3.4学习笔记之类型判断,异常处理,终止程序操作小结
Mar 01 Python
对python中UDP,socket的使用详解
Aug 22 Python
Python 读取xml数据,cv2裁剪图片实例
Mar 10 Python
Python使用plt.boxplot() 参数绘制箱线图
Jun 04 Python
使用Python中tkinter库简单gui界面制作及打包成exe的操作方法(二)
Oct 12 Python
python编写五子棋游戏
May 25 Python
keras读取h5文件load_weights、load代码操作
Jun 12 #Python
Python matplotlib 绘制双Y轴曲线图的示例代码
Jun 12 #Python
keras的siamese(孪生网络)实现案例
Jun 12 #Python
基于python实现模拟数据结构模型
Jun 12 #Python
Python-for循环的内部机制
Jun 12 #Python
Python Scrapy图片爬取原理及代码实例
Jun 12 #Python
Python Scrapy多页数据爬取实现过程解析
Jun 12 #Python
You might like
PHP,ASP.JAVA,JAVA代码格式化工具整理
2010/06/15 PHP
基于flush()不能按顺序输出时的解决办法
2013/06/29 PHP
PHP实现获取域名的方法小结
2014/11/05 PHP
PHP生成json和xml类型接口数据格式
2015/05/17 PHP
php中让人头疼的浮点数运算分析
2016/10/10 PHP
php 后端实现JWT认证方法示例
2018/09/04 PHP
PHP parse_ini_file函数的应用与扩展操作示例
2019/01/07 PHP
javascript中的几个运算符
2007/06/29 Javascript
javascript replace()正则替换实现代码
2010/02/26 Javascript
jQuery中:disabled选择器用法实例
2015/01/04 Javascript
Vuex 进阶之模块化组织详解
2018/01/12 Javascript
Angularjs中的$apply及优化使用详解
2018/07/02 Javascript
React组件内事件传参实现tab切换的示例代码
2018/07/04 Javascript
详解mpvue小程序中怎么引入iconfont字体图标
2018/10/01 Javascript
Vue编程式跳转的实例代码详解
2019/07/10 Javascript
weui中的picker使用js进行动态绑定数据问题
2019/11/06 Javascript
[00:39]DOTA2上海特级锦标赛 Liquid战队宣传片
2016/03/04 DOTA
[00:17]天涯墨客一技能展示
2018/08/25 DOTA
Django unittest 设置跳过某些case的方法
2018/12/26 Python
Django框架会话技术实例分析【Cookie与Session】
2019/05/24 Python
python批量修改ssh密码的实现
2019/08/08 Python
PyQt5+Caffe+Opencv搭建人脸识别登录界面
2019/08/28 Python
OpenCV灰度化之后图片为绿色的解决
2020/12/01 Python
利用CSS3参考手册和CSS3代码生成工具加速来学习网页制
2012/07/11 HTML / CSS
css3和jquery实现的可折叠导航菜单适合放在手机网页的导航菜单
2014/09/02 HTML / CSS
Bluebella美国官网:英国性感内衣品牌
2018/10/04 全球购物
美国时尚大码女装购物网站:Avenue
2019/05/24 全球购物
物业管理员岗位职责范文
2013/11/25 职场文书
车间班长岗位职责
2013/11/30 职场文书
同居协议书范本
2014/04/23 职场文书
信息管理与信息系统专业求职信
2014/06/21 职场文书
张家口市高新区党工委群众路线教育实践活动整改方案
2014/10/25 职场文书
2014年勤工助学工作总结
2014/11/24 职场文书
发布会邀请函
2015/01/31 职场文书
交通事故赔偿起诉书
2015/05/20 职场文书
nginx请求限制配置方法
2021/07/09 Servers