深度学习小工程练习之垃圾分类详解


Posted in Python onApril 14, 2021

介绍

这是一个基于深度学习的垃圾分类小工程,用深度残差网络构建

软件架构

  1. 使用深度残差网络resnet50作为基石,在后续添加需要的层以适应不同的分类任务
  2. 模型的训练需要用生成器将数据集循环写入内存,同时图像增强以泛化模型
  3. 使用不包含网络输出部分的resnet50权重文件进行迁移学习,只训练我们在5个stage后增加的层

安装教程

  1. 需要的第三方库主要有tensorflow1.x,keras,opencv,Pillow,scikit-learn,numpy
  2. 安装方式很简单,打开terminal,例如:pip install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple
  3. 数据集与权重文件比较大,所以没有上传
  4. 如果环境配置方面有问题或者需要数据集与模型权重文件,可以在评论区说明您的问题,我将远程帮助您

使用说明

  1. 文件夹theory记录了我在本次深度学习中收获的笔记,与模型训练的控制台打印信息
  2. 迁移学习需要的初始权重与模型定义文件resnet50.py放在model
  3. 下训练运行trainNet.py,训练结束会创建models文件夹,并将结果权重garclass.h5写入该文件夹
  4. datagen文件夹下的genit.py用于进行图像预处理以及数据生成器接口
  5. 使用训练好的模型进行垃圾分类,运行Demo.py

结果演示

深度学习小工程练习之垃圾分类详解

cans易拉罐

深度学习小工程练习之垃圾分类详解

代码解释

在实际的模型中,我们只使用了resnet50的5个stage,后面的输出部分需要我们自己定制,网络的结构图如下:

深度学习小工程练习之垃圾分类详解

stage5后我们的定制网络如下:

"""定制resnet后面的层"""
def custom(input_size,num_classes,pretrain):
    # 引入初始化resnet50模型
    base_model = ResNet50(weights=pretrain,
                          include_top=False,
                          pooling=None,
                          input_shape=(input_size,input_size, 3),
                          classes=num_classes)
    #由于有预权重,前部分冻结,后面进行迁移学习
    for layer in base_model.layers:
        layer.trainable = False
    #添加后面的层
    x = base_model.output
    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    x = layers.Dropout(0.5,name='dropout1')(x)
    #regularizers正则化层,正则化器允许在优化过程中对层的参数或层的激活情况进行惩罚
    #对损失函数进行最小化的同时,也需要让对参数添加限制,这个限制也就是正则化惩罚项,使用l2范数
    x = layers.Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x)
    x = layers.BatchNormalization(name='bn_fc_01')(x)
    x = layers.Dropout(0.5,name='dropout2')(x)
    #40个分类
    x = layers.Dense(num_classes,activation='softmax')(x)
    model = Model(inputs=base_model.input,outputs=x)
    #模型编译
    model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy'])
    return model

网络的训练是迁移学习过程,使用已有的初始resnet50权重(5个stage已经训练过,卷积层已经能够提取特征),我们只训练后面的全连接层部分,4个epoch后再对较后面的层进行训练微调一下,获得更高准确率,训练过程如下:

class Net():
    def __init__(self,img_size,gar_num,data_dir,batch_size,pretrain):
        self.img_size=img_size
        self.gar_num=gar_num
        self.data_dir=data_dir
        self.batch_size=batch_size
        self.pretrain=pretrain
    def build_train(self):
        """迁移学习"""
        model = resnet.custom(self.img_size, self.gar_num, self.pretrain)
        model.summary()
        train_sequence, validation_sequence = genit.gendata(self.data_dir, self.batch_size, self.gar_num, self.img_size)
        epochs=4
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs,verbose=1,validation_data=validation_sequence,
                                     max_queue_size=10,shuffle=True)
        #微调,在实际工程中,激活函数也被算进层里,所以总共181层,微调是为了重新训练部分卷积层,同时训练最后的全连接层
        layers=149
        learning_rate=1e-4
        for layer in model.layers[:layers]:
            layer.trainable = False
        for layer in model.layers[layers:]:
            layer.trainable = True
        Adam =adam(lr=learning_rate, decay=0.0005)
        model.compile(optimizer=Adam, loss='categorical_crossentropy', metrics=['accuracy'])
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs * 2,verbose=1,
            callbacks=[
                callbacks.ModelCheckpoint('./models/garclass.h5',monitor='val_loss', save_best_only=True, mode='min'),
                callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,patience=10, mode='min'),
                callbacks.EarlyStopping(monitor='val_loss', patience=10),],
            validation_data=validation_sequence,max_queue_size=10,shuffle=True)
        print('finish train,look for garclass.h5')

训练结果如下:

"""
    loss: 0.7949 - acc: 0.9494 - val_loss: 0.9900 - val_acc: 0.8797
    训练用了9小时左右
    """

如果使用更好的显卡,可以更快完成训练

最后

希望大家可以体验到深度学习带来的收获,能和大家学习很开心,更多关于深度学习的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
Python运行的17个时新手常见错误小结
Aug 07 Python
Python中的模块和包概念介绍
Apr 13 Python
python通过openpyxl生成Excel文件的方法
May 12 Python
在Pycharm中项目解释器与环境变量的设置方法
Oct 29 Python
python3正则提取字符串里的中文实例
Jan 31 Python
Python使用pandas和xlsxwriter读写xlsx文件的方法示例
Apr 09 Python
Python Tkinter 简单登录界面的实现
Jun 14 Python
Python初学者常见错误详解
Jul 02 Python
python GUI库图形界面开发之PyQt5拖放控件实例详解
Feb 25 Python
Selenium webdriver添加cookie实现过程详解
Aug 12 Python
python实现发送QQ邮件(可加附件)
Dec 23 Python
python glom模块的使用简介
Apr 13 Python
python3美化表格数据输出结果的实现代码
Apr 14 #Python
Python生成九宫格图片的示例代码
用Python写一个简易版弹球游戏
python urllib库的使用详解
Apr 13 #Python
用Python将库打包发布到pypi
python xlwt模块的使用解析
python 爬取豆瓣网页的示例
You might like
详解php的魔术方法__get()和__set()使用介绍
2012/09/19 PHP
php解压文件代码实现php在线解压
2014/02/13 PHP
codeigniter中view通过循环显示数组数据的方法
2015/03/20 PHP
php微信公众平台示例代码分析(二)
2016/12/06 PHP
php无限级评论嵌套实现代码
2018/04/18 PHP
laravel-admin自动生成模块,及相关基础配置方法
2019/10/08 PHP
PHP pthreads v3使用中的一些坑和注意点分析
2020/02/21 PHP
js 未结束的字符串常量错误解决方法
2010/06/13 Javascript
JS小功能(操作Table--动态添加删除表格及数据)实现代码
2013/11/28 Javascript
基于promise.js实现nodejs的promises库
2014/07/06 NodeJs
jQuery实现点击任意位置弹出层外关闭弹出层效果
2016/10/19 Javascript
jquery插件canvaspercent.js实现百分比圆饼效果
2017/07/18 jQuery
用Vue.extend构建消息提示组件的方法实例
2017/08/08 Javascript
基于vue-cli创建的项目的目录结构及说明介绍
2017/11/23 Javascript
AngularJS中重新加载当前路由页面的方法
2018/03/09 Javascript
vue项目中vue-i18n和element-ui国际化开发实现过程
2018/04/25 Javascript
微信小程序学习笔记之获取位置信息操作图文详解
2019/03/29 Javascript
深入解析Vue源码实例挂载与编译流程实现思路详解
2019/05/05 Javascript
VUE路由动态加载实例代码讲解
2019/08/26 Javascript
JavaScript实现单图片上传并预览功能
2019/09/30 Javascript
浅谈vue 锚点指令v-anchor的使用
2019/11/13 Javascript
JavaScript逻辑运算符相关总结
2020/09/04 Javascript
使用Django的模版来配合字符串翻译工作
2015/07/27 Python
详解Python之数据序列化(json、pickle、shelve)
2017/03/30 Python
django 控制页面跳转的例子
2019/08/06 Python
From CSV to SQLite3 by python 导入csv到sqlite实例
2020/02/14 Python
python3中数组逆序输出方法
2020/12/01 Python
Myprotein台湾官方网站:全球领先的运动营养品牌
2018/12/10 全球购物
Hotels.com英国:全球领先的酒店住宿提供商
2019/01/24 全球购物
英国在线滑雪板和冲浪商店:The Board Basement
2020/01/11 全球购物
物流管理毕业生自荐信
2013/10/24 职场文书
万能检讨书
2015/01/27 职场文书
元旦晚会开场白
2015/05/29 职场文书
学习弘扬焦裕禄精神心得体会
2016/01/23 职场文书
如何书写公司员工保密协议?
2019/06/27 职场文书
Nginx跨域问题解析与解决
2022/08/05 Servers