Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解Python中的条件判断语句
May 14 Python
Python实现简单字典树的方法
Apr 29 Python
Python中类型检查的详细介绍
Feb 13 Python
Python数据结构与算法之图的最短路径(Dijkstra算法)完整实例
Dec 12 Python
Python使用matplotlib填充图形指定区域代码示例
Jan 16 Python
python pytest进阶之fixture详解
Jun 27 Python
Python socket聊天脚本代码实例
Jan 02 Python
python读取当前目录下的CSV文件数据
Mar 11 Python
django执行原始查询sql,并返回Dict字典例子
Apr 01 Python
python异常处理之try finally不报错的原因
May 18 Python
python基于机器学习预测股票交易信号
May 25 Python
Python 中 Shutil 模块详情
Nov 11 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
php 数组的指针操作实现代码
2011/02/08 PHP
php实现连接access数据库并转txt写入的方法
2017/02/08 PHP
PHP中危险的file_put_contents函数详解
2017/11/04 PHP
JQuery读取XML文件数据并显示的实现代码
2009/12/16 Javascript
jQuery学习笔记之jQuery选择器的使用
2010/12/22 Javascript
原生js做的手风琴效果的导航菜单
2013/11/08 Javascript
jQuery中使用Ajax获取JSON格式数据示例代码
2013/11/26 Javascript
《JavaScript DOM 编程艺术》读书笔记之JavaScript 简史
2015/01/09 Javascript
js淡入淡出的图片轮播效果代码分享
2015/08/24 Javascript
jQuery实现美观的多级动画效果菜单代码
2015/09/06 Javascript
基于jQuery通过jQuery.form.js插件实现异步上传
2015/12/13 Javascript
AngularJS基础 ng-keypress 指令简单示例
2016/08/02 Javascript
基于jQuery实现简单人工智能聊天室
2017/02/10 Javascript
在Vue组件化中利用axios处理ajax请求的使用方法
2017/08/25 Javascript
解决linux下node.js全局模块找不到的问题
2018/05/15 Javascript
Vue2.0仿饿了么webapp单页面应用详细步骤
2018/07/08 Javascript
详解webpack打包第三方类库的正确姿势
2018/10/20 Javascript
Javascript实现html转pdf高清版(提高分辨率)
2020/02/19 Javascript
深入解析微信小程序开发中遇到的几个小问题
2020/07/11 Javascript
javascript实现京东登录显示隐藏密码
2020/08/02 Javascript
Hadoop中的Python框架的使用指南
2015/04/22 Python
python使用正则表达式匹配字符串开头并打印示例
2017/01/11 Python
Python语言生成水仙花数代码示例
2017/12/18 Python
python OpenCV学习笔记之绘制直方图的方法
2018/02/08 Python
Python打印输出数组中全部元素
2018/03/13 Python
PyQt5 窗口切换与自定义对话框的实例
2019/06/20 Python
Python使用scrapy爬取阳光热线问政平台过程解析
2019/08/14 Python
使用python将excel数据导入数据库过程详解
2019/08/27 Python
wxPython:python首选的GUI库实例分享
2019/10/05 Python
python实现的多任务版udp聊天器功能案例
2019/11/13 Python
最新PyCharm 2020.2.3永久激活码(亲测有效)
2020/11/26 Python
豆腐の盛田屋官网:日本自然派的豆乳面膜、肥皂、化妆水、乳液等
2016/10/08 全球购物
一套软件测试笔试题
2014/07/25 面试题
优秀经理事迹材料
2014/02/01 职场文书
评先进个人材料
2014/12/29 职场文书
余世维讲座观后感
2015/06/11 职场文书