百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

MMDetection3D数据加载详解

ztj100 2025-01-05 00:59 23 浏览 0 评论

OpenmmLab众多框架加载数据方式如出一辙,正好最近有项目需要使用MMDetection3D,以此为例,记录下数据加载、变换、增强,至进入网络的流程。

BaseDataset

数据在进入网络之前会被包装为DataLoader,而DataLoader在配置文件中如下:

train_dataloader = dict(
    batch_size = 8,
    num_workers = 8,
    persistent_workers = True,
    sampler = dict(type = 'DefaultSampler', shuffle = True),
    dataset = dict(
        type = dataset_type,
        data_root = data_root,
        ann_file = 'semantickitti_infos_train.pkl',
        pipeline = train_pipeline,
        metainfo = metainfo,
        modality = input_modality,
        ignore_index = ignore_label,
        backend_args = backend_args))

参数和原生的DataLoader并无太大区别,这里,我们重点关注dataset,同样是一系列参数,其中,type的真实值为预定义的数据集类型或自定义的数据集类型。以SemanticKittiDataset为例,代码如下:

class SemanticKittiDataset(Seg3DDataset):
    # 先行省略

其核心功能基本全在父类Seg3DDataset中,代码如下:

class Seg3DDataset(BaseDataset):
    # 先行省略

再顺藤摸瓜,查看BaseDataset类代码:

class BaseDataset(Dataset):

这里的父类Dataset类如大家所想,即为pytorch框架中的经典Dataset,无需赘述,我们直接查看其子类BaseDataset中的__getitem__方法:

def __getitem__(self, idx: int) -> dict:
        # 省略部分代码
        if self.test_mode:
            data = self.prepare_data(idx)
            return data

        for _ in range(self.max_refetch + 1):
            data = self.prepare_data(idx)
            if data is None:
                idx = self._rand_another()
                continue
            return data

据此,我们知道单例数据的获取方法为prepare_data(idx),核心代码如下:

def prepare_data(self, idx) -> Any:
        data_info = self.get_data_info(idx)
        return self.pipeline(data_info)

get_data_info方法可以根据索引号idx获取到对应数据的信息,包含数据的路径、名称等元信息,而pipeline则根据数据的元信息获取真正的数据。在构造方法中,我们可以看到pipeline如下:

self.pipeline = Compose(pipeline)

继续探索Compose,其代码如下:

class Compose:
    # 省略部分代码
    def __init__(self, transforms: Optional[Sequence[Union[dict, Callable]]]):
        self.transforms: List[Callable] = []
        if transforms is None:
            transforms = []

        for transform in transforms:
            if isinstance(transform, dict):
                transform = TRANSFORMS.build(transform)
                if not callable(transform):
                    raise TypeError(f'transform should be a callable object, '
                                    f'but got {type(transform)}')
                self.transforms.append(transform)
            elif callable(transform):
                self.transforms.append(transform)

    def __call__(self, data: dict) -> Optional[dict]:
        for t in self.transforms:
            data = t(data)
            if data is None:
                return None
        return data

Compose类中,__call__方法用来回调,返回数据,而数据会被依次送入transforms序列进行处理,transforms是一个回调序列或字典序列,若是后者,则也会被当作配置而创建为回调序列。

根据pipeline的构造,其参数即为数据配置文件中的pipeline,配置代码如下:

train_pipeline = [
    dict(
        type = 'LoadPointsFromFile',
        coord_type = 'LIDAR',
        load_dim = 4,
        use_dim = 4,
        backend_args = backend_args
        ),
    dict(
        type = 'LoadAnnotations3D',
        with_bbox_3d = False,
        with_label_3d = False,
        with_seg_3d = True,
        seg_3d_dtype = 'np.int32',
        seg_offset = 2**16,
        dataset_type = 'semantickitti',
        backend_args = backend_args),
    dict(type = 'PointSegClassMapping'),
    dict(type = 'PointSample', num_points = 0.9),
    
    # 省略部分代码
    dict(type = 'Pack3DDetInputs', keys = ['points', 'pts_semantic_mask'])
]

至此,数据加载流程逐渐明朗,对数据的所有操作(包括读取、增强、封装等)全部分解于pipeline序列的各个类中。数据类(如SemanticKittiDataset)在初始化的过程中会从pickle文件中获取到数据的元信息,并根据配置文件中的pipeline序列对其中的操作类型进行实例化,并按照顺序逐个应用于数据。接下来,我们继续探索pipeline中的数据操作类。

BaseTransform

pipeline中的第一个类,即LoadPointsFromFile为例,其代码如下:

class LoadPointsFromFile(BaseTransform):
    # 省略部分代码
    def _load_points(self, pts_filename: str) -> np.ndarray:
        try:
            pts_bytes = get(pts_filename, backend_args=self.backend_args)
            points = np.frombuffer(pts_bytes, dtype=np.float32)
        except ConnectionError:
            mmengine.check_file_exist(pts_filename)
            if pts_filename.endswith('.npy'):
                points = np.load(pts_filename)
            else:
                points = np.fromfile(pts_filename, dtype=np.float32)

        return points

    def transform(self, results: dict) -> dict:
        pts_file_path = results['lidar_points']['lidar_path']
        points = self._load_points(pts_file_path)
        points = points.reshape(-1, self.load_dim)
        points = points[:, self.use_dim]
        if self.norm_intensity:
            assert len(self.use_dim) >= 4, \
                f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}'  # noqa: E501
            points[:, 3] = np.tanh(points[:, 3])
        if self.norm_elongation:
            assert len(self.use_dim) >= 5, \
                f'When using elongation norm, expect used dimensions >= 5, got {len(self.use_dim)}'  # noqa: E501
            points[:, 4] = np.tanh(points[:, 4])
        attribute_dims = None

        if self.shift_height:
            floor_height = np.percentile(points[:, 2], 0.99)
            height = points[:, 2] - floor_height
            points = np.concatenate(
                [points[:, :3],
                 np.expand_dims(height, 1), points[:, 3:]], 1)
            attribute_dims = dict(height=3)

        if self.use_color:
            assert len(self.use_dim) >= 6
            if attribute_dims is None:
                attribute_dims = dict()
            attribute_dims.update(
                dict(color=[
                    points.shape[1] - 3,
                    points.shape[1] - 2,
                    points.shape[1] - 1,
                ]))

        points_class = get_points_type(self.coord_type)
        points = points_class(
            points, points_dim=points.shape[-1], attribute_dims=attribute_dims)
        results['points'] = points
        return results

根据代码,_load_points方法是根据文件名读取点云数据,而transform方法则是负责解析出点云文件名,并将_load_points方法读取的点云进行后续处理,最终将修剪后的点云放入results字典中。这里补上BaseTransform的代码:

class BaseTransform(metaclass=ABCMeta):
    def __call__(self, results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]:
        return self.transform(results)

    @abstractmethod
    def transform(self, results: Dict) -> Optional[Union[Dict, Tuple[List, List]]]:

典型的模板方法,子类只需要实现抽象方法transform即可。

LoadAnnotations3D

这个类是用来加载点云数据对应的标注的,根据参数中的加载类型,将对应的标注读取并填充到相应的字段中,代码如下:

class LoadAnnotations3D(LoadAnnotations):
# 省略部分代码

    def transform(self, results: dict) -> dict:
        results = super().transform(results)
        if self.with_bbox_3d:
            results = self._load_bboxes_3d(results)
        if self.with_bbox_depth:
            results = self._load_bboxes_depth(results)
        if self.with_label_3d:
            results = self._load_labels_3d(results)
        if self.with_attr_label:
            results = self._load_attr_labels(results)
        if self.with_panoptic_3d:
            results = self._load_panoptic_3d(results)
        if self.with_mask_3d:
            results = self._load_masks_3d(results)
        if self.with_seg_3d:
            results = self._load_semantic_seg_3d(results)
        return results

根据transform方法中的代码段,大家也能猜到,省略的代码中应该包含了一系列的_load方法,确实如此,每一个_load方法负责加载不同的标注,而if中的条件则由配置文件传入。需要注意的是,每经过一个Transform类处理,结果会体现在results中,如果有新增数据,则在results中增加字段,如果只是对原数据增强,则直接在修改原数据。

其它pipeline中的Transform子类作用类似,大家可以对自己感兴趣的代码加以探索研究。

相关推荐

如何将数据仓库迁移到阿里云 AnalyticDB for PostgreSQL

阿里云AnalyticDBforPostgreSQL(以下简称ADBPG,即原HybridDBforPostgreSQL)为基于PostgreSQL内核的MPP架构的实时数据仓库服务,可以...

Python数据分析:探索性分析

写在前面如果你忘记了前面的文章,可以看看加深印象:Python数据处理...

CSP-J/S冲奖第21天:插入排序

...

C++基础语法梳理:算法丨十大排序算法(二)

本期是C++基础语法分享的第十六节,今天给大家来梳理一下十大排序算法后五个!归并排序...

C 语言的标准库有哪些

C语言的标准库并不是一个单一的实体,而是由一系列头文件(headerfiles)组成的集合。每个头文件声明了一组相关的函数、宏、类型和常量。程序员通过在代码中使用#include<...

[深度学习] ncnn安装和调用基础教程

1介绍ncnn是腾讯开发的一个为手机端极致优化的高性能神经网络前向计算框架,无第三方依赖,跨平台,但是通常都需要protobuf和opencv。ncnn目前已在腾讯多款应用中使用,如QQ,Qzon...

用rust实现经典的冒泡排序和快速排序

1.假设待排序数组如下letmutarr=[5,3,8,4,2,7,1];...

ncnn+PPYOLOv2首次结合!全网最详细代码解读来了

编辑:好困LRS【新智元导读】今天给大家安利一个宝藏仓库miemiedetection,该仓库集合了PPYOLO、PPYOLOv2、PPYOLOE三个算法pytorch实现三合一,其中的PPYOL...

C++特性使用建议

1.引用参数使用引用替代指针且所有不变的引用参数必须加上const。在C语言中,如果函数需要修改变量的值,参数必须为指针,如...

Qt4/5升级到Qt6吐血经验总结V202308

00:直观总结增加了很多轮子,同时原有模块拆分的也更细致,估计为了方便拓展个管理。把一些过度封装的东西移除了(比如同样的功能有多个函数),保证了只有一个函数执行该功能。把一些Qt5中兼容Qt4的方法废...

到底什么是C++11新特性,请看下文

C++11是一个比较大的更新,引入了很多新特性,以下是对这些特性的详细解释,帮助您快速理解C++11的内容1.自动类型推导(auto和decltype)...

掌握C++11这些特性,代码简洁性、安全性和性能轻松跃升!

C++11(又称C++0x)是C++编程语言的一次重大更新,引入了许多新特性,显著提升了代码简洁性、安全性和性能。以下是主要特性的分类介绍及示例:一、核心语言特性1.自动类型推导(auto)编译器自...

经典算法——凸包算法

凸包算法(ConvexHull)一、概念与问题描述凸包是指在平面上给定一组点,找到包含这些点的最小面积或最小周长的凸多边形。这个多边形没有任何内凹部分,即从一个多边形内的任意一点画一条线到多边形边界...

一起学习c++11——c++11中的新增的容器

c++11新增的容器1:array当时的初衷是希望提供一个在栈上分配的,定长数组,而且可以使用stl中的模板算法。array的用法如下:#include<string>#includ...

C++ 编程中的一些最佳实践

1.遵循代码简洁原则尽量避免冗余代码,通过模块化设计、清晰的命名和良好的结构,让代码更易于阅读和维护...

取消回复欢迎 发表评论: