Keras:Unet网络实现多类语义分割方式


Posted in Python onJune 11, 2020

1 介绍

U-Net最初是用来对医学图像的语义分割,后来也有人将其应用于其他领域。但大多还是用来进行二分类,即将原始图像分成两个灰度级或者色度,依次找到图像中感兴趣的目标部分。

本文主要利用U-Net网络结构实现了多类的语义分割,并展示了部分测试效果,希望对你有用!

2 源代码

(1)训练模型

from __future__ import print_function
import os
import datetime
import numpy as np
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, AveragePooling2D, Dropout, \
 BatchNormalization
from keras.optimizers import Adam
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.callbacks import ModelCheckpoint
from keras import backend as K
from keras.layers.advanced_activations import LeakyReLU, ReLU
import cv2
 
PIXEL = 512 #set your image size
BATCH_SIZE = 5
lr = 0.001
EPOCH = 100
X_CHANNEL = 3 # training images channel
Y_CHANNEL = 1 # label iamges channel
X_NUM = 422 # your traning data number
 
pathX = 'I:\\Pascal VOC Dataset\\train1\\images\\' #change your file path
pathY = 'I:\\Pascal VOC Dataset\\train1\\SegmentationObject\\' #change your file path
 
#data processing
def generator(pathX, pathY,BATCH_SIZE):
 while 1:
  X_train_files = os.listdir(pathX)
  Y_train_files = os.listdir(pathY)
  a = (np.arange(1, X_NUM))
  X = []
  Y = []
  for i in range(BATCH_SIZE):
   index = np.random.choice(a)
   # print(index)
   img = cv2.imread(pathX + X_train_files[index], 1)
   img = np.array(img).reshape(PIXEL, PIXEL, X_CHANNEL)
   X.append(img)
   img1 = cv2.imread(pathY + Y_train_files[index], 1)
   img1 = np.array(img1).reshape(PIXEL, PIXEL, Y_CHANNEL)
   Y.append(img1)
 
  X = np.array(X)
  Y = np.array(Y)
  yield X, Y
 
 #creat unet network
inputs = Input((PIXEL, PIXEL, 3))
conv1 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
pool1 = AveragePooling2D(pool_size=(2, 2))(conv1) # 16
 
conv2 = BatchNormalization(momentum=0.99)(pool1)
conv2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
conv2 = BatchNormalization(momentum=0.99)(conv2)
conv2 = Conv2D(64, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
conv2 = Dropout(0.02)(conv2)
pool2 = AveragePooling2D(pool_size=(2, 2))(conv2) # 8
 
conv3 = BatchNormalization(momentum=0.99)(pool2)
conv3 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
conv3 = BatchNormalization(momentum=0.99)(conv3)
conv3 = Conv2D(128, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
conv3 = Dropout(0.02)(conv3)
pool3 = AveragePooling2D(pool_size=(2, 2))(conv3) # 4
 
conv4 = BatchNormalization(momentum=0.99)(pool3)
conv4 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
conv4 = BatchNormalization(momentum=0.99)(conv4)
conv4 = Conv2D(256, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
conv4 = Dropout(0.02)(conv4)
pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)
 
conv5 = BatchNormalization(momentum=0.99)(pool4)
conv5 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = BatchNormalization(momentum=0.99)(conv5)
conv5 = Conv2D(512, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
conv5 = Dropout(0.02)(conv5)
pool4 = AveragePooling2D(pool_size=(2, 2))(conv4)
# conv5 = Conv2D(35, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
# drop4 = Dropout(0.02)(conv5)
pool4 = AveragePooling2D(pool_size=(2, 2))(pool3) # 2
pool5 = AveragePooling2D(pool_size=(2, 2))(pool4) # 1
 
conv6 = BatchNormalization(momentum=0.99)(pool5)
conv6 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
 
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)
up7 = (UpSampling2D(size=(2, 2))(conv7)) # 2
conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up7)
merge7 = concatenate([pool4, conv7], axis=3)
 
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
up8 = (UpSampling2D(size=(2, 2))(conv8)) # 4
conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up8)
merge8 = concatenate([pool3, conv8], axis=3)
 
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
up9 = (UpSampling2D(size=(2, 2))(conv9)) # 8
conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up9)
merge9 = concatenate([pool2, conv9], axis=3)
 
conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
up10 = (UpSampling2D(size=(2, 2))(conv10)) # 16
conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up10)
 
conv11 = Conv2D(16, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv10)
up11 = (UpSampling2D(size=(2, 2))(conv11)) # 32
conv11 = Conv2D(8, 3, activation='relu', padding='same', kernel_initializer='he_normal')(up11)
 
# conv12 = Conv2D(3, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv11)
conv12 = Conv2D(3, 1, activation='relu', padding='same', kernel_initializer='he_normal')(conv11)
 
model = Model(input=inputs, output=conv12)
print(model.summary())
model.compile(optimizer=Adam(lr=1e-3), loss='mse', metrics=['accuracy'])
 
history = model.fit_generator(generator(pathX, pathY,BATCH_SIZE),
        steps_per_epoch=600, nb_epoch=EPOCH)
end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
 
 #save your training model
model.save(r'V1_828.h5')
 
#save your loss data
mse = np.array((history.history['loss']))
np.save(r'V1_828.npy', mse)

(2)测试模型

from keras.models import load_model
import numpy as np
import matplotlib.pyplot as plt
import os
import cv2
 
model = load_model('V1_828.h5')
test_images_path = 'I:\\Pascal VOC Dataset\\test\\test_images\\'
test_gt_path = 'I:\\Pascal VOC Dataset\\test\\SegmentationObject\\'
pre_path = 'I:\\Pascal VOC Dataset\\test\\pre\\'
 
X = []
for info in os.listdir(test_images_path):
 A = cv2.imread(test_images_path + info)
 X.append(A)
 # i += 1
X = np.array(X)
print(X.shape)
Y = model.predict(X)
 
groudtruth = []
for info in os.listdir(test_gt_path):
 A = cv2.imread(test_gt_path + info)
 groudtruth.append(A)
groudtruth = np.array(groudtruth)
 
i = 0
for info in os.listdir(test_images_path):
 cv2.imwrite(pre_path + info,Y[i])
 i += 1
 
a = range(10)
n = np.random.choice(a)
cv2.imwrite('prediction.png',Y[n])
cv2.imwrite('groudtruth.png',groudtruth[n])
fig, axs = plt.subplots(1, 3)
# cnt = 1
# for j in range(1):
axs[0].imshow(np.abs(X[n]))
axs[0].axis('off')
axs[1].imshow(np.abs(Y[n]))
axs[1].axis('off')
axs[2].imshow(np.abs(groudtruth[n]))
axs[2].axis('off')
 # cnt += 1
fig.savefig("imagestest.png")
plt.close()

3 效果展示

说明:从左到右依次是预测图像,真实图像,标注图像。可以看出,对于部分数据的分割效果还有待改进,主要原因还是数据集相对复杂,模型难于找到其中的规律。

Keras:Unet网络实现多类语义分割方式

以上这篇Keras:Unet网络实现多类语义分割方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持三水点靠木。

Python 相关文章推荐
python编写的最短路径算法
Mar 25 Python
自己编程中遇到的Python错误和解决方法汇总整理
Jun 03 Python
5种Python单例模式的实现方式
Jan 14 Python
Python自动化运维和部署项目工具Fabric使用实例
Sep 18 Python
python学习教程之Numpy和Pandas的使用
Sep 11 Python
pandas 条件搜索返回列表的方法
Oct 30 Python
python处理multipart/form-data的请求方法
Dec 26 Python
python 画二维、三维点之间的线段实现方法
Jul 07 Python
使用python切片实现二维数组复制示例
Nov 26 Python
在Ubuntu 20.04中安装Pycharm 2020.1的图文教程
Apr 30 Python
Keras Convolution1D与Convolution2D区别说明
May 22 Python
PyQt QMainWindow的使用示例
Mar 24 Python
Pycharm中配置远程Docker运行环境的教程图解
Jun 11 #Python
Keras 快速解决OOM超内存的问题
Jun 11 #Python
python3.6.8 + pycharm + PyQt5 环境搭建的图文教程
Jun 11 #Python
使用keras实现孪生网络中的权值共享教程
Jun 11 #Python
查看keras各种网络结构各层的名字方式
Jun 11 #Python
python datetime时间格式的相互转换问题
Jun 11 #Python
完美解决keras保存好的model不能成功加载问题
Jun 11 #Python
You might like
PHP类中Static方法效率测试代码
2010/10/17 PHP
PHP中header和session_start前不能有输出原因分析
2013/01/11 PHP
thinkPHP5.0框架应用请求生命周期分析
2017/03/25 PHP
PHP中define() 与 const定义常量的区别详解
2019/06/25 PHP
JavaScript接口实现代码 (Interfaces In JavaScript)
2010/06/11 Javascript
JS模板实现方法
2013/04/03 Javascript
JQuery中DOM事件合成用法实例分析
2015/06/13 Javascript
javascript密码强度校验代码(两种方法)
2015/08/10 Javascript
实例详解angularjs和ajax的结合使用
2015/10/22 Javascript
jQuery实现的指纹扫描效果实例(附演示与demo源码下载)
2016/01/26 Javascript
jQuery常用的一些技巧汇总
2016/03/26 Javascript
网页前端登录js按Enter回车键实现登陆的两种方法
2016/05/10 Javascript
AngularJS Bootstrap详细介绍及实例代码
2016/07/28 Javascript
详解nodejs 文本操作模块-fs模块(一)
2016/12/22 NodeJs
js如何判断是否在iframe中及防止网页被别站用iframe嵌套
2017/01/11 Javascript
JavaScript实现的选择排序算法实例分析
2017/04/14 Javascript
JS实现json的序列化和反序列化功能示例
2017/06/13 Javascript
vue项目中实现缓存的最佳方案详解
2019/07/11 Javascript
python命令行参数sys.argv使用示例
2014/01/28 Python
利用Python中SocketServer 实现客户端与服务器间非阻塞通信
2016/12/15 Python
详解Python之unittest单元测试代码
2018/01/24 Python
python实现梯度下降算法的实例详解
2020/08/17 Python
python实现企业微信定时发送文本消息的实例代码
2020/11/25 Python
实例教程 HTML5 Canvas 超炫酷烟花绽放动画实现代码
2014/11/05 HTML / CSS
Html5页面获取微信公众号的openid的方法
2020/05/12 HTML / CSS
德国百年厨具品牌WMF美国站:WMF美国
2016/09/12 全球购物
Snapfish爱尔兰:在线照片打印和个性化照片礼品
2018/09/17 全球购物
优衣库台湾官网:UNIQLO台湾
2019/02/01 全球购物
国培计划培训感言
2014/03/11 职场文书
公司开业庆典主持词
2014/03/21 职场文书
绿色家庭事迹材料
2014/05/01 职场文书
电视节目策划方案
2014/05/16 职场文书
2015年中学元旦晚会活动方案
2014/12/09 职场文书
专家推荐信怎么写
2015/03/25 职场文书
大学生读书笔记大全
2015/07/01 职场文书
导游词之山海关
2019/12/10 职场文书