用U-Net提取航拍图中的建筑物轮廓
ztj100 2024-12-19 17:56 43 浏览 0 评论
在我的硕士课程中,我选修了一门名为“深度学习在高维数据分析和图像处理中的应用”的课程。这个名字听起来有点吓人,但它却是我收获最多的课程之一!
在我们的一个项目中,我们的任务是构建一个 U-Net 来从航拍图像中识别和绘制建筑物。起初,这项任务让人望而生畏,但一旦我投入其中,我发现我实际上很享受这个过程。它很有挑战性,很耗时,但完全值得。
那么让我们进入计算机视觉和高维数据的世界吧;我将向你介绍如何使用 U-Net 架构从航拍图像中提取建筑物足迹——同样的方法帮助我实现了 94.7% 的准确率和 76.7% 的 Dice 指标。如果你愿意,可以在我的 GitHub 存储库中探索该项目的代码。让我们开始吧!
1、什么是 U-Net?
U-Net 是一种卷积神经网络 (CNN),专为图像分割任务而设计。U-Net 最初是为生物医学图像分割而开发的,其架构由编码器和解码器组成,这使其呈现出有趣的 U 形。
- 编码器(收缩路径):编码器是一系列卷积层(这些层可找到边缘和图案等特征),然后是最大池化层(这些层有助于减小图像大小,以帮助模型真正调整最重要的特征)。
- 解码器(扩展路径):解码器由上采样层(这些层可恢复原始图像大小)和更多卷积层组成。解码器基本上会细化和重建输出以创建一个漂亮的分割图。
- 跳过连接:U-Net 使用跳过连接将编码器和解码器中的相应层连接起来,有助于保留精细细节并提高分割准确性。将跳过连接视为子弹头列车,它直接从 U-net 的一侧飞到另一侧,跳过所有交通并确保重要乘客(详细信息)安全到达另一端。
2、数据准备
对于这个项目,我使用了以下数据集(可在我的 GitHub 存储库中找到):
- 图像:3,347 个大小为 256×256×3 的彩色栅格,每个代表马萨诸塞州的 300 平方米区域。
- 标签:从 OpenStreetMap 建筑物足迹派生的二进制掩码,指示哪些像素对应于建筑物。
未格式化的数据由 Minh 等人 (2013) 提供,是公开可用的,可以在此处找到。
数据集分为 70% 的训练集、15% 的验证集和 15% 的测试集。在将图像输入模型之前,我将像素值标准化为范围 [0, 1]。
以下是输入输出对的快速可视化:
- 左图:航拍图像。
- 右图:二元掩模,白色区域代表建筑物足迹。
3、类别不平衡
数据集显示出相当明显的类别不平衡,非建筑物像素远多于建筑物像素——这是分割任务中非常常见的挑战。
3.1 处理类别不平衡
有很多方法可以解决类别不平衡问题,例如过采样、欠采样和数据增强,我尝试了其中一些技术,但对我的模型帮助最大的一件事是创建一个特殊的损失函数,它结合了二元交叉熵 (BCE) 和骰子损失。
这如何解决类别不平衡问题?
所以基本上损失函数的工作原理是测量模型预测与基本事实的偏差——任何模型的目标都是尽量减少这种损失。
鉴于我们只关注二元分割(区分建筑物和背景),损失函数的强势选择是 BCE,因为它测量像素级损失,本质上它会查看每个预测的小像素以及它与真实像素的距离。仅使用这种类型的损失的问题在于,损失将由多数类(背景像素)主导。
这就是 Dice Loss 的作用所在——Dice Loss 通过查看预测和地面真实蒙版之间的重叠来测量损失
要理解 Dice Loss,首先让我们了解如何判断预测:
- 真正例 (TP):建筑物正确识别为建筑物。
- 假正例 (FP):背景像素被错误地标记为建筑物。
- 假负例 (FN):建筑物被误认为背景。
- 真负例 (TN):背景正确标记为背景。
Dice 系数(本质上是图像分割的 F1 分数)平衡了精度(模型标记建筑物像素的准确度)和召回率(模型识别的实际建筑物像素数)。
基本上就像在说——“让我们把重点放在建筑物上,我们正确识别了多少个建筑物,而没有犯太多错误?”
虽然 Dice 系数(F1 分数)是衡量模型性能的指标(越高越好),但 Dice Loss 用于最小化训练期间的错误(越低越好)
结合 BCE 和 Dice Loss
通过结合这两个损失函数,我们可以从逐像素监督和 Dice 对重叠的关注中受益,确保平衡对两个类的敏感度。
4、组装模型
在我进入模型架构之前,我想快速介绍两种显着提高模型性能的关键方法——这些方法在训练过程中经常被忽视。
4.1 空间 Dropout
为了防止模型过度拟合(对训练数据的学习太好以至于无法很好地预测新数据),我在模型中添加了一个空间 Dropout 层。在训练期间,它会随机隐藏每个图像中一定比例的像素。这基本上迫使模型停止依赖单个像素,而是关注更大的图景,学习建筑物周围的环境和模式。
这是一种非常酷的方法,Dropout 不仅限于图像模型中的空间层——您也可以将它用于其他类型的神经网络!
4.2 核初始化程序
我认为,在构建神经网络时,最容易被忽视的组件之一就是内核初始化程序。这可能是我在硕士课程期间最难理解的概念之一,但一旦我最终掌握了这个概念,我就明白了它在神经网络中的重要性。
那么,它到底是什么呢?简单来说,它决定了在训练开始之前如何设置网络层的权重。你可以把它想象成给房子打地基。如果地基薄弱或不平整,无论房子设计得多么好,结构都不会稳定。
对于我的 U-Net 模型,最好的初始化程序是 LeCun Normal。它根据输入层的大小按比例缩放权重,这反过来有助于最大限度地降低梯度在流经网络时变得太小(梯度消失)或太大(梯度爆炸)的风险。
4.3 Keras 函数式 API
好的,现在我们开始构建 U-Net 模型的细节。使用 Keras API,我定义了一个简化的 U-Net 模型。
Keras Functional API:一种在 Python 中构建深度学习模型的超级灵活方法,它允许您通过像构建块一样连接不同的层来定义自定义架构?。
4.4 编码器块
记住前面的内容,这是 U-Net 的收缩路径 - 它获取输入图像,捕获特征并缩小空间信息(减小图像大小)。
每个块由以下内容组成:
- 两个卷积层:每个层都有一个 relu 激活和 lecun_normal 内核初始化。
- 空间 Dropout:在每个卷积层之后添加正则化,dropout 率为 10%。
- 最大池化:将空间维度下采样 2 倍,减小图像大小,但确保保留关键特征。
def encoder_block(filters, inputs, dropout_rate=0.1, kernel_initializer='lecun_normal'):
x = Conv2D(filters, kernel_size=(3, 3), padding='same', strides=1, activation='relu',
kernel_initializer=kernel_initializer)(inputs)
x = SpatialDropout2D(dropout_rate)(x)
s = Conv2D(filters, kernel_size=(3, 3), padding='same', strides=1, activation='relu',
kernel_initializer=kernel_initializer)(x)
s = SpatialDropout2D(dropout_rate)(s)
p = MaxPooling2D(pool_size=(2, 2), padding='same')(s)
return s, p
我在我的模型中使用了四个块(当我尝试使用超过这个或更多时过滤器,模型表现更差):
- 块 1:32 个过滤器
- 块 2:64 个过滤器
- 块 3:128 个过滤器
- 块 4:256 个过滤器
4.5 瓶颈层
模型的这一部分充当 U-Net 的瓶颈,位于编码器和解码器之间。它能够在最小的空间维度上操作时捕获最深的特征(超酷)。它由以下部分构成:
- 两个卷积层
- 空间 Dropout
基础有助于为解码器从这些抽象特征重建高分辨率图像奠定基础。
def baseline_layer(filters, inputs, dropout_rate=0.1, kernel_initializer='lecun_normal'):
x = Conv2D(filters, kernel_size=(3, 3), padding='same', strides=1, activation='relu',
kernel_initializer=kernel_initializer)(inputs)
x = SpatialDropout2D(dropout_rate)(x)
x = Conv2D(filters, kernel_size=(3, 3), padding='same', strides=1, activation='relu',
kernel_initializer=kernel_initializer)(x)
x = SpatialDropout2D(dropout_rate)(x)
return x
4.6 解码器块
如果您还记得之前的内容,解码器是 U-Net 的扩展路径,它通过上采样和组合特征来重建图像。每个解码器块由以下部分组成:
- 上采样:使用具有 relu 激活的转置卷积层将空间维度加倍。
- 跳过连接:这些连接结合了相应编码器块的功能,确保网络保留了在下采样过程中丢失的所有细粒度细节。请记住,这些就像从一端开出的子弹头列 到另一个。
- 两个卷积层:细化上采样特征。
- 空间 Dropout:帮助模型学习泛化(即使在模型的这个阶段)。
def decoder_block(filters, connections, inputs, dropout_rate=0.1, kernel_initializer='lecun_normal'):
x = Conv2DTranspose(filters, kernel_size=(2, 2), padding='same', activation='relu', strides=2,
kernel_initializer=kernel_initializer)(inputs)
skip_connections = concatenate([x, connections], axis=-1)
x = Conv2D(filters, kernel_size=(3, 3), padding='same', activation='relu',
kernel_initializer=kernel_initializer)(skip_connections)
x = SpatialDropout2D(dropout_rate)(x)
x = Conv2D(filters, kernel_size=(3, 3), padding='same', activation='relu',
kernel_initializer=kernel_initializer)(x)
x = SpatialDropout2D(dropout_rate)(x)
return x
4.7 最终输出层
最后但并非最不重要的一点是,模型的最后一层是一个 1x1 卷积层,具有单个输出通道和 S 型激活函数。该层预测每个像素属于目标类别(建筑物)或背景的概率。
outputs = Conv2D(1, 1, activation = 'sigmoid')(d4)
4.8 将所有部分放在一起
这里我们有了最终的杰作:
def unet():
inputs = Input(shape = (256, 256, 3)) #defines the input layer and shape of images
#encoder
s1, p1 = encoder_block(32, inputs = inputs)
s2, p2 = encoder_block(64, inputs = p1)
s3, p3 = encoder_block(128, inputs = p2)
s4, p4 = encoder_block(256, inputs = p3)
#bottleneck
baseline = baseline_layer(512, p4)
#decoder
d1 = decoder_block(256, s4, baseline)
d2 = decoder_block(128, s3, d1)
d3 = decoder_block(64, s2, d2)
d4 = decoder_block(32, s1, d3)
#output function for binary classification of pixels
outputs = Conv2D(1, 1, activation = 'sigmoid')(d4)
#finalizing the model
model = Model(inputs = inputs, outputs = outputs, name = 'Unet')
return model
5、训练模型
以下是我训练模型的细分:
- 学习率:我使用 TensorFlow 的 ExponentialDecay 逐渐降低学习率(从 0.001 开始)。
- 优化器:我使用了 Adam 优化器,它可以动态调整每个参数的学习率,利用过去的梯度。
- 批次大小:8(降低以防止内存问题)。
- 时期:50(在 3 个时期后验证损失没有改善后提前停止)。
训练后,我通过绘制训练准确率与验证准确率以及训练损失与验证损失的图表来检查模型的性能。
训练和验证准确率都呈现上升趋势,这意味着模型能够从训练数据中学习并在验证数据上表现良好。
训练和验证损失都在稳步下降,这意味着模型能够非常有效地将损失降至最低。这都是好消息!
6、评估模型
虽然准确率显示了总体正确预测的百分比,但它并没有考虑到我们数据集中的类别不平衡。例如,如果背景像素的数量远远超过建筑物像素,那么模型只需在大多数情况下预测“背景”就可以实现高精度。所以这意味着高精度并不一定意味着模型在识别建筑物方面表现良好。
那么,我们如何才能真正评估模型分割建筑物的能力呢?
这就是 Dice 指标再次发挥作用的地方!与准确率不同,Dice 指标衡量预测蒙版与地面真实蒙版的重叠程度。 Dice 分数越高,性能越好,这意味着模型能够准确捕捉到更多预测建筑区域与实际建筑区域的交集。
以下是我用来计算 Dice 指标的代码:
def dice_metric(y_true, y_pred):
"""Calculate Dice Coefficient for ground truth and predicted masks."""
y_pred = tf.cast(y_pred > 0.5, tf.float32) #threshold set to 0.5
intersection = tf.reduce_sum(y_true * y_pred)
total_sum = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)
dice = tf.math.divide_no_nan(2 * intersection, total_sum)
return dice
6.1 测试集结果
在测试集上评估模型,我获得了以下指标:
- 准确率:94.7%(预测的总体正确性)
- 精确率:75.2%(预测建筑物与实际建筑物的百分比)建筑物)
- 召回率:78.9%(模型正确识别的实际建筑物百分比)
- Dice 指标:76.7%(测量预测和地面实况掩码之间的重叠)
这是什么意思?
在一篇研究论文中,作者强调,Dice 得分高于 0.7 表示在图像分割任务中表现良好(来源)。Dice 得分为 76.7%,该模型在将建筑物与背景分割开来方面表现良好。
因此,虽然准确度给出了一般的性能感觉,但 Dice、精确度和召回率等指标可以更好地理解模型处理特定任务的能力
分割建筑物。根据这些结果,我们似乎有一个赢家!
6.2 使用预测概率可视化结果
为了更好地了解模型的表现,我以多种方式可视化预测,并将它们与地面实况标签进行比较。此步骤有助于了解模型的优势和不足之处,使其成为评估分割任务(如识别建筑物)的关键部分。
查看预测标签(右上角)并将其与测试标签(第二张图)进行比较,我们可以看到,在大多数情况下,模型成功识别了建筑物,尽管并不完美。
假阳性图像突出显示了模型错误地将背景区域分类为建筑物的区域,而假阴性图像揭示了模型错误地将建筑物标记为背景的位置。这些错误表明仍有改进空间。但是,总体而言,该模型在分割建筑物足迹方面做得相当不错,并显示出巨大的实际应用潜力!
7、结束语
你做到了!所以,以下是我们从这次经历中可以学到的东西:
- 模型架构很重要:事实证明,U-Net 架构是图像分割的非常强大的选择,尤其是对于识别建筑物足迹等任务。
- 损失函数产生影响:结合二元交叉熵和 Dice 损失可确保模型平衡整体准确性,重点是正确分割少数类(我们案例中的建筑物)。当我们处理类别不平衡时,这一点非常重要。
- 正则化:空间 dropout 不仅有助于防止过度拟合,而且还提高了模型推广到看不见的数据的能力。这一步使训练过程更加稳健。
- 超越准确性的评估:像 Dice 这样的指标可以更好地洞察模型的性能,尤其是对于类别不平衡会扭曲简单准确性度量的任务。
下一步:总有改进的空间,我相信下一步合乎逻辑的做法是实现数据增强。这涉及生成现有图像的略微改变的副本,例如翻转、旋转或调整亮度。通过多样化训练数据,模型可以学习更精细的细节,更好地区分建筑物和背景。
相关推荐
- SpringBoot整合SpringSecurity+JWT
-
作者|Sans_https://juejin.im/post/5da82f066fb9a04e2a73daec一.说明SpringSecurity是一个用于Java企业级应用程序的安全框架,主要包含...
- 「计算机毕设」一个精美的JAVA博客系统源码分享
-
前言大家好,我是程序员it分享师,今天给大家带来一个精美的博客系统源码!可以自己买一个便宜的云服务器,当自己的博客网站,记录一下自己学习的心得。开发技术博客系统源码基于SpringBoot,shiro...
- springboot教务管理系统+微信小程序云开发附带源码
-
今天给大家分享的程序是基于springboot的管理,前端是小程序,系统非常的nice,不管是学习还是毕设都非常的靠谱。本系统主要分为pc端后台管理和微信小程序端,pc端有三个角色:管理员、学生、教师...
- SpringBoot+LayUI后台管理系统开发脚手架
-
源码获取方式:关注,转发之后私信回复【源码】即可免费获取到!项目简介本项目本着避免重复造轮子的原则,建立一套快速开发JavaWEB项目(springboot-mini),能满足大部分后台管理系统基础开...
- Spring Boot的Security安全控制——认识SpringSecurity!
-
SpringBoot的Security安全控制在Web项目开发中,安全控制是非常重要的,不同的人配置不同的权限,这样的系统才安全。最常见的权限框架有Shiro和SpringSecurity。Shi...
- 前同事2024年接私活已入百万,都是用这几个开源的SpringBoot项目
-
前言不得不佩服SpringBoot的生态如此强大,今天给大家推荐几款优秀的后台管理系统,小伙伴们再也不用从头到尾撸一个项目了。SmartAdmin...
- 值得学习的15 个优秀开源的 Spring Boot 学习项目
-
SpringBoot算是目前Java领域最火的技术栈了,除了书呢?当然就是开源项目了,今天整理15个开源领域非常不错的SpringBoot项目供大家学习,参考。高富帅的路上只能帮你到这里了,...
- 开发企业官网就用这个基于SpringBoot的CMS系统,真香
-
前言推荐这个项目是因为使用手册部署手册非常...
- 2021年超详细的java学习路线总结—纯干货分享
-
本文整理了java开发的学习路线和相关的学习资源,非常适合零基础入门java的同学,希望大家在学习的时候,能够节省时间。纯干货,良心推荐!第一阶段:Java基础...
- jeecg-boot学习总结及使用心得(jeecgboot简单吗)
-
jeecg-boot学习总结及使用心得1.jeecg-boot是一个真正前后端分离的模版项目,便于二次开发,使用的都是较流行的新技术,后端技术主要有spring-boot2.x、shiro、Myb...
- 后勤集团原料管理系统springboot+Layui+MybatisPlus+Shiro源代码
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目描述后勤集团原料管理系统spr...
- 白卷开源SpringBoot+Vue的前后端分离入门项目
-
简介白卷是一个简单的前后端分离项目,主要采用Vue.js+SpringBoot技术栈开发。除了用作入门练习,作者还希望该项目可以作为一些常见Web项目的脚手架,帮助大家简化搭建网站的流程。...
- Spring Security 自动踢掉前一个登录用户,一个配置搞定
-
登录成功后,自动踢掉前一个登录用户,松哥第一次见到这个功能,就是在扣扣里边见到的,当时觉得挺好玩的。自己做开发后,也遇到过一模一样的需求,正好最近的SpringSecurity系列正在连载,就结...
- 收藏起来!这款开源在线考试系统,我爱了
-
大家好,我是为广大程序员兄弟操碎了心的小编,每天推荐一个小工具/源码,装满你的收藏夹,每天分享一个小技巧,让你轻松节省开发效率,实现不加班不熬夜不掉头发,是我的目标!今天小编推荐一款基于Spr...
- Shiro框架:认证和授权原理(shiro权限认证流程)
-
优质文章,及时送达前言Shiro作为解决权限问题的常用框架,常用于解决认证、授权、加密、会话管理等场景。本文将对Shiro的认证和授权原理进行介绍:Shiro可以做什么?、Shiro是由什么组成的?举...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- SpringBoot整合SpringSecurity+JWT
- 「计算机毕设」一个精美的JAVA博客系统源码分享
- springboot教务管理系统+微信小程序云开发附带源码
- SpringBoot+LayUI后台管理系统开发脚手架
- Spring Boot的Security安全控制——认识SpringSecurity!
- 前同事2024年接私活已入百万,都是用这几个开源的SpringBoot项目
- 值得学习的15 个优秀开源的 Spring Boot 学习项目
- 开发企业官网就用这个基于SpringBoot的CMS系统,真香
- 2021年超详细的java学习路线总结—纯干货分享
- jeecg-boot学习总结及使用心得(jeecgboot简单吗)
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)