深度学习 - 1.1、目标查找Faster R-CNN
ztj100 2024-12-19 17:55 21 浏览 0 评论
深度学习目标搜索,在工业上常被用于目标跟踪、缺陷定位等应用。
Faster R-CNN是目标查找算法中更高精度的应用。Faster R-CNN很多都是用Python来实现,下面发一个TensorFlow 实现Faster R-CNN的代码。
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, losses
import numpy as np
# ======================
# 1. 构建主干网络(Backbone)
# ======================
def build_backbone():
"""
构建主干网络,用于提取图像特征。
此处使用预训练的 ResNet50,去掉全连接层。
"""
# 使用预训练的 ResNet50,去掉顶层的全连接层
base_model = tf.keras.applications.ResNet50(include_top=False, input_shape=(None, None, 3))
# 提取 C5 层的输出(conv5_block3_out)
c5_output = base_model.get_layer('conv5_block3_out').output
# 构建主干网络模型
backbone = models.Model(inputs=base_model.input, outputs=c5_output)
return backbone
# ================================
# 2. 构建区域建议网络(RPN)
# ================================
def build_rpn(feature_map):
"""
构建区域建议网络,根据特征图生成 Anchor 的目标性得分和边界框回归偏移量。
参数:
- feature_map: 主干网络提取的特征图。
返回:
- rpn_cls_score: RPN 的分类得分(前景/背景)。
- rpn_bbox_pred: RPN 的边界框回归预测。
"""
# 使用 3x3 卷积提取特征,通道数为 512
x = layers.Conv2D(512, (3, 3), padding='same', activation='relu')(feature_map)
# 分类分支,输出 9 个 Anchor 的前景得分
rpn_cls_score = layers.Conv2D(9 * 1, (1, 1), activation='sigmoid')(x)
# 回归分支,输出 9 个 Anchor 的边界框回归偏移量
rpn_bbox_pred = layers.Conv2D(9 * 4, (1, 1))(x)
return rpn_cls_score, rpn_bbox_pred
# ===========================
# 3. 生成 Anchors(锚点框)
# ===========================
def generate_anchors(feature_map_shape, scales, ratios, feature_stride):
"""
生成所有的 Anchors(锚点框)。
参数:
- feature_map_shape: 特征图的形状(batch_size, height, width, channels)。
- scales: 预定义的尺度列表(例如:[128, 256, 512])。
- ratios: 预定义的长宽比列表(例如:[0.5, 1, 2])。
- feature_stride: 特征图相对于原图的下采样倍数(例如:16)。
返回:
- anchors: 所有 Anchors 的坐标数组,形状为 (num_anchors, 4)。
"""
import itertools
feature_height, feature_width = feature_map_shape[1], feature_map_shape[2]
anchors = []
# 遍历特征图的每个像素位置
for y in range(feature_height):
for x in range(feature_width):
# 计算对应于原始图像的中心坐标
center_x = x * feature_stride + feature_stride / 2
center_y = y * feature_stride + feature_stride / 2
# 遍历每种尺度和长宽比
for scale, ratio in itertools.product(scales, ratios):
# 计算 Anchor 的宽度和高度
w = scale * np.sqrt(ratio)
h = scale / np.sqrt(ratio)
# 计算 Anchor 的坐标(x1, y1, x2, y2)
x1 = center_x - w / 2
y1 = center_y - h / 2
x2 = center_x + w / 2
y2 = center_y + h / 2
anchors.append([x1, y1, x2, y2])
anchors = np.array(anchors)
return anchors
# =====================================
# 4. 生成候选区域(Proposals)
# =====================================
def generate_proposals(rpn_cls_score, rpn_bbox_pred, anchors, img_shape, pre_nms_topN=6000, post_nms_topN=300, nms_thresh=0.7):
"""
根据 RPN 的预测结果,对 Anchors 进行调整,生成候选区域。
参数:
- rpn_cls_score: RPN 的分类得分(前景概率)。
- rpn_bbox_pred: RPN 的边界框回归预测。
- anchors: 所有的 Anchors 坐标数组。
- img_shape: 输入图像的形状(高度,高度,通道数)。
- pre_nms_topN: NMS 之前保留的候选区域数量。
- post_nms_topN: NMS 之后保留的候选区域数量。
- nms_thresh: NMS 的重叠阈值。
返回:
- proposals: 最终的候选区域坐标数组。
"""
from tensorflow.image import non_max_suppression
# 获取宽度和高度
widths = anchors[:, 2] - anchors[:, 0] + 1.0
heights = anchors[:, 3] - anchors[:, 1] + 1.0
# 计算中心点
ctr_x = anchors[:, 0] + 0.5 * widths
ctr_y = anchors[:, 1] + 0.5 * heights
# 展平 RPN 的回归预测
dx = rpn_bbox_pred[0, :, :, 0::4].numpy().flatten()
dy = rpn_bbox_pred[0, :, :, 1::4].numpy().flatten()
dw = rpn_bbox_pred[0, :, :, 2::4].numpy().flatten()
dh = rpn_bbox_pred[0, :, :, 3::4].numpy().flatten()
# 应用回归偏移量,计算调整后的中心点和宽高
pred_ctr_x = dx * widths + ctr_x
pred_ctr_y = dy * heights + ctr_y
pred_w = np.exp(dw) * widths
pred_h = np.exp(dh) * heights
# 计算调整后的边界框坐标
proposals = np.stack([
pred_ctr_x - 0.5 * pred_w,
pred_ctr_y - 0.5 * pred_h,
pred_ctr_x + 0.5 * pred_w,
pred_ctr_y + 0.5 * pred_h
], axis=1)
# 裁剪边界框到图像尺寸内
proposals[:, 0::2] = np.clip(proposals[:, 0::2], 0, img_shape[1] - 1)
proposals[:, 1::2] = np.clip(proposals[:, 1::2], 0, img_shape[0] - 1)
# 获取 RPN 的目标得分(前景概率)
scores = rpn_cls_score[0, :, :, :].numpy().flatten()
# 根据得分排序,选择前 pre_nms_topN 个候选区域
order = scores.argsort()[::-1][:pre_nms_topN]
proposals = proposals[order, :]
scores = scores[order]
# 执行非极大值抑制(NMS),去除重叠的候选区域
indices = tf.image.non_max_suppression(
proposals,
scores,
max_output_size=post_nms_topN,
iou_threshold=nms_thresh
)
# 获取保留的候选区域
proposals = tf.gather(proposals, indices).numpy()
return proposals
# ================================
# 5. 定义 ROI Pooling 层
# ================================
class ROIPoolingLayer(layers.Layer):
"""
定义 ROI Pooling 层,将不同尺寸的候选区域特征图转换为固定尺寸。
"""
def __init__(self, pool_size, **kwargs):
super(ROIPoolingLayer, self).__init__(**kwargs)
self.pool_size = pool_size
def call(self, inputs):
"""
执行 ROI Pooling 操作。
参数:
- inputs[0]: 特征图(来自主干网络)。
- inputs[1]: 候选区域(proposals),形状为 (num_proposals, 4)。
返回:
- roi_pooled_features: 经过 ROI Pooling 的特征图,形状为 (num_proposals, pool_size, pool_size, channels)。
"""
feature_map = inputs[0]
proposals = inputs[1]
# 假设 batch_size = 1
box_indices = tf.zeros((tf.shape(proposals)[0],), dtype=tf.int32)
# 将坐标归一化到 0~1
img_height = tf.shape(feature_map)[1] * 16 # 假设下采样率为 16
img_width = tf.shape(feature_map)[2] * 16
normalized_proposals = proposals / tf.cast([img_width, img_height, img_width, img_height], tf.float32)
roi_pooled_features = tf.image.crop_and_resize(
feature_map,
boxes=normalized_proposals,
box_ind=box_indices,
crop_size=(self.pool_size, self.pool_size)
)
return roi_pooled_features
# ==========================
# 6. 构建 Fast R-CNN 头部
# ==========================
def build_fast_rcnn_head(roi_pooled_features, num_classes):
"""
构建 Fast R-CNN 头部,对 ROI 特征进行分类和边界框回归。
参数:
- roi_pooled_features: ROI Pooling 的输出特征,形状为 (num_proposals, pool_size, pool_size, channels)。
- num_classes: 类别数量(包括背景类)。
返回:
- cls_score: 分类得分,形状为 (num_proposals, num_classes)。
- bbox_pred: 边界框回归预测,形状为 (num_proposals, num_classes * 4)。
"""
# 展平特征图
x = layers.TimeDistributed(layers.Flatten())(tf.expand_dims(roi_pooled_features, axis=0))
# 全连接层,维度为 1024
x = layers.TimeDistributed(layers.Dense(1024, activation='relu'))(x)
x = layers.TimeDistributed(layers.Dense(1024, activation='relu'))(x)
# 分类分支,使用 softmax 激活函数
cls_score = layers.TimeDistributed(layers.Dense(num_classes, activation='softmax'))(x)
# 边界框回归分支
bbox_pred = layers.TimeDistributed(layers.Dense(num_classes * 4))(x)
# 去除批次维度
cls_score = tf.squeeze(cls_score, axis=0)
bbox_pred = tf.squeeze(bbox_pred, axis=0)
return cls_score, bbox_pred
# =============================
# 7. 构建 Faster R-CNN 模型
# =============================
def build_model(num_classes, img_shape=(None, None, 3)):
"""
构建完整的 Faster R-CNN 模型。
参数:
- num_classes: 类别数量(包括背景类)。
- img_shape: 输入图像的形状。
返回:
- model: TensorFlow Keras 模型实例。
"""
# 输入层
inputs = layers.Input(shape=img_shape)
# 1. 主干网络,提取特征图
backbone = build_backbone()
feature_map = backbone(inputs)
# 2. RPN 网络,生成目标得分和边界框回归
rpn_cls_score, rpn_bbox_pred = build_rpn(feature_map)
# 3. 生成 Anchors
anchors = generate_anchors(
feature_map_shape=tf.shape(feature_map),
scales=[128, 256, 512],
ratios=[0.5, 1, 2],
feature_stride=16 # 假设下采样率为 16
)
# 4. 生成候选区域(Proposals)
proposals = layers.Lambda(lambda x: generate_proposals(
x[0], x[1], anchors, img_shape
))([rpn_cls_score, rpn_bbox_pred])
# 5. ROI Pooling,提取固定尺寸的特征
roi_pool = ROIPoolingLayer(pool_size=7)([feature_map, proposals])
# 6. Fast R-CNN 头部,进行分类和回归
cls_score, bbox_pred = build_fast_rcnn_head(roi_pool, num_classes)
# 7. 构建模型
model = models.Model(inputs=inputs, outputs=[rpn_cls_score, rpn_bbox_pred, cls_score, bbox_pred])
return model
# ==========================
# 8. 定义损失函数
# ==========================
def rpn_class_loss(y_true, y_pred):
"""
RPN 的分类损失(使用二元交叉熵)。
"""
return tf.reduce_mean(losses.binary_crossentropy(y_true, y_pred))
def rpn_regression_loss(y_true, y_pred):
"""
RPN 的回归损失(使用均方误差)。
"""
return tf.reduce_mean(losses.mean_squared_error(y_true, y_pred))
def rcnn_class_loss(y_true, y_pred):
"""
Fast R-CNN 的分类损失(使用分类交叉熵)。
"""
return tf.reduce_mean(losses.categorical_crossentropy(y_true, y_pred))
def rcnn_regression_loss(y_true, y_pred):
"""
Fast R-CNN 的回归损失(使用均方误差)。
"""
return tf.reduce_mean(losses.mean_squared_error(y_true, y_pred))
# ==========================
# 9. 模拟数据生成器
# ==========================
def data_generator(batch_size=1, num_classes=21):
"""
模拟数据生成器,生成随机的数据用于训练示例。
实际应用中,需要使用真实的图像和标注数据。
"""
while True:
# 随机生成输入图像
images = np.random.rand(batch_size, 600, 600, 3)
# 随机生成 RPN 的标签
rpn_cls_targets = np.random.randint(0, 2, (batch_size, None, None, 9))
rpn_reg_targets = np.random.rand(batch_size, None, None, 9 * 4)
# 随机生成 Fast R-CNN 的标签
cls_targets = np.random.randint(0, num_classes, (batch_size, 1))
cls_targets = tf.one_hot(cls_targets, depth=num_classes)
bbox_targets = np.random.rand(batch_size, num_classes * 4)
yield images, [rpn_cls_targets, rpn_reg_targets, cls_targets, bbox_targets]
# ==========================
# 10. 训练模型
# ==========================
# 定义类别数量(包括背景类)
num_classes = 21
# 构建模型
model = build_model(num_classes)
# 编译模型,指定优化器和损失函数
model.compile(
optimizer=optimizers.SGD(learning_rate=0.001, momentum=0.9),
loss=[
rpn_class_loss,
rpn_regression_loss,
rcnn_class_loss,
rcnn_regression_loss
]
)
# 定义训练参数
batch_size = 1
steps_per_epoch = 1000
epochs = 10
# 开始训练模型
model.fit(
data_generator(batch_size, num_classes),
steps_per_epoch=steps_per_epoch,
epochs=epochs
)
# 保存模型权重
model.save_weights('faster_rcnn_weights.h5')
相关推荐
- Java网络编程(JAVA网络编程技术)
-
网络编程三要素1.IP地址:表示设备在网络中的地址,是网络中设备的唯一标识2.端口号:应用程序在设备中唯一的标识3.协议:连接和数据在网络中传输的规则。InetAddress类Java中也有一个...
- 字节Java全能手册火了!多线程/网络/性能调优/框架啥都有
-
前言在这个技术不断更新的年代,跟不上时代变化的速度就会被刷掉,特别是咱们程序员这一群体,技术不断更新的同时也要同时进步,不然长江后浪推前浪,前浪......一个程序员从一个什么都不懂的小白在学到有一定...
- 一分钟了解java网络编程(java基础网络编程)
-
一、OSI七层网络模型应用层:Http协议、电子邮件传输、文件服务器等;表示层:数据转换,解决不同系统的兼容问题(跨语言);会话层:建立与应用程序的会话连接;传输层:提供了端口号和传输协议(TPC/U...
- Java编程-高并发情况下接口性能优化实践-提升吞吐量TPS
-
记得前段时间工作中接到一个任务是优化一个下单接口的性能提高接口的吞吐量TPS,前期通过arthas工具跟踪接口的具体方法调用链路及耗时,发现了影响此接口的性能瓶颈主要是加锁的方式,后来变更了锁的方式...
- socket 断线重连和心跳机制如何实现?
-
一、socket概念1.套接字(socket)是网络通信的基石,是支持TCP/IP协议的网络通信的基本操作单元。它是网络通信过程中端点的抽象表示,包含进行网络通信必须的五种信息:连接使用的协议,...
- 迅速了解-Java网络编程(java基础网络编程)
-
Java网络编程在JavaSE阶段,我们学习了I/O流,既然I/O流如此强大,那么能否跨越不同的主机进行I/O操作呢?这就要提到Java的网络编程了。...
- Java网络编程详解(java 网络编程)
-
网络编程基础知识最!最!最!重要网络编程基础概念网络编程不等于网站编程,网络编程即使用套接字(socket)来达到各进程间的通信,现在一般称为TCP/IP编程;网络编程分为服务端和客户端。服务端就相当...
- 「开源推荐」高性能网络通信框架 HP-Socket v5.7.2
-
简介HP-Socket是一套通用的高性能TCP/UDP/HTTP通信框架,包含服务端组件、客户端组件和Agent组件,广泛适用于各种不同应用场景的TCP/UDP/HTTP通信系统,提供C/...
- Java网络编程从入门到精通:打造属于你的网络世界
-
Java网络编程从入门到精通:打造属于你的网络世界在当今这个信息爆炸的时代,网络编程已经成为程序员必不可少的一项技能。而Java作为一种功能强大且广泛使用的编程语言,在网络编程领域也有着举足轻重的地位...
- 5分钟读懂C#中TcpClient、TcpListener和Socket三个类的角色
-
一、核心功能与定位1.Socket类:底层通信的基石-位于System.Net.Sockets命名空间,提供对网络协议栈的直接操作,支持TCP、UDP等多种协议。-手动管理连接细节:需...
- (三)谈谈 IO 模型(Socket 编程篇)
-
快过年啦,估计很多朋友已在摸鱼的路上。而我为了兄弟们年后的追逐,却在苦苦寻觅、规划,导致文章更新晚了些,各位猿粉谅解。上期分享,我们结合新春送祝福的场景,通过一坨坨的代码让BIO、NIO编程过程呈...
- 大数据编程入门:Java网络编程(大数据 编程)
-
如果想要编写出一个可以运行在多个设备上的程序,应该怎么做呢?答案是网络编程,今天小编将为大家带来大数据编程入门:Java网络编程。一、网络编程概念网络编程是指编写在通过网络连接的多个设备(计算机)上运...
- 基于JAVA的社交聊天室(java聊天设计与实现)
-
基于Java的社交聊天室一、前言随着互联网技术的迅速发展,实时通信和在线社交已成为人们日常生活的重要组成部分。基于Java的社交聊天室系统,凭借其跨平台、高性能和安全性等特点,为用户提供了一个集中、开...
- java-socket长连接demo体验(java socket长连接)
-
作者:DavidDing来源:https://zhuanlan.zhihu.com/p/56135195一、前言最近公司在预研设备app端与服务端的交互方案,主要方案有:服务端和app端通过阿里i...
- JAVA数据库编程(java数据库编程指南)
-
预计更新###第一节:什么是JAVA-JAVA的背景和历史-JAVA的特点和应用领域-如何安装和配置JAVA开发环境###第二节:JAVA基础语法-JAVA的基本数据类型和变量-运算符和...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)