深度学习中GPU和显存分析 深度学习Gpu使用率很低
ztj100 2024-12-19 17:56 57 浏览 0 评论
编者荐语
显存占用和GPU利用率是两个不一样的东西,显卡是由GPU计算单元和显存等组成的,显存和GPU的关系有点类似于内存和CPU的关系。显存可以看成是空间,类似于内存。GPU计算单元类似于CPU中的核,用来进行数值计算。
深度学习最吃硬件,耗资源,在本文,我将来科普一下在深度学习中:
- 何为“资源”
- 不同操作都耗费什么资源
- 如何充分的利用有限的资源
- 如何合理选择显卡
并纠正几个误区:
- 显存和GPU等价,使用GPU主要看显存的使用?
- Batch Size 越大,程序越快,而且近似成正比?
- 显存占用越多,程序越快?
- 显存占用大小和batch size大小成正比?
0 预备知识
nvidia-smi是Nvidia显卡命令行管理套件,基于NVML库,旨在管理和监控Nvidia GPU设备。
nvidia-smi的输出
这是nvidia-smi命令的输出,其中最重要的两个指标:
- 显存占用
- GPU利用率
显存占用和GPU利用率是两个不一样的东西,显卡是由GPU计算单元和显存等组成的,显存和GPU的关系有点类似于内存和CPU的关系。
这里推荐一个好用的小工具:gpustat,直接pip install gpustat即可安装,gpustat基于nvidia-smi,可以提供更美观简洁的展示,结合watch命令,可以动态实时监控GPU的使用情况。
watch --color -n1 gpustat -cpu
gpustat 输出
显存可以看成是空间,类似于内存。
- 显存用于存放模型,数据
- 显存越大,所能运行的网络也就越大
GPU计算单元类似于CPU中的核,用来进行数值计算。衡量计算量的单位是flop: the number of floating-point multiplication-adds,浮点数先乘后加算一个flop。计算能力越强大,速度越快。衡量计算能力的单位是flops:每秒能执行的flop数量
1. 显存分析
1.1 存储指标
K、M,G,T是以1024为底,而KB 、MB,GB,TB以1000为底。不过一般来说,在估算显存大小的时候,我们不需要严格的区分这二者。
在深度学习中会用到各种各样的数值类型,数值类型命名规范一般为TypeNum,比如Int64、Float32、Double64。
- Type:有Int,Float,Double等
- Num: 一般是 8,16,32,64,128,表示该类型所占据的比特数目
常用的数值类型如下图所示:
常用的数值类型
其中Float32 是在深度学习中最常用的数值类型,称为单精度浮点数,每一个单精度浮点数占用4Byte的显存。
举例来说:有一个1000x1000的 矩阵,float32,那么占用的显存差不多就是
2x3x256x256的四维数组(BxCxHxW)占用显存为:24M
1.2 神经网络显存占用
神经网络模型占用的显存包括:
- 模型自身的参数
- 模型的输出
举例来说,对于如下图所示的一个全连接网络(不考虑偏置项b)
模型的输入输出和参数
模型的显存占用包括:
- 参数:二维数组 W
- 模型的输出:二维数组 Y
输入X可以看成是上一层的输出,因此把它的显存占用归于上一层。
这么看来显存占用就是W和Y两个数组?
并非如此!!!
下面细细分析。
1.2.1 参数的显存占用
只有有参数的层,才会有显存占用。这部份的显存占用和输入无关,模型加载完成之后就会占用。
有参数的层主要包括:
- 卷积
- 全连接
- BatchNorm
- Embedding层
- ... ...
无参数的层:
- 多数的激活层(Sigmoid/ReLU)
- 池化层
- Dropout
- ... ...
更具体的来说,模型的参数数目(这里均不考虑偏置项b)为:
- Linear(M->N): 参数数目:M×N
- Conv2d(Cin, Cout, K): 参数数目:Cin × Cout × K × K
- BatchNorm(N): 参数数目:2N
- Embedding(N,W): 参数数目:N × W
参数占用显存 = 参数数目×n
n = 4 :float32
n = 2 : float16
n = 8 : double64
在PyTorch中,当你执行完model=MyGreatModel().cuda()之后就会占用相应的显存,占用的显存大小基本与上述分析的显存差不多(会稍大一些,因为其它开销)。
1.2.2 梯度与动量的显存占用
举例来说, 优化器如果是SGD:
这时候还需要保存动量, 因此显存x3
如果是Adam优化器,动量占用的显存更多,显存x4
总结一下,模型中与输入无关的显存占用包括:
- 参数 W
- 梯度 dW(一般与参数一样)
- 优化器的动量(普通SGD没有动量,momentum-SGD动量与梯度一样,Adam优化器动量的数量是梯度的两倍)
1.2.3 输入输出的显存占用
这部份的显存主要看输出的feature map 的形状。
feature map
比如卷积的输入输出满足以下关系:
据此可以计算出每一层输出的Tensor的形状,然后就能计算出相应的显存占用。
模型输出的显存占用,总结如下:
- 需要计算每一层的feature map的形状(多维数组的形状)
- 模型输出的显存占用与 batch size 成正比
- 需要保存输出对应的梯度用以反向传播(链式法则)
- 模型输出不需要存储相应的动量信息(因为不需要执行优化)
深度学习中神经网络的显存占用,我们可以得到如下公式:
显存占用 = 模型显存占用 + batch_size × 每个样本的显存占用
可以看出显存不是和batch-size简单的成正比,尤其是模型自身比较复杂的情况下:比如全连接很大,Embedding层很大
另外需要注意:
- 输入(数据,图片)一般不需要计算梯度
- 神经网络的每一层输入输出都需要保存下来,用来反向传播,但是在某些特殊的情况下,我们可以不要保存输入。比如ReLU,在PyTorch中,使用nn.ReLU(inplace = True) 能将激活函数ReLU的输出直接覆盖保存于模型的输入之中,节省不少显存。感兴趣的读者可以思考一下,这时候是如何反向传播的(提示:y=relu(x) -> dx = dy.copy();dx[y<=0]=0)
1.3 节省显存的方法
在深度学习中,一般占用显存最多的是卷积等层的输出,模型参数占用的显存相对较少,而且不太好优化。
节省显存一般有如下方法:
- 降低batch-size
- 下采样(NCHW -> (1/4)*NCHW)
- 减少全连接层(一般只留最后一层分类用的全连接层)
2 计算量分析
计算量的定义,之前已经讲过了,计算量越大,操作越费时,运行神经网络花费的时间越多。
2.1 常用操作的计算量
常用的操作计算量如下:
- 全连接层:BxMxN , B是batch size,M是输入形状,N是输出形状。
卷积的计算量分析
- ReLU的计算量:BHWC
2.2 AlexNet 分析
AlexNet的分析如下图,左边是每一层的参数数目(不是显存占用),右边是消耗的计算资源
AlexNet分析
可以看出:
- 全连接层占据了绝大多数的参数
- 卷积层的计算量最大
2.3 减少卷积层的计算量
今年谷歌提出的MobileNet,利用了一种被称为DepthWise Convolution的技术,将神经网络运行速度提升许多,它的核心思想就是把一个卷积操作拆分成两个相对简单的操作的组合。如图所示, 左边是原始卷积操作,右边是两个特殊而又简单的卷积操作的组合(上面类似于池化的操作,但是有权重,下面类似于全连接操作)。
Depthwise Convolution
这种操作使得:
- 显存占用变多(每一步的输出都要保存
2.4 常用模型 显存/计算复杂度/准确率
去年一篇论文(http://link.zhihu.com/?target=https%3A//arxiv.org/abs/1605.07678)总结了当时常用模型的各项指标,横座标是计算复杂度(越往右越慢,越耗时),纵座标是准确率(越高越好),圆的面积是参数数量(不是显存占用)。左上角我画了一个红色小圆,那是最理想的模型的的特点:快,效果好,占用显存小。
常见模型计算量/显存/准确率
3 总结
3.1 建议
- 时间更宝贵,尽可能使模型变快(减少flop)
- 显存占用不是和batch size简单成正比,模型自身的参数及其延伸出来的数据也要占据显存
- batch size越大,速度未必越快。在你充分利用计算资源的时候,加大batch size在速度上的提升很有限
尤其是batch-size,假定GPU处理单元已经充分利用的情况下:
- 增大batch size能增大速度,但是很有限(主要是并行计算的优化)
- 增大batch size能减缓梯度震荡,需要更少的迭代优化次数,收敛的更快,但是每次迭代耗时更长。
- 增大batch size使得一个epoch所能进行的优化次数变少,收敛可能变慢,从而需要更多时间才能收敛(比如batch_size 变成全部样本数目)。
3.2 关于显卡购买
一般显卡购买渠道就是京东自营、淘宝等电商平台,线下实体店也可以购买。 正常时期,同款显卡,京东自营的价格会略高于淘宝,主要是京东自营的售后比淘宝更好,更放心,而特殊时期,比如现在部分型号淘宝和京东自营的价格比较悬殊,我建议是淘宝购买,如果价格相差不大,优先京东自营购买。像微星不支持个人送保,我不建议在淘宝和拼多多等渠道购买,售后不方便,建议天猫旗舰店及京东自营等有售后保障的渠道购买,支持个人送保的品牌在哪里买都可以。
5月推荐入手价
本文都是针对单机单卡的分析,分布式的情况会和这个有所区别。在分析计算量的时候,只分析了前向传播,反向传播计算量一般会与前向传播有细微的差别。
相关推荐
- sharding-jdbc实现`分库分表`与`读写分离`
-
一、前言本文将基于以下环境整合...
- 三分钟了解mysql中主键、外键、非空、唯一、默认约束是什么
-
在数据库中,数据表是数据库中最重要、最基本的操作对象,是数据存储的基本单位。数据表被定义为列的集合,数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录,每一列代表记录中的一个域。...
- MySQL8行级锁_mysql如何加行级锁
-
MySQL8行级锁版本:8.0.34基本概念...
- mysql使用小技巧_mysql使用入门
-
1、MySQL中有许多很实用的函数,好好利用它们可以省去很多时间:group_concat()将取到的值用逗号连接,可以这么用:selectgroup_concat(distinctid)fr...
- MySQL/MariaDB中如何支持全部的Unicode?
-
永远不要在MySQL中使用utf8,并且始终使用utf8mb4。utf8mb4介绍MySQL/MariaDB中,utf8字符集并不是对Unicode的真正实现,即不是真正的UTF-8编码,因...
- 聊聊 MySQL Server 可执行注释,你懂了吗?
-
前言MySQLServer当前支持如下3种注释风格:...
- MySQL系列-源码编译安装(v5.7.34)
-
一、系统环境要求...
- MySQL的锁就锁住我啦!与腾讯大佬的技术交谈,是我小看它了
-
对酒当歌,人生几何!朝朝暮暮,唯有己脱。苦苦寻觅找工作之间,殊不知今日之事乃我心之痛,难道是我不配拥有工作嘛。自面试后他所谓的等待都过去一段时日,可惜在下京东上的小金库都要见低啦。每每想到不由心中一...
- MySQL字符问题_mysql中字符串的位置
-
中文写入乱码问题:我输入的中文编码是urf8的,建的库是urf8的,但是插入mysql总是乱码,一堆"???????????????????????"我用的是ibatis,终于找到原因了,我是这么解决...
- 深圳尚学堂:mysql基本sql语句大全(三)
-
数据开发-经典1.按姓氏笔画排序:Select*FromTableNameOrderByCustomerNameCollateChinese_PRC_Stroke_ci_as//从少...
- MySQL进行行级锁的?一会next-key锁,一会间隙锁,一会记录锁?
-
大家好,是不是很多人都对MySQL加行级锁的规则搞的迷迷糊糊,一会是next-key锁,一会是间隙锁,一会又是记录锁。坦白说,确实还挺复杂的,但是好在我找点了点规律,也知道如何如何用命令分析加...
- 一文讲清怎么利用Python Django实现Excel数据表的导入导出功能
-
摘要:Python作为一门简单易学且功能强大的编程语言,广受程序员、数据分析师和AI工程师的青睐。本文系统讲解了如何使用Python的Django框架结合openpyxl库实现Excel...
- 用DataX实现两个MySQL实例间的数据同步
-
DataXDataX使用Java实现。如果可以实现数据库实例之间准实时的...
- MySQL数据库知识_mysql数据库基础知识
-
MySQL是一种关系型数据库管理系统;那废话不多说,直接上自己以前学习整理文档:查看数据库命令:(1).查看存储过程状态:showprocedurestatus;(2).显示系统变量:show...
- 如何为MySQL中的JSON字段设置索引
-
背景MySQL在2015年中发布的5.7.8版本中首次引入了JSON数据类型。自此,它成了一种逃离严格列定义的方式,可以存储各种形状和大小的JSON文档,例如审计日志、配置信息、第三方数据包、用户自定...
你 发表评论:
欢迎- 一周热门
-
-
MySQL中这14个小玩意,让人眼前一亮!
-
旗舰机新标杆 OPPO Find X2系列正式发布 售价5499元起
-
【VueTorrent】一款吊炸天的qBittorrent主题,人人都可用
-
面试官:使用int类型做加减操作,是线程安全吗
-
C++编程知识:ToString()字符串转换你用正确了吗?
-
【Spring Boot】WebSocket 的 6 种集成方式
-
PyTorch 深度学习实战(26):多目标强化学习Multi-Objective RL
-
pytorch中的 scatter_()函数使用和详解
-
与 Java 17 相比,Java 21 究竟有多快?
-
基于TensorRT_LLM的大模型推理加速与OpenAI兼容服务优化
-
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)