深度学习 - 1.1、目标查找Faster R-CNN
ztj100 2024-12-19 17:55 33 浏览 0 评论
深度学习目标搜索,在工业上常被用于目标跟踪、缺陷定位等应用。
Faster R-CNN是目标查找算法中更高精度的应用。Faster R-CNN很多都是用Python来实现,下面发一个TensorFlow 实现Faster R-CNN的代码。
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, losses
import numpy as np
# ======================
# 1. 构建主干网络(Backbone)
# ======================
def build_backbone():
"""
构建主干网络,用于提取图像特征。
此处使用预训练的 ResNet50,去掉全连接层。
"""
# 使用预训练的 ResNet50,去掉顶层的全连接层
base_model = tf.keras.applications.ResNet50(include_top=False, input_shape=(None, None, 3))
# 提取 C5 层的输出(conv5_block3_out)
c5_output = base_model.get_layer('conv5_block3_out').output
# 构建主干网络模型
backbone = models.Model(inputs=base_model.input, outputs=c5_output)
return backbone
# ================================
# 2. 构建区域建议网络(RPN)
# ================================
def build_rpn(feature_map):
"""
构建区域建议网络,根据特征图生成 Anchor 的目标性得分和边界框回归偏移量。
参数:
- feature_map: 主干网络提取的特征图。
返回:
- rpn_cls_score: RPN 的分类得分(前景/背景)。
- rpn_bbox_pred: RPN 的边界框回归预测。
"""
# 使用 3x3 卷积提取特征,通道数为 512
x = layers.Conv2D(512, (3, 3), padding='same', activation='relu')(feature_map)
# 分类分支,输出 9 个 Anchor 的前景得分
rpn_cls_score = layers.Conv2D(9 * 1, (1, 1), activation='sigmoid')(x)
# 回归分支,输出 9 个 Anchor 的边界框回归偏移量
rpn_bbox_pred = layers.Conv2D(9 * 4, (1, 1))(x)
return rpn_cls_score, rpn_bbox_pred
# ===========================
# 3. 生成 Anchors(锚点框)
# ===========================
def generate_anchors(feature_map_shape, scales, ratios, feature_stride):
"""
生成所有的 Anchors(锚点框)。
参数:
- feature_map_shape: 特征图的形状(batch_size, height, width, channels)。
- scales: 预定义的尺度列表(例如:[128, 256, 512])。
- ratios: 预定义的长宽比列表(例如:[0.5, 1, 2])。
- feature_stride: 特征图相对于原图的下采样倍数(例如:16)。
返回:
- anchors: 所有 Anchors 的坐标数组,形状为 (num_anchors, 4)。
"""
import itertools
feature_height, feature_width = feature_map_shape[1], feature_map_shape[2]
anchors = []
# 遍历特征图的每个像素位置
for y in range(feature_height):
for x in range(feature_width):
# 计算对应于原始图像的中心坐标
center_x = x * feature_stride + feature_stride / 2
center_y = y * feature_stride + feature_stride / 2
# 遍历每种尺度和长宽比
for scale, ratio in itertools.product(scales, ratios):
# 计算 Anchor 的宽度和高度
w = scale * np.sqrt(ratio)
h = scale / np.sqrt(ratio)
# 计算 Anchor 的坐标(x1, y1, x2, y2)
x1 = center_x - w / 2
y1 = center_y - h / 2
x2 = center_x + w / 2
y2 = center_y + h / 2
anchors.append([x1, y1, x2, y2])
anchors = np.array(anchors)
return anchors
# =====================================
# 4. 生成候选区域(Proposals)
# =====================================
def generate_proposals(rpn_cls_score, rpn_bbox_pred, anchors, img_shape, pre_nms_topN=6000, post_nms_topN=300, nms_thresh=0.7):
"""
根据 RPN 的预测结果,对 Anchors 进行调整,生成候选区域。
参数:
- rpn_cls_score: RPN 的分类得分(前景概率)。
- rpn_bbox_pred: RPN 的边界框回归预测。
- anchors: 所有的 Anchors 坐标数组。
- img_shape: 输入图像的形状(高度,高度,通道数)。
- pre_nms_topN: NMS 之前保留的候选区域数量。
- post_nms_topN: NMS 之后保留的候选区域数量。
- nms_thresh: NMS 的重叠阈值。
返回:
- proposals: 最终的候选区域坐标数组。
"""
from tensorflow.image import non_max_suppression
# 获取宽度和高度
widths = anchors[:, 2] - anchors[:, 0] + 1.0
heights = anchors[:, 3] - anchors[:, 1] + 1.0
# 计算中心点
ctr_x = anchors[:, 0] + 0.5 * widths
ctr_y = anchors[:, 1] + 0.5 * heights
# 展平 RPN 的回归预测
dx = rpn_bbox_pred[0, :, :, 0::4].numpy().flatten()
dy = rpn_bbox_pred[0, :, :, 1::4].numpy().flatten()
dw = rpn_bbox_pred[0, :, :, 2::4].numpy().flatten()
dh = rpn_bbox_pred[0, :, :, 3::4].numpy().flatten()
# 应用回归偏移量,计算调整后的中心点和宽高
pred_ctr_x = dx * widths + ctr_x
pred_ctr_y = dy * heights + ctr_y
pred_w = np.exp(dw) * widths
pred_h = np.exp(dh) * heights
# 计算调整后的边界框坐标
proposals = np.stack([
pred_ctr_x - 0.5 * pred_w,
pred_ctr_y - 0.5 * pred_h,
pred_ctr_x + 0.5 * pred_w,
pred_ctr_y + 0.5 * pred_h
], axis=1)
# 裁剪边界框到图像尺寸内
proposals[:, 0::2] = np.clip(proposals[:, 0::2], 0, img_shape[1] - 1)
proposals[:, 1::2] = np.clip(proposals[:, 1::2], 0, img_shape[0] - 1)
# 获取 RPN 的目标得分(前景概率)
scores = rpn_cls_score[0, :, :, :].numpy().flatten()
# 根据得分排序,选择前 pre_nms_topN 个候选区域
order = scores.argsort()[::-1][:pre_nms_topN]
proposals = proposals[order, :]
scores = scores[order]
# 执行非极大值抑制(NMS),去除重叠的候选区域
indices = tf.image.non_max_suppression(
proposals,
scores,
max_output_size=post_nms_topN,
iou_threshold=nms_thresh
)
# 获取保留的候选区域
proposals = tf.gather(proposals, indices).numpy()
return proposals
# ================================
# 5. 定义 ROI Pooling 层
# ================================
class ROIPoolingLayer(layers.Layer):
"""
定义 ROI Pooling 层,将不同尺寸的候选区域特征图转换为固定尺寸。
"""
def __init__(self, pool_size, **kwargs):
super(ROIPoolingLayer, self).__init__(**kwargs)
self.pool_size = pool_size
def call(self, inputs):
"""
执行 ROI Pooling 操作。
参数:
- inputs[0]: 特征图(来自主干网络)。
- inputs[1]: 候选区域(proposals),形状为 (num_proposals, 4)。
返回:
- roi_pooled_features: 经过 ROI Pooling 的特征图,形状为 (num_proposals, pool_size, pool_size, channels)。
"""
feature_map = inputs[0]
proposals = inputs[1]
# 假设 batch_size = 1
box_indices = tf.zeros((tf.shape(proposals)[0],), dtype=tf.int32)
# 将坐标归一化到 0~1
img_height = tf.shape(feature_map)[1] * 16 # 假设下采样率为 16
img_width = tf.shape(feature_map)[2] * 16
normalized_proposals = proposals / tf.cast([img_width, img_height, img_width, img_height], tf.float32)
roi_pooled_features = tf.image.crop_and_resize(
feature_map,
boxes=normalized_proposals,
box_ind=box_indices,
crop_size=(self.pool_size, self.pool_size)
)
return roi_pooled_features
# ==========================
# 6. 构建 Fast R-CNN 头部
# ==========================
def build_fast_rcnn_head(roi_pooled_features, num_classes):
"""
构建 Fast R-CNN 头部,对 ROI 特征进行分类和边界框回归。
参数:
- roi_pooled_features: ROI Pooling 的输出特征,形状为 (num_proposals, pool_size, pool_size, channels)。
- num_classes: 类别数量(包括背景类)。
返回:
- cls_score: 分类得分,形状为 (num_proposals, num_classes)。
- bbox_pred: 边界框回归预测,形状为 (num_proposals, num_classes * 4)。
"""
# 展平特征图
x = layers.TimeDistributed(layers.Flatten())(tf.expand_dims(roi_pooled_features, axis=0))
# 全连接层,维度为 1024
x = layers.TimeDistributed(layers.Dense(1024, activation='relu'))(x)
x = layers.TimeDistributed(layers.Dense(1024, activation='relu'))(x)
# 分类分支,使用 softmax 激活函数
cls_score = layers.TimeDistributed(layers.Dense(num_classes, activation='softmax'))(x)
# 边界框回归分支
bbox_pred = layers.TimeDistributed(layers.Dense(num_classes * 4))(x)
# 去除批次维度
cls_score = tf.squeeze(cls_score, axis=0)
bbox_pred = tf.squeeze(bbox_pred, axis=0)
return cls_score, bbox_pred
# =============================
# 7. 构建 Faster R-CNN 模型
# =============================
def build_model(num_classes, img_shape=(None, None, 3)):
"""
构建完整的 Faster R-CNN 模型。
参数:
- num_classes: 类别数量(包括背景类)。
- img_shape: 输入图像的形状。
返回:
- model: TensorFlow Keras 模型实例。
"""
# 输入层
inputs = layers.Input(shape=img_shape)
# 1. 主干网络,提取特征图
backbone = build_backbone()
feature_map = backbone(inputs)
# 2. RPN 网络,生成目标得分和边界框回归
rpn_cls_score, rpn_bbox_pred = build_rpn(feature_map)
# 3. 生成 Anchors
anchors = generate_anchors(
feature_map_shape=tf.shape(feature_map),
scales=[128, 256, 512],
ratios=[0.5, 1, 2],
feature_stride=16 # 假设下采样率为 16
)
# 4. 生成候选区域(Proposals)
proposals = layers.Lambda(lambda x: generate_proposals(
x[0], x[1], anchors, img_shape
))([rpn_cls_score, rpn_bbox_pred])
# 5. ROI Pooling,提取固定尺寸的特征
roi_pool = ROIPoolingLayer(pool_size=7)([feature_map, proposals])
# 6. Fast R-CNN 头部,进行分类和回归
cls_score, bbox_pred = build_fast_rcnn_head(roi_pool, num_classes)
# 7. 构建模型
model = models.Model(inputs=inputs, outputs=[rpn_cls_score, rpn_bbox_pred, cls_score, bbox_pred])
return model
# ==========================
# 8. 定义损失函数
# ==========================
def rpn_class_loss(y_true, y_pred):
"""
RPN 的分类损失(使用二元交叉熵)。
"""
return tf.reduce_mean(losses.binary_crossentropy(y_true, y_pred))
def rpn_regression_loss(y_true, y_pred):
"""
RPN 的回归损失(使用均方误差)。
"""
return tf.reduce_mean(losses.mean_squared_error(y_true, y_pred))
def rcnn_class_loss(y_true, y_pred):
"""
Fast R-CNN 的分类损失(使用分类交叉熵)。
"""
return tf.reduce_mean(losses.categorical_crossentropy(y_true, y_pred))
def rcnn_regression_loss(y_true, y_pred):
"""
Fast R-CNN 的回归损失(使用均方误差)。
"""
return tf.reduce_mean(losses.mean_squared_error(y_true, y_pred))
# ==========================
# 9. 模拟数据生成器
# ==========================
def data_generator(batch_size=1, num_classes=21):
"""
模拟数据生成器,生成随机的数据用于训练示例。
实际应用中,需要使用真实的图像和标注数据。
"""
while True:
# 随机生成输入图像
images = np.random.rand(batch_size, 600, 600, 3)
# 随机生成 RPN 的标签
rpn_cls_targets = np.random.randint(0, 2, (batch_size, None, None, 9))
rpn_reg_targets = np.random.rand(batch_size, None, None, 9 * 4)
# 随机生成 Fast R-CNN 的标签
cls_targets = np.random.randint(0, num_classes, (batch_size, 1))
cls_targets = tf.one_hot(cls_targets, depth=num_classes)
bbox_targets = np.random.rand(batch_size, num_classes * 4)
yield images, [rpn_cls_targets, rpn_reg_targets, cls_targets, bbox_targets]
# ==========================
# 10. 训练模型
# ==========================
# 定义类别数量(包括背景类)
num_classes = 21
# 构建模型
model = build_model(num_classes)
# 编译模型,指定优化器和损失函数
model.compile(
optimizer=optimizers.SGD(learning_rate=0.001, momentum=0.9),
loss=[
rpn_class_loss,
rpn_regression_loss,
rcnn_class_loss,
rcnn_regression_loss
]
)
# 定义训练参数
batch_size = 1
steps_per_epoch = 1000
epochs = 10
# 开始训练模型
model.fit(
data_generator(batch_size, num_classes),
steps_per_epoch=steps_per_epoch,
epochs=epochs
)
# 保存模型权重
model.save_weights('faster_rcnn_weights.h5')
相关推荐
- sharding-jdbc实现`分库分表`与`读写分离`
-
一、前言本文将基于以下环境整合...
- 三分钟了解mysql中主键、外键、非空、唯一、默认约束是什么
-
在数据库中,数据表是数据库中最重要、最基本的操作对象,是数据存储的基本单位。数据表被定义为列的集合,数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录,每一列代表记录中的一个域。...
- MySQL8行级锁_mysql如何加行级锁
-
MySQL8行级锁版本:8.0.34基本概念...
- mysql使用小技巧_mysql使用入门
-
1、MySQL中有许多很实用的函数,好好利用它们可以省去很多时间:group_concat()将取到的值用逗号连接,可以这么用:selectgroup_concat(distinctid)fr...
- MySQL/MariaDB中如何支持全部的Unicode?
-
永远不要在MySQL中使用utf8,并且始终使用utf8mb4。utf8mb4介绍MySQL/MariaDB中,utf8字符集并不是对Unicode的真正实现,即不是真正的UTF-8编码,因...
- 聊聊 MySQL Server 可执行注释,你懂了吗?
-
前言MySQLServer当前支持如下3种注释风格:...
- MySQL系列-源码编译安装(v5.7.34)
-
一、系统环境要求...
- MySQL的锁就锁住我啦!与腾讯大佬的技术交谈,是我小看它了
-
对酒当歌,人生几何!朝朝暮暮,唯有己脱。苦苦寻觅找工作之间,殊不知今日之事乃我心之痛,难道是我不配拥有工作嘛。自面试后他所谓的等待都过去一段时日,可惜在下京东上的小金库都要见低啦。每每想到不由心中一...
- MySQL字符问题_mysql中字符串的位置
-
中文写入乱码问题:我输入的中文编码是urf8的,建的库是urf8的,但是插入mysql总是乱码,一堆"???????????????????????"我用的是ibatis,终于找到原因了,我是这么解决...
- 深圳尚学堂:mysql基本sql语句大全(三)
-
数据开发-经典1.按姓氏笔画排序:Select*FromTableNameOrderByCustomerNameCollateChinese_PRC_Stroke_ci_as//从少...
- MySQL进行行级锁的?一会next-key锁,一会间隙锁,一会记录锁?
-
大家好,是不是很多人都对MySQL加行级锁的规则搞的迷迷糊糊,一会是next-key锁,一会是间隙锁,一会又是记录锁。坦白说,确实还挺复杂的,但是好在我找点了点规律,也知道如何如何用命令分析加...
- 一文讲清怎么利用Python Django实现Excel数据表的导入导出功能
-
摘要:Python作为一门简单易学且功能强大的编程语言,广受程序员、数据分析师和AI工程师的青睐。本文系统讲解了如何使用Python的Django框架结合openpyxl库实现Excel...
- 用DataX实现两个MySQL实例间的数据同步
-
DataXDataX使用Java实现。如果可以实现数据库实例之间准实时的...
- MySQL数据库知识_mysql数据库基础知识
-
MySQL是一种关系型数据库管理系统;那废话不多说,直接上自己以前学习整理文档:查看数据库命令:(1).查看存储过程状态:showprocedurestatus;(2).显示系统变量:show...
- 如何为MySQL中的JSON字段设置索引
-
背景MySQL在2015年中发布的5.7.8版本中首次引入了JSON数据类型。自此,它成了一种逃离严格列定义的方式,可以存储各种形状和大小的JSON文档,例如审计日志、配置信息、第三方数据包、用户自定...
你 发表评论:
欢迎- 一周热门
-
-
MySQL中这14个小玩意,让人眼前一亮!
-
旗舰机新标杆 OPPO Find X2系列正式发布 售价5499元起
-
【VueTorrent】一款吊炸天的qBittorrent主题,人人都可用
-
面试官:使用int类型做加减操作,是线程安全吗
-
C++编程知识:ToString()字符串转换你用正确了吗?
-
【Spring Boot】WebSocket 的 6 种集成方式
-
PyTorch 深度学习实战(26):多目标强化学习Multi-Objective RL
-
pytorch中的 scatter_()函数使用和详解
-
与 Java 17 相比,Java 21 究竟有多快?
-
基于TensorRT_LLM的大模型推理加速与OpenAI兼容服务优化
-
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)