MMDetection3D数据加载详解
ztj100 2025-01-05 00:59 28 浏览 0 评论
OpenmmLab众多框架加载数据方式如出一辙,正好最近有项目需要使用MMDetection3D,以此为例,记录下数据加载、变换、增强,至进入网络的流程。
BaseDataset
数据在进入网络之前会被包装为DataLoader,而DataLoader在配置文件中如下:
train_dataloader = dict(
batch_size = 8,
num_workers = 8,
persistent_workers = True,
sampler = dict(type = 'DefaultSampler', shuffle = True),
dataset = dict(
type = dataset_type,
data_root = data_root,
ann_file = 'semantickitti_infos_train.pkl',
pipeline = train_pipeline,
metainfo = metainfo,
modality = input_modality,
ignore_index = ignore_label,
backend_args = backend_args))
参数和原生的DataLoader并无太大区别,这里,我们重点关注dataset,同样是一系列参数,其中,type的真实值为预定义的数据集类型或自定义的数据集类型。以SemanticKittiDataset为例,代码如下:
class SemanticKittiDataset(Seg3DDataset):
# 先行省略
其核心功能基本全在父类Seg3DDataset中,代码如下:
class Seg3DDataset(BaseDataset):
# 先行省略
再顺藤摸瓜,查看BaseDataset类代码:
class BaseDataset(Dataset):
这里的父类Dataset类如大家所想,即为pytorch框架中的经典Dataset,无需赘述,我们直接查看其子类BaseDataset中的__getitem__方法:
def __getitem__(self, idx: int) -> dict:
# 省略部分代码
if self.test_mode:
data = self.prepare_data(idx)
return data
for _ in range(self.max_refetch + 1):
data = self.prepare_data(idx)
if data is None:
idx = self._rand_another()
continue
return data
据此,我们知道单例数据的获取方法为prepare_data(idx),核心代码如下:
def prepare_data(self, idx) -> Any:
data_info = self.get_data_info(idx)
return self.pipeline(data_info)
get_data_info方法可以根据索引号idx获取到对应数据的信息,包含数据的路径、名称等元信息,而pipeline则根据数据的元信息获取真正的数据。在构造方法中,我们可以看到pipeline如下:
self.pipeline = Compose(pipeline)
继续探索Compose,其代码如下:
class Compose:
# 省略部分代码
def __init__(self, transforms: Optional[Sequence[Union[dict, Callable]]]):
self.transforms: List[Callable] = []
if transforms is None:
transforms = []
for transform in transforms:
if isinstance(transform, dict):
transform = TRANSFORMS.build(transform)
if not callable(transform):
raise TypeError(f'transform should be a callable object, '
f'but got {type(transform)}')
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
def __call__(self, data: dict) -> Optional[dict]:
for t in self.transforms:
data = t(data)
if data is None:
return None
return data
在Compose类中,__call__方法用来回调,返回数据,而数据会被依次送入transforms序列进行处理,transforms是一个回调序列或字典序列,若是后者,则也会被当作配置而创建为回调序列。
根据pipeline的构造,其参数即为数据配置文件中的pipeline,配置代码如下:
train_pipeline = [
dict(
type = 'LoadPointsFromFile',
coord_type = 'LIDAR',
load_dim = 4,
use_dim = 4,
backend_args = backend_args
),
dict(
type = 'LoadAnnotations3D',
with_bbox_3d = False,
with_label_3d = False,
with_seg_3d = True,
seg_3d_dtype = 'np.int32',
seg_offset = 2**16,
dataset_type = 'semantickitti',
backend_args = backend_args),
dict(type = 'PointSegClassMapping'),
dict(type = 'PointSample', num_points = 0.9),
# 省略部分代码
dict(type = 'Pack3DDetInputs', keys = ['points', 'pts_semantic_mask'])
]
至此,数据加载流程逐渐明朗,对数据的所有操作(包括读取、增强、封装等)全部分解于pipeline序列的各个类中。数据类(如SemanticKittiDataset)在初始化的过程中会从pickle文件中获取到数据的元信息,并根据配置文件中的pipeline序列对其中的操作类型进行实例化,并按照顺序逐个应用于数据。接下来,我们继续探索pipeline中的数据操作类。
BaseTransform
以pipeline中的第一个类,即LoadPointsFromFile为例,其代码如下:
class LoadPointsFromFile(BaseTransform):
# 省略部分代码
def _load_points(self, pts_filename: str) -> np.ndarray:
try:
pts_bytes = get(pts_filename, backend_args=self.backend_args)
points = np.frombuffer(pts_bytes, dtype=np.float32)
except ConnectionError:
mmengine.check_file_exist(pts_filename)
if pts_filename.endswith('.npy'):
points = np.load(pts_filename)
else:
points = np.fromfile(pts_filename, dtype=np.float32)
return points
def transform(self, results: dict) -> dict:
pts_file_path = results['lidar_points']['lidar_path']
points = self._load_points(pts_file_path)
points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim]
if self.norm_intensity:
assert len(self.use_dim) >= 4, \
f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}' # noqa: E501
points[:, 3] = np.tanh(points[:, 3])
if self.norm_elongation:
assert len(self.use_dim) >= 5, \
f'When using elongation norm, expect used dimensions >= 5, got {len(self.use_dim)}' # noqa: E501
points[:, 4] = np.tanh(points[:, 4])
attribute_dims = None
if self.shift_height:
floor_height = np.percentile(points[:, 2], 0.99)
height = points[:, 2] - floor_height
points = np.concatenate(
[points[:, :3],
np.expand_dims(height, 1), points[:, 3:]], 1)
attribute_dims = dict(height=3)
if self.use_color:
assert len(self.use_dim) >= 6
if attribute_dims is None:
attribute_dims = dict()
attribute_dims.update(
dict(color=[
points.shape[1] - 3,
points.shape[1] - 2,
points.shape[1] - 1,
]))
points_class = get_points_type(self.coord_type)
points = points_class(
points, points_dim=points.shape[-1], attribute_dims=attribute_dims)
results['points'] = points
return results
根据代码,_load_points方法是根据文件名读取点云数据,而transform方法则是负责解析出点云文件名,并将_load_points方法读取的点云进行后续处理,最终将修剪后的点云放入results字典中。这里补上BaseTransform的代码:
class BaseTransform(metaclass=ABCMeta):
def __call__(self, results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]:
return self.transform(results)
@abstractmethod
def transform(self, results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]:
典型的模板方法,子类只需要实现抽象方法transform即可。
LoadAnnotations3D
这个类是用来加载点云数据对应的标注的,根据参数中的加载类型,将对应的标注读取并填充到相应的字段中,代码如下:
class LoadAnnotations3D(LoadAnnotations):
# 省略部分代码
def transform(self, results: dict) -> dict:
results = super().transform(results)
if self.with_bbox_3d:
results = self._load_bboxes_3d(results)
if self.with_bbox_depth:
results = self._load_bboxes_depth(results)
if self.with_label_3d:
results = self._load_labels_3d(results)
if self.with_attr_label:
results = self._load_attr_labels(results)
if self.with_panoptic_3d:
results = self._load_panoptic_3d(results)
if self.with_mask_3d:
results = self._load_masks_3d(results)
if self.with_seg_3d:
results = self._load_semantic_seg_3d(results)
return results
根据transform方法中的代码段,大家也能猜到,省略的代码中应该包含了一系列的_load方法,确实如此,每一个_load方法负责加载不同的标注,而if中的条件则由配置文件传入。需要注意的是,每经过一个Transform类处理,结果会体现在results中,如果有新增数据,则在results中增加字段,如果只是对原数据增强,则直接在修改原数据。
其它pipeline中的Transform子类作用类似,大家可以对自己感兴趣的代码加以探索研究。
- 上一篇:异常检测汇总
- 下一篇:AI数据分析:集中度分析和离散度分析
相关推荐
- 再说圆的面积-蒙特卡洛(蒙特卡洛方法求圆周率的matlab程序)
-
在微积分-圆的面积和周长(1)介绍微积分方法求解圆的面积,本文使用蒙特卡洛方法求解圆面积。...
- python创建分类器小结(pytorch分类数据集创建)
-
简介:分类是指利用数据的特性将其分成若干类型的过程。监督学习分类器就是用带标记的训练数据建立一个模型,然后对未知数据进行分类。...
- matplotlib——绘制散点图(matplotlib散点图颜色和图例)
-
绘制散点图不同条件(维度)之间的内在关联关系观察数据的离散聚合程度...
- python实现实时绘制数据(python如何绘制)
-
方法一importmatplotlib.pyplotaspltimportnumpyasnpimporttimefrommathimport*plt.ion()#...
- 简单学Python——matplotlib库3——绘制散点图
-
前面我们学习了用matplotlib绘制折线图,今天我们学习绘制散点图。其实简单的散点图与折线图的语法基本相同,只是作图函数由plot()变成了scatter()。下面就绘制一个散点图:import...
- 数据分析-相关性分析可视化(相关性分析数据处理)
-
前面介绍了相关性分析的原理、流程和常用的皮尔逊相关系数和斯皮尔曼相关系数,具体可以参考...
- 免费Python机器学习课程一:线性回归算法
-
学习线性回归的概念并从头开始在python中开发完整的线性回归算法最基本的机器学习算法必须是具有单个变量的线性回归算法。如今,可用的高级机器学习算法,库和技术如此之多,以至于线性回归似乎并不重要。但是...
- 用Python进行机器学习(2)之逻辑回归
-
前面介绍了线性回归,本次介绍的是逻辑回归。逻辑回归虽然名字里面带有“回归”两个字,但是它是一种分类算法,通常用于解决二分类问题,比如某个邮件是否是广告邮件,比如某个评价是否为正向的评价。逻辑回归也可以...
- 【Python机器学习系列】拟合和回归傻傻分不清?一文带你彻底搞懂
-
一、拟合和回归的区别拟合...
- 推荐2个十分好用的pandas数据探索分析神器
-
作者:俊欣来源:关于数据分析与可视化...
- 向量数据库:解锁大模型记忆的关键!选型指南+实战案例全解析
-
本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在...
- 用Python进行机器学习(11)-主成分分析PCA
-
我们在机器学习中有时候需要处理很多个参数,但是这些参数有时候彼此之间是有着各种关系的,这个时候我们就会想:是否可以找到一种方式来降低参数的个数呢?这就是今天我们要介绍的主成分分析,英文是Princip...
- 神经网络基础深度解析:从感知机到反向传播
-
本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在...
- Python实现基于机器学习的RFM模型
-
CDA数据分析师出品作者:CDALevelⅠ持证人岗位:数据分析师行业:大数据...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)