使用PyTorch进行深度学习的图像增强
ztj100 2024-11-03 16:15 21 浏览 0 评论
在这篇文章中,我们将了解图像增强的概念以及有什么不同的图像增强技术。我们还将使用PyTorch实现这些图像增强技术来构建一个图像分类深度学习模型。
为什么我们需要图像增强?
深度学习模型通常需要大量的数据来进行训练。通常,数据越多,模型的性能越好。但是获取海量数据面临着自身的挑战。并非每个人都拥有大公司的财力。
深度学习模型通常需要大量的训练数据。一般来说,数据越多,模型的性能越好。但获取大量数据本身也存在挑战。
缺乏数据的问题是,我们的深度学习模型可能无法从数据中学习模式或功能,因此它可能无法在不可见的数据上提供良好的性能。
在这种情况下我们能怎么做呢?我们可以利用图像增强技术来降低数据收集难度。
图像增强技术
图像旋转
图像旋转是最常用的增强技术之一。它可以帮助我们的模型对对象方向的变化变得健壮。即使我们旋转图像,图像的信息也保持不变。即从不同的角度看汽车,汽车也还是汽车:
我们可以使用此技术通过从原始图像创建旋转图像来增加数据大小。让我们看看如何旋转图像:
让我们导入图像并首先对其进行可视化:
这是原始图像。现在让我们看看如何旋转它。我将使用skimage库的rotate函数旋转图像:
图像平移
在某些情况下,图像中的对象未完全对准中心。在这些情况下,可以使用图像平移为图像添加平移不变性。
通过移动图像,我们可以更改对象在图像中的位置,从而使模型更具多样性。最终将导致更通用的模型。
图像平移是一种几何变换,可将图像中每个对象的位置映射到最终输出图像中的新位置。
平移操作之后,输入图像中位置(x,y)上存在的对象将移位到新位置(X,Y):
- X = x + dx
- Y = y + dy
此处,dx和dy是沿不同尺寸的相应位移。让我们看看如何将平移应用于图像:
平移超参数定义图像应移动的像素数。在这里,我将图像移动了(25,25)像素。我再次使用了“wrap”,用图像的其余像素填充输入边界外的点。
翻转图像
翻转是旋转的延伸。它使我们可以在左右以及上下方向上翻转图像。让我们看看如何实现翻转:
在这里,我使用了NumPy的fliplr函数来将图像从左到右翻转。它翻转每行的像素值。同样,我们可以上下翻转图像:
给图像添加噪声
图像降噪是重要的增强步骤,可让我们的深度学习模型学习如何将图像中的信号与噪声分离。
我们将使用skimage库的random_noise函数为原始图像添加一些随机噪声。我将要添加的噪声的标准偏差设为0.155(您也可以更改此值)。
图像模糊
由于图像来自不同的源,因此,图像的质量将不一样。有些图片可能是高质量的,而另一些可能是很差的。
在这种情况下,我们可以模糊图像。这有什么用呢?这有助于使我们的深度学习模型更健壮。
Sigma是标准偏差。sigma值越高,模糊效果越多。将“ Multichannel”设置为true可确保分别过滤图像的每个通道。
选择正确的增强技术的基本准则
在根据您要解决的问题确定增强技术时,有一些准则很重要。
- 任何机器学习模型构建过程的第一步都是确保输入的大小符合模型的期望。我们还必须确保所有图像的大小应该相似。为此,我们可以调整图像的大小到适当的大小。
- 假设您正在处理分类问题,并且数据样本的数量相对较少。在这种情况下,您可以使用不同的增强技术,例如图像旋转,图像降噪,翻转,移位等。请记住,所有这些操作都适用于图像中对象位置无关紧要的分类问题。
- 如果您正在执行对象检测任务,而对象的位置正是我们要检测的位置,那么这些技术可能不合适。
- 对图像像素值进行归一化是保证机器学习模型更好更快收敛的一种良好策略。如果机器学习模型有特定的要求,我们必须根据机器学习模型的要求对图像进行预处理。
案例研究:解决图像分类问题并应用图像增强
该项目的目的是将车辆图像分类为non-emergency 或emergency 。这是图像分类问题。您可以从此处下载数据集。
加载机器学习数据集
我们将应用图像增强技术,最后建立卷积神经网络(CNN)模型。让我们导入所需的Python库:
# importing the libraries from torchsummary import summary import pandas as pd import numpy as np from skimage.io import imread, imsave from tqdm import tqdm import matplotlib.pyplot as plt %matplotlib inline from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score from skimage.transform import rotate from skimage.util import random_noise from skimage.filters import gaussian from scipy import ndimage
现在,我们将读取包含图像名称及其相应标签的CSV文件:
# loading dataset data = pd.read_csv('emergency_vs_non-emergency_dataset/emergency_train.csv') data.head()
这里的0表示该车辆为non-emergency 车辆,1表示其为emergency 车辆。现在,让我们从机器学习数据集中导入有图像:
# loading images train_img = [] for img_name in tqdm(data['image_names']): image_path = 'emergency_vs_non-emergency_dataset/images/' + img_name img = imread(image_path) img = img/255 train_img.append(img) train_x = np.array(train_img) train_y = data['emergency_or_not'].values train_x.shape, train_y.shape
机器学习数据集中共有1,646张图像。让我们将这些数据分为训练集和验证集。我们将使用验证集来评估深度学习模型在看不见的数据上的表现:
train_x, val_x, train_y, val_y = train_test_split(train_x, train_y, test_size = 0.1, random_state = 13, stratify=train_y) (train_x.shape, train_y.shape), (val_x.shape, val_y.shape)
我将test_size保持为0.1,因此将随机选择10%的数据作为验证集,其余的90%的数据将用于训练模型。我们在训练集中有1,481张图像,这不足以训练深度学习模型。
因此,接下来,我们将增强这些训练图像,以增强训练集,这可以改善深度学习模型的性能。
图像增强
final_train_data = [] final_target_train = [] for i in tqdm(range(train_x.shape[0])): final_train_data.append(train_x[i]) final_train_data.append(rotate(train_x[i], angle=45, mode = 'wrap')) final_train_data.append(np.fliplr(train_x[i])) final_train_data.append(np.flipud(train_x[i])) final_train_data.append(random_noise(train_x[i],var=0.2**2)) for j in range(5): final_target_train.append(train_y[i])
我们为训练集中的1,481张图像中的每张图像生成了4张增强图像。让我们以数组形式转换图像并验证数据集的大小:
len(final_target_train), len(final_train_data) final_train = np.array(final_train_data) final_target_train = np.array(final_target_train)
让我们可视化这些图像:
fig,ax = plt.subplots(nrows=1,ncols=5,figsize=(20,20)) for i in range(5): ax[i].imshow(final_train[i+30]) ax[i].axis('off')
这里的第一张图片是数据集中的原始图片。其余四幅图像是使用不同的图像增强技术生成的。
现在是时候定义深度学习模型的体系结构了,然后在训练集上对其进行训练了。
# PyTorch libraries and modules import torch from torch.autograd import Variable from torch.nn import Linear, ReLU, CrossEntropyLoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout from torch.optim import Adam, SGD
我们必须将训练集和验证集都转换为PyTorch格式:
# converting training images into torch format final_train = final_train.reshape(7405, 3, 224, 224) final_train = torch.from_numpy(final_train) final_train = final_train.float() # converting the target into torch format final_target_train = final_target_train.astype(int) final_target_train = torch.from_numpy(final_target_train)
同样,我们将转换验证集:
# converting validation images into torch format val_x = val_x.reshape(165, 3, 224, 224) val_x = torch.from_numpy(val_x) val_x = val_x.float() # converting the target into torch format val_y = val_y.astype(int) val_y = torch.from_numpy(val_y)
模型架构
接下来,我们将定义深度学习模型的体系结构。该体系结构包含4个卷积块,然后是4个全连接的dense层:
torch.manual_seed(0) class Net(Module): def __init__(self): super(Net, self).__init__() self.cnn_layers = Sequential( # Defining a 2D convolution layer Conv2d(3, 32, kernel_size=3, stride=1, padding=1), ReLU(inplace=True), # adding batch normalization BatchNorm2d(32), MaxPool2d(kernel_size=2, stride=2), # adding dropout Dropout(p=0.25), # Defining another 2D convolution layer Conv2d(32, 64, kernel_size=3, stride=1, padding=1), ReLU(inplace=True), # adding batch normalization BatchNorm2d(64), MaxPool2d(kernel_size=2, stride=2), # adding dropout Dropout(p=0.25), # Defining another 2D convolution layer Conv2d(64, 128, kernel_size=3, stride=1, padding=1), ReLU(inplace=True), # adding batch normalization BatchNorm2d(128), MaxPool2d(kernel_size=2, stride=2), # adding dropout Dropout(p=0.25), # Defining another 2D convolution layer Conv2d(128, 128, kernel_size=3, stride=1, padding=1), ReLU(inplace=True), # adding batch normalization BatchNorm2d(128), MaxPool2d(kernel_size=2, stride=2), # adding dropout Dropout(p=0.25), ) self.linear_layers = Sequential( Linear(128 * 14 * 14, 512), ReLU(inplace=True), Dropout(), Linear(512, 256), ReLU(inplace=True), Dropout(), Linear(256,10), ReLU(inplace=True), Dropout(), Linear(10,2) ) # Defining the forward pass def forward(self, x): x = self.cnn_layers(x) x = x.view(x.size(0), -1) x = self.linear_layers(x) return x
让我们定义模型的其他超参数,包括优化器,学习率和损失函数:
# defining the model model = Net() # defining the optimizer optimizer = Adam(model.parameters(), lr=0.000075) # defining the loss function criterion = CrossEntropyLoss() # checking if GPU is available if torch.cuda.is_available(): model = model.cuda() criterion = criterion.cuda() print(model)
训练模型
让我们训练20个epochs:
torch.manual_seed(0) # batch size of the model batch_size = 64 # number of epochs to train the model n_epochs = 20 for epoch in range(1, n_epochs+1): train_loss = 0.0 permutation = torch.randperm(final_train.size()[0]) training_loss = [] for i in tqdm(range(0,final_train.size()[0], batch_size)): indices = permutation[i:i+batch_size] batch_x, batch_y = final_train[indices], final_target_train[indices] if torch.cuda.is_available(): batch_x, batch_y = batch_x.cuda(), batch_y.cuda() optimizer.zero_grad() outputs = model(batch_x) loss = criterion(outputs,batch_y) training_loss.append(loss.item()) loss.backward() optimizer.step() training_loss = np.average(training_loss) print('epoch: \t', epoch, '\t training loss: \t', training_loss)
您会注意到,随着时间的增加,训练损失会减少。让我们保存经过训练的深度学习模型的权重:
torch.save(model, 'model.pt')
加载这个深度学习模型的Python代码:
the_model = torch.load('model.pt')
检查模型的性能
最后,让我们对训练和验证集进行预测,并检查各自的准确性:
torch.manual_seed(0) # prediction for training set prediction = [] target = [] permutation = torch.randperm(final_train.size()[0]) for i in tqdm(range(0,final_train.size()[0], batch_size)): indices = permutation[i:i+batch_size] batch_x, batch_y = final_train[indices], final_target_train[indices] if torch.cuda.is_available(): batch_x, batch_y = batch_x.cuda(), batch_y.cuda() with torch.no_grad(): output = model(batch_x.cuda()) softmax = torch.exp(output).cpu() prob = list(softmax.numpy()) predictions = np.argmax(prob, axis=1) prediction.append(predictions) target.append(batch_y) # training accuracy accuracy = [] for i in range(len(prediction)): accuracy.append(accuracy_score(target[i].cpu(),prediction[i])) print('training accuracy: \t', np.average(accuracy))
我们在训练集上的准确性超过91%!我们需要对验证集进行相同的检查:
# checking the performance on validation set torch.manual_seed(0) output = model(val_x.cuda()) softmax = torch.exp(output).cpu() prob = list(softmax.detach().numpy()) predictions = np.argmax(prob, axis=1) accuracy_score(val_y, predictions)
最后
在本文中,我们介绍了大多数常用的图像增强技术。您可以在任何图像分类问题上尝试这些图像增强技术,然后比较有无增强时的性能。
相关推荐
- 30天学会Python编程:16. Python常用标准库使用教程
-
16.1collections模块16.1.1高级数据结构16.1.2示例...
- 强烈推荐!Python 这个宝藏库 re 正则匹配
-
Python的re模块(RegularExpression正则表达式)提供各种正则表达式的匹配操作。...
- Python爬虫中正则表达式的用法,只讲如何应用,不讲原理
-
Python爬虫:正则的用法(非原理)。大家好,这节课给大家讲正则的实际用法,不讲原理,通俗易懂的讲如何用正则抓取内容。·导入re库,这里是需要从html这段字符串中提取出中间的那几个文字。实例一个对...
- Python数据分析实战-正则提取文本的URL网址和邮箱(源码和效果)
-
实现功能:Python数据分析实战-利用正则表达式提取文本中的URL网址和邮箱...
- python爬虫教程之爬取当当网 Top 500 本五星好评书籍
-
我们使用requests和re来写一个爬虫作为一个爱看书的你(说的跟真的似的)怎么能发现好书呢?所以我们爬取当当网的前500本好五星评书籍怎么样?ok接下来就是学习python的正确姿...
- 深入理解re模块:Python中的正则表达式神器解析
-
在Python中,"re"是一个强大的模块,用于处理正则表达式(regularexpressions)。正则表达式是一种强大的文本模式匹配工具,用于在字符串中查找、替换或提取特定模式...
- 如何使用正则表达式和 Python 匹配不以模式开头的字符串
-
需要在Python中使用正则表达式来匹配不以给定模式开头的字符串吗?如果是这样,你可以使用下面的语法来查找所有的字符串,除了那些不以https开始的字符串。r"^(?!https).*&...
- 先Mark后用!8分钟读懂 Python 性能优化
-
从本文总结了Python开发时,遇到的性能优化问题的定位和解决。概述:性能优化的原则——优化需要优化的部分。性能优化的一般步骤:首先,让你的程序跑起来结果一切正常。然后,运行这个结果正常的代码,看看它...
- Python“三步”即可爬取,毋庸置疑
-
声明:本实例仅供学习,切忌遵守robots协议,请不要使用多线程等方式频繁访问网站。#第一步导入模块importreimportrequests#第二步获取你想爬取的网页地址,发送请求,获取网页内...
- 简单学Python——re库(正则表达式)2(split、findall、和sub)
-
1、split():分割字符串,返回列表语法:re.split('分隔符','目标字符串')例如:importrere.split(',','...
- Lavazza拉瓦萨再度牵手上海大师赛
-
阅读此文前,麻烦您点击一下“关注”,方便您进行讨论和分享。Lavazza拉瓦萨再度牵手上海大师赛标题:2024上海大师赛:网球与咖啡的浪漫邂逅在2024年的上海劳力士大师赛上,拉瓦萨咖啡再次成为官...
- ArkUI-X构建Android平台AAR及使用
-
本教程主要讲述如何利用ArkUI-XSDK完成AndroidAAR开发,实现基于ArkTS的声明式开发范式在android平台显示。包括:1.跨平台Library工程开发介绍...
- Deepseek写歌详细教程(怎样用deepseek写歌功能)
-
以下为结合DeepSeek及相关工具实现AI写歌的详细教程,涵盖作词、作曲、演唱全流程:一、核心流程三步法1.AI生成歌词-打开DeepSeek(网页/APP/API),使用结构化提示词生成歌词:...
- “AI说唱解说影视”走红,“零基础入行”靠谱吗?本报记者实测
-
“手里翻找冻鱼,精心的布局;老漠却不言语,脸上带笑意……”《狂飙》剧情被写成歌词,再配上“科目三”背景音乐的演唱,这段1分钟30秒的视频受到了无数网友的点赞。最近一段时间随着AI技术的发展,说唱解说影...
- AI音乐制作神器揭秘!3款工具让你秒变高手
-
在音乐创作的领域里,每个人都有一颗想要成为大师的心。但是面对复杂的乐理知识和繁复的制作过程,许多人的热情被一点点消磨。...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- 30天学会Python编程:16. Python常用标准库使用教程
- 强烈推荐!Python 这个宝藏库 re 正则匹配
- Python爬虫中正则表达式的用法,只讲如何应用,不讲原理
- Python数据分析实战-正则提取文本的URL网址和邮箱(源码和效果)
- python爬虫教程之爬取当当网 Top 500 本五星好评书籍
- 深入理解re模块:Python中的正则表达式神器解析
- 如何使用正则表达式和 Python 匹配不以模式开头的字符串
- 先Mark后用!8分钟读懂 Python 性能优化
- Python“三步”即可爬取,毋庸置疑
- 简单学Python——re库(正则表达式)2(split、findall、和sub)
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)