Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
详解python中executemany和序列的使用方法
Aug 12 Python
Python 基础教程之str和repr的详解
Aug 20 Python
Python实现小数转化为百分数的格式化输出方法示例
Sep 20 Python
详解Python3.6的py文件打包生成exe
Jul 13 Python
python 在指定范围内随机生成不重复的n个数实例
Jan 28 Python
Python mutiprocessing多线程池pool操作示例
Jan 30 Python
python生成每日报表数据(Excel)并邮件发送的实例
Feb 03 Python
window7下的python2.7版本和python3.5版本的opencv-python安装过程
Oct 24 Python
计算pytorch标准化(Normalize)所需要数据集的均值和方差实例
Jan 15 Python
基于python的docx模块处理word和WPS的docx格式文件方式
Feb 13 Python
Python sklearn库实现PCA教程(以鸢尾花分类为例)
Feb 24 Python
Python实现邮件发送的详细设置方法(遇到问题)
Jan 18 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
php 无限级 SelectTree 类
2009/05/19 PHP
屏蔽机器人从你的网站搜取email地址的php代码
2012/11/14 PHP
如何解决phpmyadmin导入数据库文件最大限制2048KB
2015/10/09 PHP
PHP中PCRE正则解析代码详解
2019/04/26 PHP
XRegExp 0.2: Now With Named Capture
2007/11/30 Javascript
jQuery设置div一直在页面顶部显示的方法
2013/10/24 Javascript
javascript获取所有同类checkbox选项(实例代码)
2013/11/07 Javascript
js跳转页面方法实现汇总
2014/02/11 Javascript
轻松创建nodejs服务器(10):处理上传图片
2014/12/18 NodeJs
jQuery简单实现仿京东分类导航层效果
2016/06/07 Javascript
connection reset by peer问题总结及解决方案
2016/10/21 Javascript
JS实现获取来自百度,Google,soso,sogou关键词的方法
2016/12/21 Javascript
JavaScript实现汉字转换为拼音的库文件示例
2016/12/22 Javascript
基于JS递归函数细化认识及实用实例(推荐)
2017/08/07 Javascript
EasyUI在Panel上动态添加LinkButton按钮
2017/08/11 Javascript
vue结合axios与后端进行ajax交互的方法
2018/07/06 Javascript
图片文字识别(OCR)插件Ocrad.js教程
2018/11/26 Javascript
什么时候不能在 Node.js 中使用 Lock Files
2019/06/24 Javascript
Vue按时间段查询数据组件使用详解
2020/08/21 Javascript
vscode中的vue项目报错Property ‘xxx‘ does not exist on type ‘CombinedVueInstance<{ readyOnly...Vetur(2339)
2020/09/11 Javascript
[00:09]DOTA2全国高校联赛 精彩活动引爆全场
2018/05/30 DOTA
用python实现面向对像的ASP程序实例
2014/11/10 Python
Python脚本实现自动发带图的微博
2016/04/27 Python
强悍的Python读取大文件的解决方案
2019/02/16 Python
python3编写ThinkPHP命令执行Getshell的方法
2019/02/26 Python
Django 导出项目依赖库到 requirements.txt过程解析
2019/08/23 Python
PyQt5实现画布小程序
2020/05/30 Python
如何在Anaconda中打开python自带idle
2020/09/21 Python
全球最大的房车租赁市场:Outdoorsy
2018/09/19 全球购物
自1926年以来就为冰岛保持温暖:66°North
2020/11/27 全球购物
单位实习证明怎么写
2014/01/17 职场文书
本科毕业生应聘自荐信范文
2014/06/26 职场文书
群众路线教育实践活动实施方案
2014/10/31 职场文书
出国留学单位推荐信
2015/03/26 职场文书
微信小程序APP的事件绑定以及传递参数时的冒泡和捕获
2022/04/19 Javascript
springboot集成redis存对象乱码的问题及解决
2022/06/16 Java/Android