Keras-多输入多输出实例(多任务)


Posted in Python onJune 22, 2020

1、模型结果设计

Keras-多输入多输出实例(多任务)

2、代码

from keras import Input, Model
from keras.layers import Dense, Concatenate
import numpy as np
from keras.utils import plot_model
from numpy import random as rd

samples_n = 3000
samples_dim_01 = 2
samples_dim_02 = 2
# 样本数据
x1 = rd.rand(samples_n, samples_dim_01)
x2 = rd.rand(samples_n, samples_dim_02)
y_1 = []
y_2 = []
y_3 = []
for x11, x22 in zip(x1, x2):
  y_1.append(np.sum(x11) + np.sum(x22))
  y_2.append(np.max([np.max(x11), np.max(x22)]))
  y_3.append(np.min([np.min(x11), np.min(x22)]))
y_1 = np.array(y_1)
y_1 = np.expand_dims(y_1, axis=1)
y_2 = np.array(y_2)
y_2 = np.expand_dims(y_2, axis=1)
y_3 = np.array(y_3)
y_3 = np.expand_dims(y_3, axis=1)

# 输入层
inputs_01 = Input((samples_dim_01,), name='input_1')
inputs_02 = Input((samples_dim_02,), name='input_2')
# 全连接层
dense_01 = Dense(units=3, name="dense_01", activation='softmax')(inputs_01)
dense_011 = Dense(units=3, name="dense_011", activation='softmax')(dense_01)
dense_02 = Dense(units=6, name="dense_02", activation='softmax')(inputs_02)
# 加入合并层
merge = Concatenate()([dense_011, dense_02])
# 分成两类输出 --- 输出01
output_01 = Dense(units=6, activation="relu", name='output01')(merge)
output_011 = Dense(units=1, activation=None, name='output011')(output_01)
# 分成两类输出 --- 输出02
output_02 = Dense(units=1, activation=None, name='output02')(merge)
# 分成两类输出 --- 输出03
output_03 = Dense(units=1, activation=None, name='output03')(merge)
# 构造一个新模型
model = Model(inputs=[inputs_01, inputs_02], outputs=[output_011,
                           output_02,
                           output_03
                           ])
# 显示模型情况
plot_model(model, show_shapes=True)
print(model.summary())
# # 编译
# model.compile(optimizer="adam", loss='mean_squared_error', loss_weights=[1,
#                                     0.8,
#                                     0.8
#                                     ])
# # 训练
# model.fit([x1, x2], [y_1,
#           y_2,
#           y_3
#           ], epochs=50, batch_size=32, validation_split=0.1)

# 以下的方法可灵活设置
model.compile(optimizer='adam',
       loss={'output011': 'mean_squared_error',
          'output02': 'mean_squared_error',
          'output03': 'mean_squared_error'},
       loss_weights={'output011': 1,
              'output02': 0.8,
              'output03': 0.8})
model.fit({'input_1': x1,
      'input_2': x2},
     {'output011': y_1,
      'output02': y_2,
      'output03': y_3},
     epochs=50, batch_size=32, validation_split=0.1)

# 预测
test_x1 = rd.rand(1, 2)
test_x2 = rd.rand(1, 2)
test_y = model.predict(x=[test_x1, test_x2])
# 测试
print("测试结果:")
print("test_x1:", test_x1, "test_x2:", test_x2, "y:", test_y, np.sum(test_x1) + np.sum(test_x2))

补充知识:Keras多输出(多任务)如何设置fit_generator

在使用Keras的时候,因为需要考虑到效率问题,需要修改fit_generator来适应多输出

# create model
model = Model(inputs=x_inp, outputs=[main_pred, aux_pred])
# complie model
model.compile(
  optimizer=optimizers.Adam(lr=learning_rate),
  loss={"main": weighted_binary_crossentropy(weights), "auxiliary":weighted_binary_crossentropy(weights)},
  loss_weights={"main": 0.5, "auxiliary": 0.5},
  metrics=[metrics.binary_accuracy],
)
# Train model
model.fit_generator(
  train_gen, epochs=num_epochs, verbose=0, shuffle=True
)

看Keras官方文档:

generator: A generator or an instance of Sequence (keras.utils.Sequence) object in order to avoid duplicate data when using multiprocessing. The output of the generator must be either

a tuple (inputs, targets)

a tuple (inputs, targets, sample_weights).

Keras设计多输出(多任务)使用fit_generator的步骤如下:

根据官方文档,定义一个generator或者一个class继承Sequence

class Batch_generator(Sequence):
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 y_batch = {'main':batch_1,'auxiliary':batch_2}
 return X_batch, y_batch

# or in another way
def batch_generator():
 """
 用于产生batch_1, batch_2(记住是numpy.array格式转换)
 """
 yield X_batch, {'main': batch_1,'auxiliary':batch_2}

重要的事情说三遍(亲自采坑,搜了一大圈才发现滴):

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

如果是多输出(多任务)的时候,这里的target是字典类型

以上这篇Keras-多输入多输出实例(多任务)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
Python机器学习之决策树算法实例详解
Dec 06 Python
python利用OpenCV2实现人脸检测
Apr 16 Python
Numpy数组的保存与读取方法
Apr 04 Python
python docx 中文字体设置的操作方法
May 08 Python
Python3实现将本地JSON大数据文件写入MySQL数据库的方法
Jun 13 Python
Python爬虫之正则表达式的使用教程详解
Oct 25 Python
python 获取一个值在某个区间的指定倍数的值方法
Nov 12 Python
python爬取内容存入Excel实例
Feb 20 Python
Python调用.NET库的方法步骤
Dec 27 Python
Python 中由 yield 实现异步操作
May 04 Python
Python命名空间及作用域原理实例解析
Aug 12 Python
如何用tempfile库创建python进程中的临时文件
Jan 28 Python
python和c语言哪个更适合初学者
Jun 22 #Python
Virtualenv 搭建 Py项目运行环境的教程详解
Jun 22 #Python
终于搞懂了Keras中multiloss的对应关系介绍
Jun 22 #Python
keras 多任务多loss实例
Jun 22 #Python
python对execl 处理操作代码
Jun 22 #Python
Python select及selectors模块概念用法详解
Jun 22 #Python
tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例
Jun 22 #Python
You might like
无数据库的详细域名查询程序PHP版(1)
2006/10/09 PHP
Laravel 4 初级教程之Pages、表单验证
2014/10/30 PHP
如何使用GDB调试PHP程序
2015/12/08 PHP
php使用Jpgraph创建柱状图展示年度收支表效果示例
2017/02/15 PHP
Laravel开启跨域请求的方法
2019/10/13 PHP
javascript写的一个链表实现代码
2009/10/25 Javascript
JQuery的一些小应用收集
2010/03/27 Javascript
一款js和css代码压缩工具[附JAVA环境配置方法]
2010/04/16 Javascript
jQuery弹出层始终垂直居中相对于屏幕或当前窗口
2013/04/01 Javascript
自己封装的javascript事件队列函数版
2014/06/12 Javascript
js与C#进行时间戳转换
2014/11/14 Javascript
推荐4个原生javascript常用的函数
2015/01/12 Javascript
jquery实现表格隔行换色效果
2015/11/19 Javascript
AngularJS入门教程之AngularJS模型
2016/04/18 Javascript
JavaScript的Ext JS框架中的GridPanel组件使用指南
2016/05/21 Javascript
jQuery事件对象的属性和方法详解
2017/09/09 jQuery
学习jQuery中的noConflict()用法
2018/09/28 jQuery
Python实现二维有序数组查找的方法
2016/04/27 Python
非递归的输出1-N的全排列实例(推荐)
2017/04/11 Python
Python内置函数reversed()用法分析
2018/03/20 Python
PyQt5实现简单数据标注工具
2019/03/18 Python
python实现kNN算法识别手写体数字的示例代码
2019/08/16 Python
pytorch 中pad函数toch.nn.functional.pad()的用法
2020/01/08 Python
美国受欢迎的眼影品牌:BH Cosmetics
2016/10/25 全球购物
英国莱斯特松木橡木家具网上商店:Choice Furniture Superstore
2019/07/05 全球购物
澳大利亚家具商店:Freedom
2020/12/17 全球购物
应届生骨科医生求职信
2013/10/31 职场文书
小加工厂管理制度
2014/01/21 职场文书
年会活动策划方案
2014/01/23 职场文书
党员组织关系介绍信
2014/02/13 职场文书
见习期自我鉴定范文
2014/03/19 职场文书
企业节能减排实施方案
2014/03/19 职场文书
乡镇党的群众路线教育实践活动领导班子对照检查材料
2014/09/25 职场文书
党员倡议书
2015/01/19 职场文书
学校百日安全活动总结
2015/05/07 职场文书
使用Selenium实现微博爬虫(预登录、展开全文、翻页)
2021/04/13 Python