深度学习,图像分类入门(图像分类实战)
ztj100 2024-10-31 16:12 92 浏览 0 评论
本文的文字及图片来源于网络,仅供学习、交流使用,不具有任何商业用途,,版权归原作者所有,如有问题请及时联系我们以作处理
作者:weixin_46389668 来源:CSDN
本文链接:https://blog.csdn.net/weixin_46389668/article/details/111185617
pytorch数据集读取方法
**CIFAR10数据集的定义方法如下:**
`dataset_dir = '../../../dataset/'
torchvision.datasets.CIFAR10(dataset_dir, train=True, transform=None, target_transform=None, download=False) `
dataset_dir:存放数据集的路径。
train(bool,可选)–如果为True,则构建训练集,否则构建测试集。
transform:定义数据预处理,数据增强方案都是在这里指定。
target_transform:标注的预处理,分类任务不常用。
download:是否下载,若为True则从互联网下载,如果已经在dataset_dir下存在,就不会再次下载
读取示例1(从网上自动下载)
train_data = torchvision.datasets.CIFAR10('../../../dataset',
train=True,
transform=None,
target_transform=None,
download=True)
读取示例2(示例1基础上附带数据增强)
# 读取训练集
custom_transform=transforms.transforms.Compose([
transforms.Resize((64, 64)), # 缩放到指定大小 64*64
transforms.ColorJitter(0.2, 0.2, 0.2), # 随机颜色变换
transforms.RandomRotation(5), # 随机旋转
transforms.Normalize([0.485,0.456,0.406], # 对图像像素进行归一化
[0.229,0.224,0.225])])
train_data=torchvision.datasets.CIFAR10('../../../dataset',
train=True,
transform=custom_transforms,
target_transform=None,
download=False)
数据加载
# 读取数据集
train_data=torchvision.datasets.CIFAR10('../../../dataset', train=True,
transform=None,
target_transform=None,
download=True)
# 实现数据批量读取
train_loader = torch.utils.data.DataLoader(train_data,
batch_size=2,
shuffle=True,
num_workers=4)
自定义数据集及读取方法
简单的对pytorch读取数据一般化pipeline的描述,就是下面的这个流程:
图像数据 ? 图像索引文件 ? 使用Dataset构建数据集 ? 使用DataLoader读取数据
2.2.1 图像索引文件制作
下载MNIST的图像和标签数据到Dive-into-CV-PyTorch/dataset/MNIST/目录下,得到下面的压缩文件并解压暂存,以用来充当自己的图像数据集。
train-images-idx3-ubyte.gz: training set images (9912422 bytes) ? train-images-idx3-ubyte(解压后)
train-labels-idx1-ubyte.gz: training set labels (28881 bytes) ? train-labels-idx1-ubyte(解压后)
t10k-images-idx3-ubyte.gz: test set images (1648877 bytes) ? t10k-images-idx3-ubyte(解压后)
t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes) ? t10k-labels-idx1-ubyte(解压后)
运行如下代码,实现图像数据的本地存储和索引文件的制作,我们将图像按照训练集和测试集分别存放,并且分别制作训练集和测试集的索引文件,在索引文件中将记录图像的文件名和标签信息。
import os
from skimage import io
import torchvision.datasets.mnist as mnist
# 数据文件读取
root = r'./MNIST/' # MNIST解压文件根目录
train_set = (
mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
)
test_set = (
mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
)
# 数据量展示
print('train set:', train_set[0].size())
print('test set:', test_set[0].size())
def convert_to_img(save_path, train=True):
'''
将图片存储在本地,并制作索引文件
@para: save_path 图像保存路径,将在路径下创建train、test文件夹分别存储训练集和测试集
@para: train 默认True,本地存储训练集图像,否则本地存储测试集图像
'''
if train:
f = open(save_path + 'train.txt', 'w')
data_path = save_path + '/train/'
if (not os.path.exists(data_path)):
os.makedirs(data_path)
for i, (img, label) in enumerate(zip(train_set[0], train_set[1])):
img_path = data_path + str(i) + '.jpg'
io.imsave(img_path, img.numpy())
int_label = str(label).replace('tensor(', '')
int_label = int_label.replace(')', '')
f.write(str(i)+'.jpg' + ',' + str(int_label) + '\n')
f.close()
else:
f = open(save_path + 'test.txt', 'w')
data_path = save_path + '/test/'
if (not os.path.exists(data_path)):
os.makedirs(data_path)
for i, (img, label) in enumerate(zip(test_set[0], test_set[1])):
img_path = data_path + str(i) + '.jpg'
io.imsave(img_path, img.numpy())
int_label = str(label).replace('tensor(', '')
int_label = int_label.replace(')', '')
f.write(str(i)+'.jpg' + ',' + str(int_label) + '\n')
f.close()
# 根据需求本地存储训练集或测试集
save_path = r'./MNIST/mnist_data/'
convert_to_img(save_path, True)
convert_to_img(save_path, False)
2.2.2 构建自己的Dataset
from torch.utils.data.dataset import Dataset
class MyDataset(Dataset): # 继承Dataset类
def __init__(self):
# 初始化图像文件路径或图像文件名列表等
pass
def __getitem__(self, index):
# 1.根据索引index从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open,cv2.imread)
# 2.预处理数据(例如torchvision.Transform)
# 3.返回数据对(例如图像和标签)
pass
def __len__(self):
return count # 返回数据量
__init__() : 初始化模块,初始化该类的一些基本参数
__getitem__() : 接收一个index,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息,返回数据对(图像和标签)
__len__() : 返回所有数据的数量
import pandas as pd
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
class MnistDataset(Dataset):
def __init__(self, image_path, image_label, transform=None):
super(MnistDataset, self).__init__()
self.image_path = image_path # 初始化图像路径列表
self.image_label = image_label # 初始化图像标签列表
self.transform = transform # 初始化数据增强方法
def __getitem__(self, index):
"""
获取对应index的图像,并视情况进行数据增强
"""
image = Image.open(self.image_path[index])
image = np.array(image)
label = float(self.image_label[index])
if self.transform is not None:
image = self.transform(image)
return image, torch.tensor(label)
def __len__(self):
return len(self.image_path)
def get_path_label(img_root, label_file_path):
"""
获取数字图像的路径和标签并返回对应列表
@para: img_root: 保存图像的根目录
@para:label_file_path: 保存图像标签数据的文件路径 .csv 或 .txt 分隔符为','
@return: 图像的路径列表和对应标签列表
"""
data = pd.read_csv(label_file_path, names=['img', 'label'])
data['img'] = data['img'].apply(lambda x: img_root + x)
return data['img'].tolist(), data['label'].tolist()
# 获取训练集路径列表和标签列表
train_data_root = './dataset/MNIST/mnist_data/train/'
train_label = './dataset/MNIST/mnist_data/train.txt'
train_img_list, train_label_list = get_path_label(train_data_root, train_label)
# 训练集dataset
train_dataset = MnistDataset(train_img_list,
train_label_list,
transform=transforms.Compose([transforms.ToTensor()]))
# 获取测试集路径列表和标签列表
test_data_root = './dataset/MNIST/mnist_data/test/'
test_label = './dataset/MNIST/mnist_data/test.txt'
test_img_list, test_label_list = get_path_label(test_data_root, test_label)
# 测试集sdataset
test_dataset = MnistDataset(test_img_list,
test_label_list,
transform=transforms.Compose([transforms.ToTensor()]))
2.2.3 使用DataLoader批量读取数据
使用 DataLoader 批量的读取数据,相当于帮我们完成一个batch的数据组装工作。
Dataloader 为一个迭代器,最基本的使用方法就是传入一个 Dataset 对象,在Dataloader中,会触发Dataset对象中的 gititem() 函数,逐次读取数据,并根据 batch_size 产生一个 batch 的数据,实现批量化的数据读取。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
dataset:加载的数据集(Dataset对象)
batch_size:一个批量数目大小
shuffle::是否打乱数据顺序
sampler: 样本抽样方式
num_workers:使用多进程加载的进程数,0代表不使用多进程
collate_fn: 将多个样本数据组成一个batch的方式,一般使用默认的拼接方式,可以通过自定义这个函数来完成一些特殊的读取逻辑。
pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
drop_last:为True时,dataset中的数据个数不是batch_size整数倍时,将多出来不足一个batch的数据丢弃
Dataset是对本地数据读取逻辑的定义;而DataLoader是对Dataset对象的封装,执行调度,将一个batch size的图像数据组装在一起,实现批量读取数据。
对于图像分类问题,torchvision还提供了一种文件目录组织形式可供调用,即ImageFolder,因为利用了分类任务的特性,此时就不用再另行创建一份标签文件了。这种文件目录组织形式,要求数据集已经自觉按照待分配的类别分成了不同的文件夹,一种类别的文件夹下面只存放同一种类别的图片。
相关推荐
- 如何将数据仓库迁移到阿里云 AnalyticDB for PostgreSQL
-
阿里云AnalyticDBforPostgreSQL(以下简称ADBPG,即原HybridDBforPostgreSQL)为基于PostgreSQL内核的MPP架构的实时数据仓库服务,可以...
- Python数据分析:探索性分析
-
写在前面如果你忘记了前面的文章,可以看看加深印象:Python数据处理...
- C++基础语法梳理:算法丨十大排序算法(二)
-
本期是C++基础语法分享的第十六节,今天给大家来梳理一下十大排序算法后五个!归并排序...
- C 语言的标准库有哪些
-
C语言的标准库并不是一个单一的实体,而是由一系列头文件(headerfiles)组成的集合。每个头文件声明了一组相关的函数、宏、类型和常量。程序员通过在代码中使用#include<...
- [深度学习] ncnn安装和调用基础教程
-
1介绍ncnn是腾讯开发的一个为手机端极致优化的高性能神经网络前向计算框架,无第三方依赖,跨平台,但是通常都需要protobuf和opencv。ncnn目前已在腾讯多款应用中使用,如QQ,Qzon...
- 用rust实现经典的冒泡排序和快速排序
-
1.假设待排序数组如下letmutarr=[5,3,8,4,2,7,1];...
- ncnn+PPYOLOv2首次结合!全网最详细代码解读来了
-
编辑:好困LRS【新智元导读】今天给大家安利一个宝藏仓库miemiedetection,该仓库集合了PPYOLO、PPYOLOv2、PPYOLOE三个算法pytorch实现三合一,其中的PPYOL...
- C++特性使用建议
-
1.引用参数使用引用替代指针且所有不变的引用参数必须加上const。在C语言中,如果函数需要修改变量的值,参数必须为指针,如...
- Qt4/5升级到Qt6吐血经验总结V202308
-
00:直观总结增加了很多轮子,同时原有模块拆分的也更细致,估计为了方便拓展个管理。把一些过度封装的东西移除了(比如同样的功能有多个函数),保证了只有一个函数执行该功能。把一些Qt5中兼容Qt4的方法废...
- 到底什么是C++11新特性,请看下文
-
C++11是一个比较大的更新,引入了很多新特性,以下是对这些特性的详细解释,帮助您快速理解C++11的内容1.自动类型推导(auto和decltype)...
- 掌握C++11这些特性,代码简洁性、安全性和性能轻松跃升!
-
C++11(又称C++0x)是C++编程语言的一次重大更新,引入了许多新特性,显著提升了代码简洁性、安全性和性能。以下是主要特性的分类介绍及示例:一、核心语言特性1.自动类型推导(auto)编译器自...
- 经典算法——凸包算法
-
凸包算法(ConvexHull)一、概念与问题描述凸包是指在平面上给定一组点,找到包含这些点的最小面积或最小周长的凸多边形。这个多边形没有任何内凹部分,即从一个多边形内的任意一点画一条线到多边形边界...
- 一起学习c++11——c++11中的新增的容器
-
c++11新增的容器1:array当时的初衷是希望提供一个在栈上分配的,定长数组,而且可以使用stl中的模板算法。array的用法如下:#include<string>#includ...
- C++ 编程中的一些最佳实践
-
1.遵循代码简洁原则尽量避免冗余代码,通过模块化设计、清晰的命名和良好的结构,让代码更易于阅读和维护...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)