使用PyTorch进行深度学习的图像增强
ztj100 2024-11-03 16:15 13 浏览 0 评论
在这篇文章中,我们将了解图像增强的概念以及有什么不同的图像增强技术。我们还将使用PyTorch实现这些图像增强技术来构建一个图像分类深度学习模型。
为什么我们需要图像增强?
深度学习模型通常需要大量的数据来进行训练。通常,数据越多,模型的性能越好。但是获取海量数据面临着自身的挑战。并非每个人都拥有大公司的财力。
深度学习模型通常需要大量的训练数据。一般来说,数据越多,模型的性能越好。但获取大量数据本身也存在挑战。
缺乏数据的问题是,我们的深度学习模型可能无法从数据中学习模式或功能,因此它可能无法在不可见的数据上提供良好的性能。
在这种情况下我们能怎么做呢?我们可以利用图像增强技术来降低数据收集难度。
图像增强技术
图像旋转
图像旋转是最常用的增强技术之一。它可以帮助我们的模型对对象方向的变化变得健壮。即使我们旋转图像,图像的信息也保持不变。即从不同的角度看汽车,汽车也还是汽车:
我们可以使用此技术通过从原始图像创建旋转图像来增加数据大小。让我们看看如何旋转图像:
让我们导入图像并首先对其进行可视化:
这是原始图像。现在让我们看看如何旋转它。我将使用skimage库的rotate函数旋转图像:
图像平移
在某些情况下,图像中的对象未完全对准中心。在这些情况下,可以使用图像平移为图像添加平移不变性。
通过移动图像,我们可以更改对象在图像中的位置,从而使模型更具多样性。最终将导致更通用的模型。
图像平移是一种几何变换,可将图像中每个对象的位置映射到最终输出图像中的新位置。
平移操作之后,输入图像中位置(x,y)上存在的对象将移位到新位置(X,Y):
- X = x + dx
- Y = y + dy
此处,dx和dy是沿不同尺寸的相应位移。让我们看看如何将平移应用于图像:
平移超参数定义图像应移动的像素数。在这里,我将图像移动了(25,25)像素。我再次使用了“wrap”,用图像的其余像素填充输入边界外的点。
翻转图像
翻转是旋转的延伸。它使我们可以在左右以及上下方向上翻转图像。让我们看看如何实现翻转:
在这里,我使用了NumPy的fliplr函数来将图像从左到右翻转。它翻转每行的像素值。同样,我们可以上下翻转图像:
给图像添加噪声
图像降噪是重要的增强步骤,可让我们的深度学习模型学习如何将图像中的信号与噪声分离。
我们将使用skimage库的random_noise函数为原始图像添加一些随机噪声。我将要添加的噪声的标准偏差设为0.155(您也可以更改此值)。
图像模糊
由于图像来自不同的源,因此,图像的质量将不一样。有些图片可能是高质量的,而另一些可能是很差的。
在这种情况下,我们可以模糊图像。这有什么用呢?这有助于使我们的深度学习模型更健壮。
Sigma是标准偏差。sigma值越高,模糊效果越多。将“ Multichannel”设置为true可确保分别过滤图像的每个通道。
选择正确的增强技术的基本准则
在根据您要解决的问题确定增强技术时,有一些准则很重要。
- 任何机器学习模型构建过程的第一步都是确保输入的大小符合模型的期望。我们还必须确保所有图像的大小应该相似。为此,我们可以调整图像的大小到适当的大小。
- 假设您正在处理分类问题,并且数据样本的数量相对较少。在这种情况下,您可以使用不同的增强技术,例如图像旋转,图像降噪,翻转,移位等。请记住,所有这些操作都适用于图像中对象位置无关紧要的分类问题。
- 如果您正在执行对象检测任务,而对象的位置正是我们要检测的位置,那么这些技术可能不合适。
- 对图像像素值进行归一化是保证机器学习模型更好更快收敛的一种良好策略。如果机器学习模型有特定的要求,我们必须根据机器学习模型的要求对图像进行预处理。
案例研究:解决图像分类问题并应用图像增强
该项目的目的是将车辆图像分类为non-emergency 或emergency 。这是图像分类问题。您可以从此处下载数据集。
加载机器学习数据集
我们将应用图像增强技术,最后建立卷积神经网络(CNN)模型。让我们导入所需的Python库:
# importing the libraries from torchsummary import summary import pandas as pd import numpy as np from skimage.io import imread, imsave from tqdm import tqdm import matplotlib.pyplot as plt %matplotlib inline from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score from skimage.transform import rotate from skimage.util import random_noise from skimage.filters import gaussian from scipy import ndimage
现在,我们将读取包含图像名称及其相应标签的CSV文件:
# loading dataset data = pd.read_csv('emergency_vs_non-emergency_dataset/emergency_train.csv') data.head()
这里的0表示该车辆为non-emergency 车辆,1表示其为emergency 车辆。现在,让我们从机器学习数据集中导入有图像:
# loading images train_img = [] for img_name in tqdm(data['image_names']): image_path = 'emergency_vs_non-emergency_dataset/images/' + img_name img = imread(image_path) img = img/255 train_img.append(img) train_x = np.array(train_img) train_y = data['emergency_or_not'].values train_x.shape, train_y.shape
机器学习数据集中共有1,646张图像。让我们将这些数据分为训练集和验证集。我们将使用验证集来评估深度学习模型在看不见的数据上的表现:
train_x, val_x, train_y, val_y = train_test_split(train_x, train_y, test_size = 0.1, random_state = 13, stratify=train_y) (train_x.shape, train_y.shape), (val_x.shape, val_y.shape)
我将test_size保持为0.1,因此将随机选择10%的数据作为验证集,其余的90%的数据将用于训练模型。我们在训练集中有1,481张图像,这不足以训练深度学习模型。
因此,接下来,我们将增强这些训练图像,以增强训练集,这可以改善深度学习模型的性能。
图像增强
final_train_data = [] final_target_train = [] for i in tqdm(range(train_x.shape[0])): final_train_data.append(train_x[i]) final_train_data.append(rotate(train_x[i], angle=45, mode = 'wrap')) final_train_data.append(np.fliplr(train_x[i])) final_train_data.append(np.flipud(train_x[i])) final_train_data.append(random_noise(train_x[i],var=0.2**2)) for j in range(5): final_target_train.append(train_y[i])
我们为训练集中的1,481张图像中的每张图像生成了4张增强图像。让我们以数组形式转换图像并验证数据集的大小:
len(final_target_train), len(final_train_data) final_train = np.array(final_train_data) final_target_train = np.array(final_target_train)
让我们可视化这些图像:
fig,ax = plt.subplots(nrows=1,ncols=5,figsize=(20,20)) for i in range(5): ax[i].imshow(final_train[i+30]) ax[i].axis('off')
这里的第一张图片是数据集中的原始图片。其余四幅图像是使用不同的图像增强技术生成的。
现在是时候定义深度学习模型的体系结构了,然后在训练集上对其进行训练了。
# PyTorch libraries and modules import torch from torch.autograd import Variable from torch.nn import Linear, ReLU, CrossEntropyLoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout from torch.optim import Adam, SGD
我们必须将训练集和验证集都转换为PyTorch格式:
# converting training images into torch format final_train = final_train.reshape(7405, 3, 224, 224) final_train = torch.from_numpy(final_train) final_train = final_train.float() # converting the target into torch format final_target_train = final_target_train.astype(int) final_target_train = torch.from_numpy(final_target_train)
同样,我们将转换验证集:
# converting validation images into torch format val_x = val_x.reshape(165, 3, 224, 224) val_x = torch.from_numpy(val_x) val_x = val_x.float() # converting the target into torch format val_y = val_y.astype(int) val_y = torch.from_numpy(val_y)
模型架构
接下来,我们将定义深度学习模型的体系结构。该体系结构包含4个卷积块,然后是4个全连接的dense层:
torch.manual_seed(0) class Net(Module): def __init__(self): super(Net, self).__init__() self.cnn_layers = Sequential( # Defining a 2D convolution layer Conv2d(3, 32, kernel_size=3, stride=1, padding=1), ReLU(inplace=True), # adding batch normalization BatchNorm2d(32), MaxPool2d(kernel_size=2, stride=2), # adding dropout Dropout(p=0.25), # Defining another 2D convolution layer Conv2d(32, 64, kernel_size=3, stride=1, padding=1), ReLU(inplace=True), # adding batch normalization BatchNorm2d(64), MaxPool2d(kernel_size=2, stride=2), # adding dropout Dropout(p=0.25), # Defining another 2D convolution layer Conv2d(64, 128, kernel_size=3, stride=1, padding=1), ReLU(inplace=True), # adding batch normalization BatchNorm2d(128), MaxPool2d(kernel_size=2, stride=2), # adding dropout Dropout(p=0.25), # Defining another 2D convolution layer Conv2d(128, 128, kernel_size=3, stride=1, padding=1), ReLU(inplace=True), # adding batch normalization BatchNorm2d(128), MaxPool2d(kernel_size=2, stride=2), # adding dropout Dropout(p=0.25), ) self.linear_layers = Sequential( Linear(128 * 14 * 14, 512), ReLU(inplace=True), Dropout(), Linear(512, 256), ReLU(inplace=True), Dropout(), Linear(256,10), ReLU(inplace=True), Dropout(), Linear(10,2) ) # Defining the forward pass def forward(self, x): x = self.cnn_layers(x) x = x.view(x.size(0), -1) x = self.linear_layers(x) return x
让我们定义模型的其他超参数,包括优化器,学习率和损失函数:
# defining the model model = Net() # defining the optimizer optimizer = Adam(model.parameters(), lr=0.000075) # defining the loss function criterion = CrossEntropyLoss() # checking if GPU is available if torch.cuda.is_available(): model = model.cuda() criterion = criterion.cuda() print(model)
训练模型
让我们训练20个epochs:
torch.manual_seed(0) # batch size of the model batch_size = 64 # number of epochs to train the model n_epochs = 20 for epoch in range(1, n_epochs+1): train_loss = 0.0 permutation = torch.randperm(final_train.size()[0]) training_loss = [] for i in tqdm(range(0,final_train.size()[0], batch_size)): indices = permutation[i:i+batch_size] batch_x, batch_y = final_train[indices], final_target_train[indices] if torch.cuda.is_available(): batch_x, batch_y = batch_x.cuda(), batch_y.cuda() optimizer.zero_grad() outputs = model(batch_x) loss = criterion(outputs,batch_y) training_loss.append(loss.item()) loss.backward() optimizer.step() training_loss = np.average(training_loss) print('epoch: \t', epoch, '\t training loss: \t', training_loss)
您会注意到,随着时间的增加,训练损失会减少。让我们保存经过训练的深度学习模型的权重:
torch.save(model, 'model.pt')
加载这个深度学习模型的Python代码:
the_model = torch.load('model.pt')
检查模型的性能
最后,让我们对训练和验证集进行预测,并检查各自的准确性:
torch.manual_seed(0) # prediction for training set prediction = [] target = [] permutation = torch.randperm(final_train.size()[0]) for i in tqdm(range(0,final_train.size()[0], batch_size)): indices = permutation[i:i+batch_size] batch_x, batch_y = final_train[indices], final_target_train[indices] if torch.cuda.is_available(): batch_x, batch_y = batch_x.cuda(), batch_y.cuda() with torch.no_grad(): output = model(batch_x.cuda()) softmax = torch.exp(output).cpu() prob = list(softmax.numpy()) predictions = np.argmax(prob, axis=1) prediction.append(predictions) target.append(batch_y) # training accuracy accuracy = [] for i in range(len(prediction)): accuracy.append(accuracy_score(target[i].cpu(),prediction[i])) print('training accuracy: \t', np.average(accuracy))
我们在训练集上的准确性超过91%!我们需要对验证集进行相同的检查:
# checking the performance on validation set torch.manual_seed(0) output = model(val_x.cuda()) softmax = torch.exp(output).cpu() prob = list(softmax.detach().numpy()) predictions = np.argmax(prob, axis=1) accuracy_score(val_y, predictions)
最后
在本文中,我们介绍了大多数常用的图像增强技术。您可以在任何图像分类问题上尝试这些图像增强技术,然后比较有无增强时的性能。
相关推荐
- 使用 Pinia ORM 管理 Vue 中的状态
-
转载说明:原创不易,未经授权,谢绝任何形式的转载状态管理是构建任何Web应用程序的重要组成部分。虽然Vue提供了管理简单状态的技术,但随着应用程序复杂性的增加,处理状态可能变得更具挑战性。这就是为什么...
- Vue3开发企业级音乐Web App 明星讲师带你学习大厂高质量代码
-
Vue3开发企业级音乐WebApp明星讲师带你学习大厂高质量代码下栽课》jzit.top/392/...
- 一篇文章说清 webpack、vite、vue-cli、create-vue 的区别
-
webpack、vite、vue-cli、create-vue这些都是什么?看着有点晕,不要怕,我们一起来分辨一下。...
- 超赞 vue2/3 可视化打印设计VuePluginPrint
-
今天来给大家推荐一款非常不错的Vue可拖拽打印设计器Hiprint。引入使用//main.js中引入安装import{hiPrintPlugin}from'vue-plugin-...
- 搭建Trae+Vue3的AI开发环境(vue3 ts开发)
-
从2024年2025年,不断的有各种AI工具会在自媒体中火起来,号称各种效率王炸,而在AI是否会替代打工人的话题中,程序员又首当其冲。...
- Vue中mixin怎么理解?(vue的mixins有什么用)
-
作者:qdmryt转发链接:https://mp.weixin.qq.com/s/JHF3oIGSTnRegpvE6GSZhg前言...
- Vue脚手架安装,初始化项目,打包并用Tomcat和Nginx部署
-
1.创建Vue脚手架#1.在本地文件目录创建my-first-vue文件夹,安装vue-cli脚手架:npminstall-gvue-cli安装过程如下图所示:创建my-first-vue...
- 新手如何搭建个人网站(小白如何搭建个人网站)
-
ElementUl是饿了么前端团队推出的桌面端UI框架,具有是简洁、直观、强悍和低学习成本等优势,非常适合初学者使用。因此,本次项目使用ElementUI框架来完成个人博客的主体开发,欢迎大家讨论...
- 零基础入门vue开发(vue快速入门与实战开发)
-
上面一节我们已经成功的安装了nodejs,并且配置了npm的全局环境变量,那么这一节我们就来正式的安装vue-cli,然后在webstorm开发者工具里运行我们的vue项目。这一节有两种创建vue项目...
- .net core集成vue(.net core集成vue3)
-
react、angular、vue你更熟悉哪个?下边这个是vue的。要求需要你的计算机安装有o.netcore2.0以上版本onode、webpack、vue-cli、vue(npm...
- 使用 Vue 脚手架,为什么要学 webpack?(一)
-
先问大家一个很简单的问题:vueinitwebpackprjectName与vuecreateprojectName有什么区别呢?它们是Vue-cli2和Vue-cli3创建...
- vue 构建和部署(vue项目部署服务器)
-
普通的搭建方式(安装指令)安装Node.js检查node是否已安装,终端输入node-v会使用命令行(安装)npminstallvue-cli-首先安装vue-clivueinitwe...
- Vue.js 环境配置(vue的环境搭建)
-
说明:node.js和vue.js的关系:Node.js是一个基于ChromeV8引擎的JavaScript运行时环境;类比:Java的jvm(虚拟机)...
- vue项目完整搭建步骤(vuecli项目搭建)
-
简介为了让一些不太清楚搭建前端项目的小白,更快上手。今天我将一步一步带领你们进行前端项目的搭建。前端开发中需要用到框架,那vue作为三大框架主流之一,在工作中很常用。所以就以vue为例。...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- 使用 Pinia ORM 管理 Vue 中的状态
- Vue3开发企业级音乐Web App 明星讲师带你学习大厂高质量代码
- 一篇文章说清 webpack、vite、vue-cli、create-vue 的区别
- 超赞 vue2/3 可视化打印设计VuePluginPrint
- 搭建Trae+Vue3的AI开发环境(vue3 ts开发)
- 如何在现有的Vue项目中嵌入 Blazor项目?
- Vue中mixin怎么理解?(vue的mixins有什么用)
- Vue脚手架安装,初始化项目,打包并用Tomcat和Nginx部署
- 新手如何搭建个人网站(小白如何搭建个人网站)
- 零基础入门vue开发(vue快速入门与实战开发)
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)