可逆神经网络详细解析:让神经网络更加轻量化
ztj100 2024-12-19 17:55 27 浏览 0 评论
来源:PaperWeekly本文约4600字,建议阅读10分钟本文以可逆残差网络作为基础进行分析。
为什么要用可逆网络呢?
- 因为编码和解码使用相同的参数,所以 model 是轻量级的。可逆的降噪网络 InvDN 只有 DANet 网络参数量的 4.2%,但是 InvDN 的降噪性能更好。
- 由于可逆网络是信息无损的,所以它能保留输入数据的细节信息。
- 无论网络的深度如何,可逆网络都使用恒定的内存来计算梯度。
其中最主要目的就是为了减少内存的消耗,当前所有的神经网络都采用反向传播的方式来训练,反向传播算法需要存储网络的中间结果来计算梯度,而且其对内存的消耗与网络单元数成正比。这也就意味着,网络越深越广,对内存的消耗越大,这将成为很多应用的瓶颈。
下面是 Pytorch summary 的结果,Forward/backward pass size(MB): 218.59 就是需要保存的中间变量大小,可以看出这部分占据了很大部分显存(随着网络深度的增加,中间变量占据显存量会一直增加,resnet152(size=224)的中间变量更是占据总共内存的 606.6÷836.79≈0.725 )。如果不存储中间层结果,那么就可以大幅减少 GPU 的显存占用,有助于训练更深更广的网络。
import torch
from torchvision import models
from torchsummary import summary
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vgg = models.vgg16().to(device)
summary(vgg, (3, 224, 224))
结果:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 224, 224] 1,792
ReLU-2 [-1, 64, 224, 224] 0
Conv2d-3 [-1, 64, 224, 224] 36,928
ReLU-4 [-1, 64, 224, 224] 0
MaxPool2d-5 [-1, 64, 112, 112] 0
Conv2d-6 [-1, 128, 112, 112] 73,856
ReLU-7 [-1, 128, 112, 112] 0
Conv2d-8 [-1, 128, 112, 112] 147,584
ReLU-9 [-1, 128, 112, 112] 0
MaxPool2d-10 [-1, 128, 56, 56] 0
Conv2d-11 [-1, 256, 56, 56] 295,168
ReLU-12 [-1, 256, 56, 56] 0
Conv2d-13 [-1, 256, 56, 56] 590,080
ReLU-14 [-1, 256, 56, 56] 0
Conv2d-15 [-1, 256, 56, 56] 590,080
ReLU-16 [-1, 256, 56, 56] 0
MaxPool2d-17 [-1, 256, 28, 28] 0
Conv2d-18 [-1, 512, 28, 28] 1,180,160
ReLU-19 [-1, 512, 28, 28] 0
Conv2d-20 [-1, 512, 28, 28] 2,359,808
ReLU-21 [-1, 512, 28, 28] 0
Conv2d-22 [-1, 512, 28, 28] 2,359,808
ReLU-23 [-1, 512, 28, 28] 0
MaxPool2d-24 [-1, 512, 14, 14] 0
Conv2d-25 [-1, 512, 14, 14] 2,359,808
ReLU-26 [-1, 512, 14, 14] 0
Conv2d-27 [-1, 512, 14, 14] 2,359,808
ReLU-28 [-1, 512, 14, 14] 0
Conv2d-29 [-1, 512, 14, 14] 2,359,808
ReLU-30 [-1, 512, 14, 14] 0
MaxPool2d-31 [-1, 512, 7, 7] 0
Linear-32 [-1, 4096] 102,764,544
ReLU-33 [-1, 4096] 0
Dropout-34 [-1, 4096] 0
Linear-35 [-1, 4096] 16,781,312
ReLU-36 [-1, 4096] 0
Dropout-37 [-1, 4096] 0
Linear-38 [-1, 1000] 4,097,000
================================================================
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 218.59
Params size (MB): 527.79
Estimated Total Size (MB): 746.96
----------------------------------------------------------------
接下来我将先从可逆神经网络讲起,然后是神经网络的反向传播,最后是标准残差网络。对反向传播算法和标准残差网络比较熟悉的小伙伴,可以只看第一节:可逆神经网络。如果各位小伙伴不熟悉反向传播算法和标准残差网络,建议先看第二节:反向传播(BP)算法和第三节:残差网络(Residual Network)。本文1.2和1.3.4摘录自 @阿亮。
可逆神经网络
可逆网络具有的性质:
- 网络的输入、输出的大小必须一致。
- 网络的雅可比行列式不为 0。
1.1 什么是雅可比行列式?
雅可比行列式通常称为雅可比式(Jacobian),它是以 n 个 n 元函数的偏导数为元素的行列式 。事实上,在函数都连续可微(即偏导数都连续)的前提之下,它就是函数组的微分形式下的系数矩阵(即雅可比矩阵)的行列式。若因变量对自变量连续可微,而自变量对新变量连续可微,则因变量也对新变量连续可微。这可用行列式的乘法法则和偏导数的连锁法则直接验证。也类似于导数的连锁法则。偏导数的连锁法则也有类似的公式;这常用于重积分的计算中。
1.2 雅可比行列式与神经网络的关系
为什么神经网络会与雅可比行列式有关系?这里我借用李宏毅老师的 ppt(12-14页)。想看视频的可以到 b 站上看。
简单的来讲就是 ,他们的分布之间的关系就变为 ,又因为有 ,所以 这个网络的雅可比行列式不为 0 才行。
顺便提一下,flow-based Model 优化的损失函数如下:
其实这里跟矩阵运算很像,矩阵可逆的条件也是矩阵的雅可比行列式不为 0,雅可比矩阵可以理解为矩阵的一阶导数。
假设可逆网络的表达式为:
它的雅可比矩阵为:
其行列式为 1。
1.3 可逆残差网络(Reversible Residual Network)
论文标题:
The Reversible Residual Network: Backpropagation Without Storing Activations
论文链接:
https://arxiv.org/abs/1707.04585
多伦多大学的 Aidan N.Gomez 和 Mengye Ren 提出了可逆残差神经网络,当前层的激活结果可由下一层的结果计算得出,也就是如果我们知道网络层最后的结果,就可以反推前面每一层的中间结果。这样我们只需要存储网络的参数和最后一层的结果即可,激活结果的存储与网络的深度无关了,将大幅减少显存占用。令人惊讶的是,实验结果显示,可逆残差网络的表现并没有显著下降,与之前的标准残差网络实验结果基本旗鼓相当。
1.3.1 可逆块结构
可逆神经网络将每一层分割成两部分,分别为 和 ,每一个可逆块的输入是 ,输出是 。其结构如下:
正向计算图示:
公式表示:
逆向计算图示:
公式表示:
其中 F 和 G 都是相似的残差函数,参考上图残差网络。可逆块的跨距只能为 1,也就是说可逆块必须一个接一个连接,中间不能采用其它网络形式衔接,否则的话就会丢失信息,并且无法可逆计算了,这点与残差块不一样。如果一定要采取跟残差块相似的结构,也就是中间一部分采用普通网络形式衔接,那中间这部分的激活结果就必须显式的存起来。
1.3.2 不用存储激活结果的反向传播
为了更好地计算反向传播的步骤,我们修改一下上述正向计算和逆向计算的公式:
尽管 和 的值是相同的,但是两个变量在图中却代表不同的节点,所以在反向传播中它们的总体导数是不一样的。 的导数包含通过 产生的间接影响,而 的导数却不受 的任何影响。
在反向传播计算流程中,先给出最后一层的激活值 和误差传播的总体导数 ,然后要计算出其输入值 和对应的导数 ,以及残差函数 F 和 G 中权重参数的总体导数,求解步骤如下:
1.3.3 计算开销
一个 N 个连接的神经网络,正向计算的理论加乘开销为 N,反向传播求导的理论加乘开销为 2N(反向求导包含复合函数求导连乘),而可逆网络多一步需要反向计算输入值的操作,所以理论计算开销为 4N,比普通网络开销约多出 33% 左右。但是在实际操作中,正向和反向的计算开销在 GPU 上差不多,可以都理解为 N。那么这样的话,普通网络的整体计算开销为 2N,可逆网络的整体开销为 3N,也就是多出了约 50%。
1.3.4 雅可比行列式的计算
其编码公式如下:
其解码公式如下:
为了计算雅可比矩阵,我们更直观的写成下面的编码公式:
它的雅可比矩阵为:
其实上面这个雅可比行列式也是 1,因为这里 ,它们的系数是一样的。
有另外一种解释方式就是把这种对偶的形式切成两半:
其行列式为 1。
因为是对偶的形式,所以这里的行列式也为 1。
因为 ,所以其行列式也为 1。
反向传播(BP)算法
上图中符号的含义:
- x1,x2,x3:表示 3 个输入层节点。
- :表示从 t-1 层到 t 层的权重参数,j 表示 t 层的第 j 个节点,i 表示 t-1 层的第 i 个节点。
- :表示 t 层的第 i 个激活后输出结果。
- g(x):表示激活函数。
正向传播计算过程:
- 隐藏层(网络的第二层)
- 输出层(网络的最后一层)
反向传播计算过程:
以单个样本为例,假设输入向量是 [x1,x2,x3],目标输出值是 [y1,y2],代价函数用 L 表示。反向传播的总体原理就是根据总体输出误差,反向传播回网络,通过计算每一层节点的梯度,利用梯度下降法原理,更新每一层的网络权重 w 和偏置 b,这也是网络学习的过程。误差反向传播的优点就是可以把繁杂的导数计算以数列递推的形式来表示, 简化了计算过程。
以平方误差来计算反向传播的过程,代价函数表示如下:
根据导数的链式法则反向求解隐藏 -> 输出层、输入层 -> 隐藏层的权重表示:
引入新的误差求导表示形式,称为神经单元误差:
l=2,3 表示第几层,j 表示某一层的第几个节点。替换表示后如下:
所以我们可以归纳出一般的计算公式:
从上述公式可以看出,如果神经单元误差 δ 可以求出来,那么总误差对每一层的权重 w 和偏置 b 的偏导数就可以求出来,接下来就可以利用梯度下降法来优化参数了。
求解每一层的 δ:
- 输出层
- 隐藏层
也就是说,我们根据输出层的神经误差单元 δ 就可以直接求出隐藏层的神经误差单元,进而省去了隐藏层的繁杂的求导过程,我们可以得出更一般的计算过程:
从而得出 l 层神经单元误差和 l+1 层神经单元误差的关系。这就是误差反向传播算法,只要求出输出层的神经单元误差,其它层的神经单元误差就不需要计算偏导数了,而可以直接通过上述公式得出。
残差网络(Residual Network)
残差网络主要可以解决两个问题(其结构如下图):
- 梯度消失问题;
- 网络退化问题。
上述结构就是一个两层网络组成的残差块,残差块可以由 2、3 层甚至更多层组成,但是如果是一层的,就变成线性变换了,没什么意义了。上述图可以写成公式如下:
所以在第二层进入激活函数 ReLU之 前 F(x)+x 组成新的输入,也叫恒等映射。
恒等映射就是在这个残差块输入是 x 的情况下输出依然是 x,这样其目标就是学习让 F(X)=0。
这里有一个问题哈,为什么要额外加一个 x 呢,而不是让模型直接学习 F(x)=x?
因为让 F(x)=0 比较容易,初始化参数 W 非常小接近 0,就可以让输出接近 0,同时输出如果是负数,经过第一层 Relu 后输出依然 0,都能使得最后的 F(x)=0,也就是有多种情况都可以使得 F(x)=0;但是让 F(x)=x 确实非常难的,因为参数都必须刚刚好才能使得最后输出为 x。
恒等映射有什么作用?
恒等映射就可以解决网络退化的问题,当网络层数越来越深的时候,网络的精度却在下降,也就是说网络自身存在一个最优的层度结构,太深太浅都能使得模型精度下降。有了恒等映射存在,网络就能够自己学习到哪些层是冗余的,就可以无损通过这些层,理论上讲再深的网络都不影响其精度,解决了网络退化问题。
为什么可以解决梯度消失问题呢?
以两个残差块的结构实例图来分析,其中每个残差块有 2 层神经网络组成,如下图:
假设激活函数 ReLU 用 g(x) 函数来表示,样本实例是 [x1,y1],即输入是 x1,目标值是 y1,损失函数还是采用平方损失函数,则每一层的计算如下:
下面我们对第一个残差块的权重参数求导,根据链式求导法则,公式如下:
我们可以看到求导公式中多了一个+1项,这就将原来的链式求导中的连乘变成了连加状态,可以有效避免梯度消失了。
参考文献:
[1] PPT https://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2019/Lecture/FLOW%20(v7).pdf
[2] 神经网络的可逆形式 https://zhuanlan.zhihu.com/p/268242678
[3] 大幅减少GPU显存占用:可逆残差网络(The Reversible Residual Network) https://www.cnblogs.com/gczr/p/12181354.html
[4] 雅可比行列式 https://baike.baidu.com/item/雅可比行列式/4709261?fr=aladdin
[5] The Reversible Residual Network: Backpropagation Without Storing Activations
[6] pytorch-summary https://github.com/sksq96/pytorch-summary
相关推荐
- SpringBoot整合SpringSecurity+JWT
-
作者|Sans_https://juejin.im/post/5da82f066fb9a04e2a73daec一.说明SpringSecurity是一个用于Java企业级应用程序的安全框架,主要包含...
- 「计算机毕设」一个精美的JAVA博客系统源码分享
-
前言大家好,我是程序员it分享师,今天给大家带来一个精美的博客系统源码!可以自己买一个便宜的云服务器,当自己的博客网站,记录一下自己学习的心得。开发技术博客系统源码基于SpringBoot,shiro...
- springboot教务管理系统+微信小程序云开发附带源码
-
今天给大家分享的程序是基于springboot的管理,前端是小程序,系统非常的nice,不管是学习还是毕设都非常的靠谱。本系统主要分为pc端后台管理和微信小程序端,pc端有三个角色:管理员、学生、教师...
- SpringBoot+LayUI后台管理系统开发脚手架
-
源码获取方式:关注,转发之后私信回复【源码】即可免费获取到!项目简介本项目本着避免重复造轮子的原则,建立一套快速开发JavaWEB项目(springboot-mini),能满足大部分后台管理系统基础开...
- Spring Boot的Security安全控制——认识SpringSecurity!
-
SpringBoot的Security安全控制在Web项目开发中,安全控制是非常重要的,不同的人配置不同的权限,这样的系统才安全。最常见的权限框架有Shiro和SpringSecurity。Shi...
- 前同事2024年接私活已入百万,都是用这几个开源的SpringBoot项目
-
前言不得不佩服SpringBoot的生态如此强大,今天给大家推荐几款优秀的后台管理系统,小伙伴们再也不用从头到尾撸一个项目了。SmartAdmin...
- 值得学习的15 个优秀开源的 Spring Boot 学习项目
-
SpringBoot算是目前Java领域最火的技术栈了,除了书呢?当然就是开源项目了,今天整理15个开源领域非常不错的SpringBoot项目供大家学习,参考。高富帅的路上只能帮你到这里了,...
- 开发企业官网就用这个基于SpringBoot的CMS系统,真香
-
前言推荐这个项目是因为使用手册部署手册非常...
- 2021年超详细的java学习路线总结—纯干货分享
-
本文整理了java开发的学习路线和相关的学习资源,非常适合零基础入门java的同学,希望大家在学习的时候,能够节省时间。纯干货,良心推荐!第一阶段:Java基础...
- jeecg-boot学习总结及使用心得(jeecgboot简单吗)
-
jeecg-boot学习总结及使用心得1.jeecg-boot是一个真正前后端分离的模版项目,便于二次开发,使用的都是较流行的新技术,后端技术主要有spring-boot2.x、shiro、Myb...
- 后勤集团原料管理系统springboot+Layui+MybatisPlus+Shiro源代码
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目描述后勤集团原料管理系统spr...
- 白卷开源SpringBoot+Vue的前后端分离入门项目
-
简介白卷是一个简单的前后端分离项目,主要采用Vue.js+SpringBoot技术栈开发。除了用作入门练习,作者还希望该项目可以作为一些常见Web项目的脚手架,帮助大家简化搭建网站的流程。...
- Spring Security 自动踢掉前一个登录用户,一个配置搞定
-
登录成功后,自动踢掉前一个登录用户,松哥第一次见到这个功能,就是在扣扣里边见到的,当时觉得挺好玩的。自己做开发后,也遇到过一模一样的需求,正好最近的SpringSecurity系列正在连载,就结...
- 收藏起来!这款开源在线考试系统,我爱了
-
大家好,我是为广大程序员兄弟操碎了心的小编,每天推荐一个小工具/源码,装满你的收藏夹,每天分享一个小技巧,让你轻松节省开发效率,实现不加班不熬夜不掉头发,是我的目标!今天小编推荐一款基于Spr...
- Shiro框架:认证和授权原理(shiro权限认证流程)
-
优质文章,及时送达前言Shiro作为解决权限问题的常用框架,常用于解决认证、授权、加密、会话管理等场景。本文将对Shiro的认证和授权原理进行介绍:Shiro可以做什么?、Shiro是由什么组成的?举...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- SpringBoot整合SpringSecurity+JWT
- 「计算机毕设」一个精美的JAVA博客系统源码分享
- springboot教务管理系统+微信小程序云开发附带源码
- SpringBoot+LayUI后台管理系统开发脚手架
- Spring Boot的Security安全控制——认识SpringSecurity!
- 前同事2024年接私活已入百万,都是用这几个开源的SpringBoot项目
- 值得学习的15 个优秀开源的 Spring Boot 学习项目
- 开发企业官网就用这个基于SpringBoot的CMS系统,真香
- 2021年超详细的java学习路线总结—纯干货分享
- jeecg-boot学习总结及使用心得(jeecgboot简单吗)
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)