keras分类模型中的输入数据与标签的维度实例


Posted in Python onJuly 03, 2020

在《python深度学习》这本书中。

一、21页mnist十分类

导入数据集
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

初始数据维度:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

数据预处理:
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
  
之后:
print(train_images, type(train_images), train_images.shape, train_images.dtype)
print(train_labels, type(train_labels), train_labels.shape, train_labels.dtype)
结果:
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]] <class 'numpy.ndarray'> (60000, 784) float32
[[0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]] <class 'numpy.ndarray'> (60000, 10) float32

二、51页IMDB二分类

导入数据:

from keras.datasets import imdb (train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

参数 num_words=10000 的意思是仅保留训练数据中前 10 000 个最常出现的单词。

train_data和test_data都是numpy.ndarray类型,都是一维的(共25000个元素,相当于25000个list),其中每个list代表一条评论,每个list中的每个元素的值范围在0-9999 ,代表10000个最常见单词的每个单词的索引,每个list长度不一,因为每条评论的长度不一,例如train_data中的list最短的为11,最长的为189。

train_labels和test_labels都是含25000个元素(元素的值要不0或者1,代表两类)的list。

数据预处理:

# 将整数序列编码为二进制矩阵
def vectorize_sequences(sequences, dimension=10000):
 # Create an all-zero matrix of shape (len(sequences), dimension)
 results = np.zeros((len(sequences), dimension))
 for i, sequence in enumerate(sequences):
  results[i, sequence] = 1. # set specific indices of results[i] to 1s
 return results


x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)

第一种方式:shape为(25000,)
y_train = np.asarray(train_labels).astype('float32') #就用这种方式就行了
y_test = np.asarray(test_labels).astype('float32')
第二种方式:shape为(25000,1)
y_train = np.asarray(train_labels).astype('float32').reshape(25000, 1)
y_test = np.asarray(test_labels).astype('float32').reshape(25000, 1)
第三种方式:shape为(25000,2)
y_train = to_categorical(train_labels) #变成one-hot向量
y_test = to_categorical(test_labels)

第三种方式,相当于把二分类看成了多分类,所以网络的结构同时需要更改,

最后输出的维度:1->2

最后的激活函数:sigmoid->softmax

损失函数:binary_crossentropy->categorical_crossentropy

预处理之后,train_data和test_data变成了shape为(25000,10000),dtype为float32的ndarray(one-hot向量),train_labels和test_labels变成了shape为(25000,)的一维ndarray,或者(25000,1)的二维ndarray,或者shape为(25000,2)的one-hot向量。

注:

1.sigmoid对应binary_crossentropy,softmax对应categorical_crossentropy

2.网络的所有输入和目标都必须是浮点数张量

补充知识:keras输入数据的方法:model.fit和model.fit_generator

1.第一种,普通的不用数据增强的

from keras.datasets import mnist,cifar10,cifar100
(X_train, y_train), (X_valid, Y_valid) = cifar10.load_data() 
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, shuffle=True,
    verbose=1, validation_data=(X_valid, Y_valid), )

2.第二种,带数据增强的 ImageDataGenerator,可以旋转角度、平移等操作。

from keras.preprocessing.image import ImageDataGenerator
(trainX, trainY), (testX, testY) = cifar100.load_data()
trainX = trainX.astype('float32')
testX = testX.astype('float32')
trainX /= 255.
testX /= 255.
Y_train = np_utils.to_categorical(trainY, nb_classes)
Y_test = np_utils.to_categorical(testY, nb_classes)
generator = ImageDataGenerator(rotation_range=15,
        width_shift_range=5./32,
        height_shift_range=5./32)
generator.fit(trainX, seed=0)
model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size),
     steps_per_epoch=len(trainX) // batch_size, epochs=nb_epoch,
     callbacks=callbacks,
     validation_data=(testX, Y_test),
     validation_steps=testX.shape[0] // batch_size, verbose=1)

以上这篇keras分类模型中的输入数据与标签的维度实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python算法演练_One Rule 算法(详解)
May 17 Python
Python环境搭建之OpenCV的步骤方法
Oct 20 Python
基于Django filter中用contains和icontains的区别(详解)
Dec 12 Python
利用Anaconda简单安装scrapy框架的方法
Jun 13 Python
Django中日期处理注意事项与自定义时间格式转换详解
Aug 06 Python
python读取xlsx的方法
Dec 25 Python
python文件写入write()的操作
May 14 Python
python+opencv像素的加减和加权操作的实现
Jul 14 Python
python自动化工具之pywinauto实例详解
Aug 26 Python
在python shell中运行python文件的实现
Dec 21 Python
python GUI库图形界面开发之PyQt5简单绘图板实例与代码分析
Mar 08 Python
python使用隐式循环快速求和的实现示例
Sep 11 Python
keras自动编码器实现系列之卷积自动编码器操作
Jul 03 #Python
Python with语句用法原理详解
Jul 03 #Python
Keras搭建自编码器操作
Jul 03 #Python
python 识别登录验证码图片功能的实现代码(完整代码)
Jul 03 #Python
python图片验证码识别最新模块muggle_ocr的示例代码
Jul 03 #Python
keras topN显示,自编写代码案例
Jul 03 #Python
python如何使用代码运行助手
Jul 03 #Python
You might like
PHP编程最快明白(第一讲 软件环境和准备工作)
2010/10/25 PHP
PHP+SQL 注入攻击的技术实现以及预防办法
2010/12/29 PHP
解决PHP4.0 和 PHP5.0类构造函数的兼容问题
2013/08/01 PHP
php function用法如何递归及return和echo区别
2014/03/07 PHP
php 使用GD库为页面增加水印示例代码
2014/03/24 PHP
PHP抓取淘宝商品的用户晒单评论+图片+搜索商品列表实例
2016/04/14 PHP
YII框架中使用memcache的方法详解
2017/08/02 PHP
PHP实现基于图的深度优先遍历输出1,2,3...n的全排列功能
2017/11/10 PHP
php中用unset销毁变量并释放内存
2020/05/10 PHP
超越Jquery_01_isPlainObject分析与重构
2010/10/20 Javascript
JavaScript声明变量时为什么要加var关键字
2014/09/29 Javascript
如何用js 实现依赖注入的思想,后端框架思想搬到前端来
2015/08/03 Javascript
JavaScript设计模式经典之命令模式
2016/02/24 Javascript
javascript HTML5文件上传FileReader API
2020/03/27 Javascript
BootStrap的弹出框(Popover)支持鼠标移到弹出层上弹窗层不隐藏的原因及解决办法
2016/04/03 Javascript
JS for...in 遍历语句用法实例分析
2016/08/24 Javascript
JS实现“隐藏与显示”功能(多种方法)
2016/11/24 Javascript
js 实现获取name 相同的页面元素并循环遍历的方法
2017/02/14 Javascript
js实现倒计时效果(小于10补零)
2017/03/08 Javascript
使用pkg打包ThinkJS项目的方法步骤
2019/12/30 Javascript
解决vuex数据页面刷新后初始化操作
2020/07/26 Javascript
Python下的twisted框架入门指引
2015/04/15 Python
PyChar学习教程之自定义文件与代码模板详解
2017/07/17 Python
python实现彩票系统
2020/06/28 Python
python实现RabbitMQ的消息队列的示例代码
2018/11/08 Python
Python模块zipfile原理及使用方法详解
2020/08/04 Python
HTML5 Canvas概述
2009/08/26 HTML / CSS
Kidsroom台湾:来自德国的婴儿用品
2017/12/11 全球购物
华纳兄弟工作室的官方授权商店:WB Shop
2018/11/30 全球购物
Athleta官网:购买女士瑜伽服、技术运动服和休闲运动服
2020/11/12 全球购物
什么是用户模式(User Mode)与内核模式(Kernel Mode) ?
2015/09/07 面试题
四年级评语大全
2014/04/21 职场文书
银行竞聘演讲稿范文
2014/04/23 职场文书
乡镇八一建军节活动方案
2014/08/24 职场文书
90行Python代码开发个人云盘应用
2021/04/20 Python
使用MybatisPlus打印sql语句
2022/04/22 SQL Server