深度学习小工程练习之垃圾分类详解


Posted in Python onApril 14, 2021

介绍

这是一个基于深度学习的垃圾分类小工程,用深度残差网络构建

软件架构

  1. 使用深度残差网络resnet50作为基石,在后续添加需要的层以适应不同的分类任务
  2. 模型的训练需要用生成器将数据集循环写入内存,同时图像增强以泛化模型
  3. 使用不包含网络输出部分的resnet50权重文件进行迁移学习,只训练我们在5个stage后增加的层

安装教程

  1. 需要的第三方库主要有tensorflow1.x,keras,opencv,Pillow,scikit-learn,numpy
  2. 安装方式很简单,打开terminal,例如:pip install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple
  3. 数据集与权重文件比较大,所以没有上传
  4. 如果环境配置方面有问题或者需要数据集与模型权重文件,可以在评论区说明您的问题,我将远程帮助您

使用说明

  1. 文件夹theory记录了我在本次深度学习中收获的笔记,与模型训练的控制台打印信息
  2. 迁移学习需要的初始权重与模型定义文件resnet50.py放在model
  3. 下训练运行trainNet.py,训练结束会创建models文件夹,并将结果权重garclass.h5写入该文件夹
  4. datagen文件夹下的genit.py用于进行图像预处理以及数据生成器接口
  5. 使用训练好的模型进行垃圾分类,运行Demo.py

结果演示

深度学习小工程练习之垃圾分类详解

cans易拉罐

深度学习小工程练习之垃圾分类详解

代码解释

在实际的模型中,我们只使用了resnet50的5个stage,后面的输出部分需要我们自己定制,网络的结构图如下:

深度学习小工程练习之垃圾分类详解

stage5后我们的定制网络如下:

"""定制resnet后面的层"""
def custom(input_size,num_classes,pretrain):
    # 引入初始化resnet50模型
    base_model = ResNet50(weights=pretrain,
                          include_top=False,
                          pooling=None,
                          input_shape=(input_size,input_size, 3),
                          classes=num_classes)
    #由于有预权重,前部分冻结,后面进行迁移学习
    for layer in base_model.layers:
        layer.trainable = False
    #添加后面的层
    x = base_model.output
    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    x = layers.Dropout(0.5,name='dropout1')(x)
    #regularizers正则化层,正则化器允许在优化过程中对层的参数或层的激活情况进行惩罚
    #对损失函数进行最小化的同时,也需要让对参数添加限制,这个限制也就是正则化惩罚项,使用l2范数
    x = layers.Dense(512,activation='relu',kernel_regularizer= regularizers.l2(0.0001),name='fc2')(x)
    x = layers.BatchNormalization(name='bn_fc_01')(x)
    x = layers.Dropout(0.5,name='dropout2')(x)
    #40个分类
    x = layers.Dense(num_classes,activation='softmax')(x)
    model = Model(inputs=base_model.input,outputs=x)
    #模型编译
    model.compile(optimizer="adam",loss = 'categorical_crossentropy',metrics=['accuracy'])
    return model

网络的训练是迁移学习过程,使用已有的初始resnet50权重(5个stage已经训练过,卷积层已经能够提取特征),我们只训练后面的全连接层部分,4个epoch后再对较后面的层进行训练微调一下,获得更高准确率,训练过程如下:

class Net():
    def __init__(self,img_size,gar_num,data_dir,batch_size,pretrain):
        self.img_size=img_size
        self.gar_num=gar_num
        self.data_dir=data_dir
        self.batch_size=batch_size
        self.pretrain=pretrain
    def build_train(self):
        """迁移学习"""
        model = resnet.custom(self.img_size, self.gar_num, self.pretrain)
        model.summary()
        train_sequence, validation_sequence = genit.gendata(self.data_dir, self.batch_size, self.gar_num, self.img_size)
        epochs=4
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs,verbose=1,validation_data=validation_sequence,
                                     max_queue_size=10,shuffle=True)
        #微调,在实际工程中,激活函数也被算进层里,所以总共181层,微调是为了重新训练部分卷积层,同时训练最后的全连接层
        layers=149
        learning_rate=1e-4
        for layer in model.layers[:layers]:
            layer.trainable = False
        for layer in model.layers[layers:]:
            layer.trainable = True
        Adam =adam(lr=learning_rate, decay=0.0005)
        model.compile(optimizer=Adam, loss='categorical_crossentropy', metrics=['accuracy'])
        model.fit_generator(train_sequence,steps_per_epoch=len(train_sequence),epochs=epochs * 2,verbose=1,
            callbacks=[
                callbacks.ModelCheckpoint('./models/garclass.h5',monitor='val_loss', save_best_only=True, mode='min'),
                callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1,patience=10, mode='min'),
                callbacks.EarlyStopping(monitor='val_loss', patience=10),],
            validation_data=validation_sequence,max_queue_size=10,shuffle=True)
        print('finish train,look for garclass.h5')

训练结果如下:

"""
    loss: 0.7949 - acc: 0.9494 - val_loss: 0.9900 - val_acc: 0.8797
    训练用了9小时左右
    """

如果使用更好的显卡,可以更快完成训练

最后

希望大家可以体验到深度学习带来的收获,能和大家学习很开心,更多关于深度学习的资料请关注三水点靠木其它相关文章!

Python 相关文章推荐
python将图片文件转换成base64编码的方法
Mar 14 Python
详解Python中heapq模块的用法
Jun 28 Python
详解Python各大聊天系统的屏蔽脏话功能原理
Dec 01 Python
python与php实现分割文件代码
Mar 06 Python
python用pickle模块实现“增删改查”的简易功能
Jun 07 Python
django实现同一个ip十分钟内只能注册一次的实例
Nov 03 Python
selenium+python自动化测试之多窗口切换
Jan 23 Python
对python while循环和双重循环的实例详解
Aug 23 Python
python烟花效果的代码实例
Feb 25 Python
Python pymysql模块安装并操作过程解析
Oct 13 Python
Python基础之hashlib模块详解
May 06 Python
教你利用python实现企业微信发送消息
May 23 Python
python3美化表格数据输出结果的实现代码
Apr 14 #Python
Python生成九宫格图片的示例代码
用Python写一个简易版弹球游戏
python urllib库的使用详解
Apr 13 #Python
用Python将库打包发布到pypi
python xlwt模块的使用解析
python 爬取豆瓣网页的示例
You might like
PHP 遍历XP文件夹下所有文件
2008/11/27 PHP
php include和require的区别深入解析
2013/06/17 PHP
thinkPHP显示不出验证码的原因与解决方法分析
2017/05/20 PHP
Mootools 1.2教程 滚动条(Slider)
2009/09/15 Javascript
使用jquery animate创建平滑滚动效果(可以是到顶部、到底部或指定地方)
2014/05/27 Javascript
JavaScript学习心得之概述
2015/01/20 Javascript
jquery使整个div区域可以点击的方法
2015/06/24 Javascript
jQuery实现向下滑出的平滑下拉菜单效果
2015/08/21 Javascript
JS实现淘宝支付宝网站的控制台菜单效果
2015/09/28 Javascript
JavaScript实现的多种鼠标拖放效果
2015/11/03 Javascript
JavaScript中关联原型链属性特性
2016/02/13 Javascript
AngularJS基础 ng-model-options 指令简单示例
2016/08/02 Javascript
jquery实现转盘抽奖功能
2017/01/06 Javascript
使用微信内嵌H5网页解决JS倒计时失效问题
2017/01/13 Javascript
Vue Cli与BootStrap结合实现表格分页功能
2017/08/18 Javascript
解决vue中对象属性改变视图不更新的问题
2018/02/23 Javascript
浅谈VueJS SSR 后端绘制内存泄漏的相关解决经验
2018/12/20 Javascript
jQuery实现简单弹幕效果
2019/11/28 jQuery
Anaconda 离线安装 python 包的操作方法
2018/06/11 Python
朴素贝叶斯分类算法原理与Python实现与使用方法案例
2018/06/26 Python
python实现五子棋小程序
2019/06/18 Python
Python Gitlab Api 使用方法
2019/08/28 Python
Python Sympy计算梯度、散度和旋度的实例
2019/12/06 Python
Django模板标签中url使用详解(url跳转到指定页面)
2020/03/19 Python
python Cartopy的基础使用详解
2020/11/01 Python
HTML5不支持标签和新增标签详解
2016/06/27 HTML / CSS
html5 横向滑动导航栏的方法示例
2020/05/08 HTML / CSS
英国在线药房:Chemist.co.uk
2019/03/26 全球购物
英语自荐信常用语句
2013/12/13 职场文书
导游的职业规划书范文
2013/12/27 职场文书
大学新闻系应届生求职信
2014/06/02 职场文书
安全环保标语
2014/06/09 职场文书
骨干教师申报材料
2014/12/17 职场文书
浅谈Python数学建模之线性规划
2021/06/23 Python
Win11运行育碧游戏总是崩溃怎么办 win11玩育碧游戏出现性能崩溃的解决办法
2022/04/06 数码科技
Python各协议下socket黏包问题原理
2022/04/12 Python