百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

结构篇| 浅析LLaMA网络架构

ztj100 2025-02-10 15:16 8 浏览 0 评论

01 前言

LLaMA(Large Language Model Meta AI)是由Meta AI 发布的一个开放且高效的大型基础语言模型。为什么突然讲这个模型,主要LLaMA 已经成为了最受欢迎的开源大语言模型之一,LLaMA 系列模型在学术界和工业界引起了广泛的 关注,对于推动大语言模型技术的开源发展做出了重要贡献。

第一,开源,去了解其内部模型具有可行性。第二,它很受欢迎,说明在LLM界还是具有很强代表性,了解它内部结构有助于深入理解LLM发展路径。第三,众多 研究人员纷纷通过指令微调或继续预训练等方法来进一步扩展 LLaMA 模型的功 能和应用范围。其中,指令微调由于相对较低的计算成本,已成为开发定制化或专业化模型的首选方法,也因此出现了庞大的 LLaMA 家族。可以说LLaMA成为现在大部分互联网拥抱的对象,了解它,有助于拿下好offer。

总之,LLaMA 将有助于使 LLM 的使用和研究平民化,是一个深度学习LLM好切入口。同时,LLaMA已在教育、法律、医疗等专业领域有重要的应用场景,这对于构建大模型生态有先天的优势。

02 LLaMA架构

和GPT 系列一样,LLaMA 模型也是 Decoder-only 架构。底座也是Transformer的一种,《Transformer原理》和《概念篇| Transformer家族》已经介绍过Transformer模型,同样是基于自回归生成(Autoregressive)。自回归生成:在生成任务中,使用自回归(Autoregressive)方式,即逐个生成输出序列中的每个Token。在解码过程中,每次生成一个Token时,使用前面已生成的内容作为上下文,来帮助预测下一个Token。

2.0 Decoder-Only

当前主流的大语言模型都基于 Transformer 模型进行设计的。Transformer 是由多层的多头自注意力(Multi-head Self-attention)模块堆叠而成的神经网络模型。原 始的 Transformer 模型由编码器和解码器两个部分构成,而这两个部分实际上可以 独立使用,之前在《概念篇| Transformer家族》介绍过,例如基于编码器架构的 BERT 模型 和解码器架构的 GPT 模型 。 与 BERT 等早期的预训练语言模型相比,大语言模型的特点是使用了更长的向量 维度、更深的层数,进而包含了更大规模的模型参数,并主要使用解码器架构,对 于 Transformer 本身的结构与配置改变并不大。

与原生的Transformer的Decoder结构相比,做了以下几点改进:

Pre-normalization : 为了提高训练稳定性,LLaMA 对每个 transformer 子层的输入进行归一化,使用 RMSNorm 归一化函数,Pre-normalization 由Zhang和Sennrich引入。使用 RMSNorm 的好处是不用计算样本的均值,速度提升了40%。

SWiGLU:为了提高模型性能,结构上使用门控线性单元,且为了保持 FFN 层参数量不变,将隐藏单元的数量调整为 234d 而不是 PaLM 论文中的 4d,同时将 ReLU 替换为 SiLU 激活,引入以提高性能。

Rotary Embeddings:为了更好地建模长序列数据,模型的输入不再使用 positional embeddings,而是在网络的每一层添加了positional embeddings (RoPE),RoPE 方法由Su等人引入。

Grouped-Query Attention GQA:为了平衡效率和性能,部分版本采用了分组查询注意力机制。

虽然从LLaMA1到LLaMA3已经发布多个版本,大体架构基本相似,接下来针对LLaMA改进点进行详细介绍,其余组件可以参考原先文章《Transformer原理》。

2.1 RMSNorm

在之前的Transformer我们提到过,LN是对单个数据的指定维度进行Norm处理与batch无关。Transformer中的Normalization层一般都是采用LayerNorm来对Tensor进行归一化,LayerNorm的公式如下:

而RMSNorm就是LayerNorm的变体,RMSNorm省去了求均值的过程,也没有了偏置 β :

RMSNorm实现源码:

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))


    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)


    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

2.2 激活函数的改进-SwiGLU

与标准的Transformer一样,经过Attention层之后就进行FeedForward层的处理,但LLama2的FeedForward与标准的Transformer FeedForward有一些细微的差异:

SwiGLU激活函数是SwiGLU是GLU的一种变体,其中包含了GLU和Swish激活函数。

2.2.1 GLU

门控线性单元 GLU: 定义为门控线性单元( Gated Linear Units, GLU),定义为输入的两个线性变换的逐元素乘积,其中一个经过了 sigmoid 激活(也可以用其他激活函数替换)。

2.2.2 FFN_SwiGLU

FFN_SwiGLU 原版实现使用 Swish 稍有不同,LLaMA 官方提供的代码使用 F.silu() 激活函函数:

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)


        self.w1 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )
        self.w2 = RowParallelLinear(
            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
        )
        self.w3 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )


    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

2.3 RoPE 旋转位置编码

必须使用位置编码,是因为纯粹的 Attention 模块是无法捕捉输入顺序的,即无法理解不同位置的 token 代表的意义不同。熟悉《》都知道,在做注意力机制的时候,并没有考虑词语的顺序,但是NLP不用文字顺序会影响文本的根本性意思。比如,输入文本为“我爱吃肉包”或“肉包爱吃我”,模型会将这两句话视为相同的内容,因为嵌入中并没有明确的顺序信息让模型去学习。

2.3.1 绝对编码与相对编码

在标准的Transformer中通常是在整个网络进入Transformer Block之前做一个位置编码:

Transformer论文中,使用正余弦函数表示绝对位置,通过两者乘积得到相对位置。因为正余弦函数具有周期性,可以很好地表示序列中单词的相对位置。比较经典的位置编码用公式表达就是:

其中,i表示token在序列中的位置,设句子长度为 L,则i=0,.......,L-1 。p是token的位置向量,p(i,2t)表示这个位置向量里的第t个元素,t表示奇数维度,2t表示偶数维度;d表示token的维度。

除了绝对编码,还有一种相对编码,相对位置编码是根据单词之间的相对位置关系来计算位置编码。这种编码方式更加灵活,能够捕捉到不同单词之间的相对位置信息,有助于模型更好地理解序列中单词之间的关系。但是也有缺点,计算效率低下,同时大部分相对编码都没有落地可行性。

RoPE(Rotary Position Embedding)旋转位置编码,由模型 RoFormer: Enhanced Transformer with Rotary Position Embedding 提出。RoPE 的核心思想是将位置编码与词向量通过旋转矩阵相乘,使得词向量不仅包含词汇的语义信息,还融入了位置信息,其具有以下优点:

相对位置感知:使用绝对位置编码来达到相对位置编码的效果,RoPE 能够自然地捕捉词汇之间的相对位置关系。

无需额外的计算:位置编码与词向量的结合在计算上是高效的。

适应不同长度的序列:RoPE 可以灵活处理不同长度的输入序列。

具体如何做到呢?,这里面涉及了很多数学上推理,大家可以看一下
https://spaces.ac.cn/archives/8130,我这里只做个简单介绍一下:RoPE 借助了复数的思想,出发点是通过绝对位置编码的方式实现相对位置编码。根据复数乘法的几何意义,上述变换实际上是对应向量旋转,所以位置向量称为“旋转式位置编 码”。本质还是利用绝对位置编码,通过内积方式,得到相对位置表达式。

根据内积满足线性叠加的性质,任意偶数维的 RoPE,都可以表示为二维情形的拼接,即:

如果放在二维空间维度来看,可以用极坐标来理解, 旋转角度不影响轴长度:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis




def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)




def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)
    
    # 在attention 模块利用
    xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

2.4 GQA

在说这个之前,我们先回顾一下Transformer的多头注意力机制。

在 Transformer 中,注意力模块会并行多次重复计算。每个并行计算称为一个注意力头(Attention Head)。注意力模块将其查询 Query 、键 Key和值 Value的参数矩阵进行 N 次拆分,并将每次拆分独立通过一个单独的注意力头。最后,所有这些相同的注意力计算会合并在一起,产生最终的注意力分数。能够更细致地捕捉并表达每个词汇之间的多种联系和微妙差异。

  • MHA(Multi Head Attention) 中,每个头有自己单独的 key-value 对;标准的多头注意力机制,h个Query、Key 和 Value 矩阵。
  • MQA(Multi Query Attention) 中只会有一组 key-value 对;多查询注意力的一种变体,也是用于自回归解码的一种注意力机制。与MHA不同的是,MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。
  • GQA(Grouped Query Attention)中,会对 attention 进行分组操作,query 被分为 N 组,每个组共享一个 Key 和 Value 矩阵GQA将查询头分成G组,每个组共享一个Key 和 Value 矩阵。GQA-G是指具有G组的grouped-query attention。GQA-1具有单个组,因此具有单个Key 和 Value,等效于MQA。而GQA-H具有与头数相等的组,等效于MHA。

GQA介于MHA和MQA之间。GQA 综合 MHA 和 MQA ,既不损失太多性能,又能利用 MQA 的推理加速。不是所有 Q 头共享一组 KV,而是分组一定头数 Q 共享一组 KV,比如上图中就是两组 Q 共享一组 KV。现在LLaMA3基本都使用了GQA结构。

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads


        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
        )


        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()
        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()


    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        # 计算 q、k 、v
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)


        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        # 加上Rope位置旋转
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)


        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)
        # kv 缓存
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv


        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]
        # GQA应用
        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(
            keys, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(
            values, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)


        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(
            1, 2
        )  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

03 总结

Python的完整的LLaMa3代码在github可以快速找到,其核心代码也不过几百行,但其中的设计思想和理念,够我们这些小白喝一段时间,希望通过不断深入学习,提高对LLM实际的理解。通过记录所学的知识,建立自己的系统性思维。

相关推荐

再见Swagger UI 国人开源了一款超好用的 API 文档生成框架,真香

背景最近,栈长发现某些国内的开源项目都使用到了Knife4j技术,看名字就觉得很锋利啊!...

Spring Boot自动装配黑魔法:手把手教你打造高逼格自定义Starter

如果你是SpringBoot深度用户,是否经历过这样的痛苦:每个新项目都要重复配置Redis连接池,反复粘贴Swagger配置参数,在微服务架构中为统一日志格式疲于奔命?本文将为你揭开Spring...

Spring Boot(十五):集成Knife4j(spring boot 集成)

Knife4j的简介Knife4j是一个集Swagger2和OpenAPI3为一体的增强解决方案,它的前身是上一篇文章中介绍的swagger-bootstrap-ui。swagger-bootstra...

swagger-bootstrap-ui:swagger改进版本,界面更美观易于阅读

swagger作为一款在线文档生成工具,用于自动生成接口API,避免接口文档和代码不同步,但原生的界面不是很友好,下面介绍一款改进版本swagger-bootstrap-ui,界面左右侧布局,可以打开...

界面美观功能强大,终于可以告别单调的swagger ui了——knife4j

介绍knife4j是为JavaMVC框架集成Swagger生成Api文档的增强解决方案(在非Java项目中也提供了前端UI的增强解决方案),前身是swagger-bootstrap-ui,取名kni...

从 0 到 1 实战 Spring Boot 3:手把手教你构建高效 RESTful 接口

从0到1实战SpringBoot3:手把手教你构建高效RESTful接口在微服务架构盛行的今天,构建高效稳定的RESTful接口是后端开发者的核心技能。SpringBoot凭...

SpringBoot动态权限校验终极指南:3种高赞方案让老板主动加薪!

“上周用这套方案重构权限系统,CTO当着全组的面摔了祖传代码!”一位脉脉匿名网友的血泪经验:还在用硬编码写Shiro过滤器?RBAC模型搞出200张表?是时候用SpringSecurity+动态路...

一个基于 Spring Boot 的在线考试系统

今天推荐一款超级美观的在线考试系统,感兴趣可以先去预览地址看看该项目。在线Demo预览,http://129.211.88.191,账户分别是admin、teacher、student,密码是ad...

SpringBoot API开发的十大专业实践指南

在SpringBoot应用开发领域,构建高效、可靠的API需遵循系统化的开发规范。本文结合实战编码示例,详细解析10项关键开发实践,助您打造具备工业级标准的后端接口。一、RESTful...

震碎认知!将原理融会贯通到顶点的SpringBoot实战项目

SpringBoot是什么?我们知道,从2002年开始,Spring一直在飞速的发展,如今已经成为了在JavaEE(JavaEnterpriseEdition)开发中真正意义上的标准,但...

Spring Boot 整合 Knife4j 实现接口文档编写?

Knife4j增强版的SwaggerUI实现,在Knife4j中提供了很多功能并且用户体验也随之有了很大的提升。Knife4j主要基于Swagger2.0构建的,主要的用途就是在SpringBo...

前端同事老是说swagger不好用,我用了knife4j后,同事爽得不行

日常开发当中,少不了前端联调,随着协同开发的发展,前端对接口要求也变得越来越高了。所以我使用了knife4j,同事用完觉得太舒服了。knife4j简介:Knife4j...

一个基于spring boot的Java开源商城系统

前言一个基于springboot的JAVA开源商城系统,是前后端分离、为生产环境多实例完全准备、数据库为b2b2c商城系统设计、拥有完整下单流程和精美设计的java开源商城系统https://www...

再见 Swagger!国人开源了一款超好用的 API 文档生成框架真香

Knife4j是为JavaMVC框架集成Swagger生成Api文档的增强解决方案,前身是swagger-bootstrap-ui,取名kni4j是希望她能像一把匕首一样小巧,轻量,并且功能强悍!...

Spring Boot整合MybatisPlus和Druid

在Java中,我比较ORM熟悉的只有...

取消回复欢迎 发表评论: