百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

结构篇| 浅析LLaMA网络架构

ztj100 2025-02-10 15:16 24 浏览 0 评论

01 前言

LLaMA(Large Language Model Meta AI)是由Meta AI 发布的一个开放且高效的大型基础语言模型。为什么突然讲这个模型,主要LLaMA 已经成为了最受欢迎的开源大语言模型之一,LLaMA 系列模型在学术界和工业界引起了广泛的 关注,对于推动大语言模型技术的开源发展做出了重要贡献。

第一,开源,去了解其内部模型具有可行性。第二,它很受欢迎,说明在LLM界还是具有很强代表性,了解它内部结构有助于深入理解LLM发展路径。第三,众多 研究人员纷纷通过指令微调或继续预训练等方法来进一步扩展 LLaMA 模型的功 能和应用范围。其中,指令微调由于相对较低的计算成本,已成为开发定制化或专业化模型的首选方法,也因此出现了庞大的 LLaMA 家族。可以说LLaMA成为现在大部分互联网拥抱的对象,了解它,有助于拿下好offer。

总之,LLaMA 将有助于使 LLM 的使用和研究平民化,是一个深度学习LLM好切入口。同时,LLaMA已在教育、法律、医疗等专业领域有重要的应用场景,这对于构建大模型生态有先天的优势。

02 LLaMA架构

和GPT 系列一样,LLaMA 模型也是 Decoder-only 架构。底座也是Transformer的一种,《Transformer原理》和《概念篇| Transformer家族》已经介绍过Transformer模型,同样是基于自回归生成(Autoregressive)。自回归生成:在生成任务中,使用自回归(Autoregressive)方式,即逐个生成输出序列中的每个Token。在解码过程中,每次生成一个Token时,使用前面已生成的内容作为上下文,来帮助预测下一个Token。

2.0 Decoder-Only

当前主流的大语言模型都基于 Transformer 模型进行设计的。Transformer 是由多层的多头自注意力(Multi-head Self-attention)模块堆叠而成的神经网络模型。原 始的 Transformer 模型由编码器和解码器两个部分构成,而这两个部分实际上可以 独立使用,之前在《概念篇| Transformer家族》介绍过,例如基于编码器架构的 BERT 模型 和解码器架构的 GPT 模型 。 与 BERT 等早期的预训练语言模型相比,大语言模型的特点是使用了更长的向量 维度、更深的层数,进而包含了更大规模的模型参数,并主要使用解码器架构,对 于 Transformer 本身的结构与配置改变并不大。

与原生的Transformer的Decoder结构相比,做了以下几点改进:

Pre-normalization : 为了提高训练稳定性,LLaMA 对每个 transformer 子层的输入进行归一化,使用 RMSNorm 归一化函数,Pre-normalization 由Zhang和Sennrich引入。使用 RMSNorm 的好处是不用计算样本的均值,速度提升了40%。

SWiGLU:为了提高模型性能,结构上使用门控线性单元,且为了保持 FFN 层参数量不变,将隐藏单元的数量调整为 234d 而不是 PaLM 论文中的 4d,同时将 ReLU 替换为 SiLU 激活,引入以提高性能。

Rotary Embeddings:为了更好地建模长序列数据,模型的输入不再使用 positional embeddings,而是在网络的每一层添加了positional embeddings (RoPE),RoPE 方法由Su等人引入。

Grouped-Query Attention GQA:为了平衡效率和性能,部分版本采用了分组查询注意力机制。

虽然从LLaMA1到LLaMA3已经发布多个版本,大体架构基本相似,接下来针对LLaMA改进点进行详细介绍,其余组件可以参考原先文章《Transformer原理》。

2.1 RMSNorm

在之前的Transformer我们提到过,LN是对单个数据的指定维度进行Norm处理与batch无关。Transformer中的Normalization层一般都是采用LayerNorm来对Tensor进行归一化,LayerNorm的公式如下:

而RMSNorm就是LayerNorm的变体,RMSNorm省去了求均值的过程,也没有了偏置 β :

RMSNorm实现源码:

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))


    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)


    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

2.2 激活函数的改进-SwiGLU

与标准的Transformer一样,经过Attention层之后就进行FeedForward层的处理,但LLama2的FeedForward与标准的Transformer FeedForward有一些细微的差异:

SwiGLU激活函数是SwiGLU是GLU的一种变体,其中包含了GLU和Swish激活函数。

2.2.1 GLU

门控线性单元 GLU: 定义为门控线性单元( Gated Linear Units, GLU),定义为输入的两个线性变换的逐元素乘积,其中一个经过了 sigmoid 激活(也可以用其他激活函数替换)。

2.2.2 FFN_SwiGLU

FFN_SwiGLU 原版实现使用 Swish 稍有不同,LLaMA 官方提供的代码使用 F.silu() 激活函函数:

class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)


        self.w1 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )
        self.w2 = RowParallelLinear(
            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
        )
        self.w3 = ColumnParallelLinear(
            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
        )


    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

2.3 RoPE 旋转位置编码

必须使用位置编码,是因为纯粹的 Attention 模块是无法捕捉输入顺序的,即无法理解不同位置的 token 代表的意义不同。熟悉《》都知道,在做注意力机制的时候,并没有考虑词语的顺序,但是NLP不用文字顺序会影响文本的根本性意思。比如,输入文本为“我爱吃肉包”或“肉包爱吃我”,模型会将这两句话视为相同的内容,因为嵌入中并没有明确的顺序信息让模型去学习。

2.3.1 绝对编码与相对编码

在标准的Transformer中通常是在整个网络进入Transformer Block之前做一个位置编码:

Transformer论文中,使用正余弦函数表示绝对位置,通过两者乘积得到相对位置。因为正余弦函数具有周期性,可以很好地表示序列中单词的相对位置。比较经典的位置编码用公式表达就是:

其中,i表示token在序列中的位置,设句子长度为 L,则i=0,.......,L-1 。p是token的位置向量,p(i,2t)表示这个位置向量里的第t个元素,t表示奇数维度,2t表示偶数维度;d表示token的维度。

除了绝对编码,还有一种相对编码,相对位置编码是根据单词之间的相对位置关系来计算位置编码。这种编码方式更加灵活,能够捕捉到不同单词之间的相对位置信息,有助于模型更好地理解序列中单词之间的关系。但是也有缺点,计算效率低下,同时大部分相对编码都没有落地可行性。

RoPE(Rotary Position Embedding)旋转位置编码,由模型 RoFormer: Enhanced Transformer with Rotary Position Embedding 提出。RoPE 的核心思想是将位置编码与词向量通过旋转矩阵相乘,使得词向量不仅包含词汇的语义信息,还融入了位置信息,其具有以下优点:

相对位置感知:使用绝对位置编码来达到相对位置编码的效果,RoPE 能够自然地捕捉词汇之间的相对位置关系。

无需额外的计算:位置编码与词向量的结合在计算上是高效的。

适应不同长度的序列:RoPE 可以灵活处理不同长度的输入序列。

具体如何做到呢?,这里面涉及了很多数学上推理,大家可以看一下
https://spaces.ac.cn/archives/8130,我这里只做个简单介绍一下:RoPE 借助了复数的思想,出发点是通过绝对位置编码的方式实现相对位置编码。根据复数乘法的几何意义,上述变换实际上是对应向量旋转,所以位置向量称为“旋转式位置编 码”。本质还是利用绝对位置编码,通过内积方式,得到相对位置表达式。

根据内积满足线性叠加的性质,任意偶数维的 RoPE,都可以表示为二维情形的拼接,即:

如果放在二维空间维度来看,可以用极坐标来理解, 旋转角度不影响轴长度:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis




def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)




def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)
    
    # 在attention 模块利用
    xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

2.4 GQA

在说这个之前,我们先回顾一下Transformer的多头注意力机制。

在 Transformer 中,注意力模块会并行多次重复计算。每个并行计算称为一个注意力头(Attention Head)。注意力模块将其查询 Query 、键 Key和值 Value的参数矩阵进行 N 次拆分,并将每次拆分独立通过一个单独的注意力头。最后,所有这些相同的注意力计算会合并在一起,产生最终的注意力分数。能够更细致地捕捉并表达每个词汇之间的多种联系和微妙差异。

  • MHA(Multi Head Attention) 中,每个头有自己单独的 key-value 对;标准的多头注意力机制,h个Query、Key 和 Value 矩阵。
  • MQA(Multi Query Attention) 中只会有一组 key-value 对;多查询注意力的一种变体,也是用于自回归解码的一种注意力机制。与MHA不同的是,MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。
  • GQA(Grouped Query Attention)中,会对 attention 进行分组操作,query 被分为 N 组,每个组共享一个 Key 和 Value 矩阵GQA将查询头分成G组,每个组共享一个Key 和 Value 矩阵。GQA-G是指具有G组的grouped-query attention。GQA-1具有单个组,因此具有单个Key 和 Value,等效于MQA。而GQA-H具有与头数相等的组,等效于MHA。

GQA介于MHA和MQA之间。GQA 综合 MHA 和 MQA ,既不损失太多性能,又能利用 MQA 的推理加速。不是所有 Q 头共享一组 KV,而是分组一定头数 Q 共享一组 KV,比如上图中就是两组 Q 共享一组 KV。现在LLaMA3基本都使用了GQA结构。

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )
class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads


        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
        )


        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()
        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()


    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        # 计算 q、k 、v
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)


        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        # 加上Rope位置旋转
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)


        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)
        # kv 缓存
        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv


        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]
        # GQA应用
        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(
            keys, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)
        values = repeat_kv(
            values, self.n_rep
        )  # (bs, cache_len + seqlen, n_local_heads, head_dim)


        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        values = values.transpose(
            1, 2
        )  # (bs, n_local_heads, cache_len + seqlen, head_dim)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)

03 总结

Python的完整的LLaMa3代码在github可以快速找到,其核心代码也不过几百行,但其中的设计思想和理念,够我们这些小白喝一段时间,希望通过不断深入学习,提高对LLM实际的理解。通过记录所学的知识,建立自己的系统性思维。

相关推荐

其实TensorFlow真的很水无非就这30篇熬夜练

好的!以下是TensorFlow需要掌握的核心内容,用列表形式呈现,简洁清晰(含表情符号,<300字):1.基础概念与环境TensorFlow架构(计算图、会话->EagerE...

交叉验证和超参数调整:如何优化你的机器学习模型

准确预测Fitbit的睡眠得分在本文的前两部分中,我获取了Fitbit的睡眠数据并对其进行预处理,将这些数据分为训练集、验证集和测试集,除此之外,我还训练了三种不同的机器学习模型并比较了它们的性能。在...

机器学习交叉验证全指南:原理、类型与实战技巧

机器学习模型常常需要大量数据,但它们如何与实时新数据协同工作也同样关键。交叉验证是一种通过将数据集分成若干部分、在部分数据上训练模型、在其余数据上测试模型的方法,用来检验模型的表现。这有助于发现过拟合...

深度学习中的类别激活热图可视化

作者:ValentinaAlto编译:ronghuaiyang导读使用Keras实现图像分类中的激活热图的可视化,帮助更有针对性...

超强,必会的机器学习评估指标

大侠幸会,在下全网同名[算法金]0基础转AI上岸,多个算法赛Top[日更万日,让更多人享受智能乐趣]构建机器学习模型的关键步骤是检查其性能,这是通过使用验证指标来完成的。选择正确的验证指...

机器学习入门教程-第六课:监督学习与非监督学习

1.回顾与引入上节课我们谈到了机器学习的一些实战技巧,比如如何处理数据、选择模型以及调整参数。今天,我们将更深入地探讨机器学习的两大类:监督学习和非监督学习。2.监督学习监督学习就像是有老师的教学...

Python教程(三十八):机器学习基础

...

Python 模型部署不用愁!容器化实战,5 分钟搞定环境配置

你是不是也遇到过这种糟心事:花了好几天训练出的Python模型,在自己电脑上跑得顺顺当当,一放到服务器就各种报错。要么是Python版本不对,要么是依赖库冲突,折腾半天还是用不了。别再喊“我...

超全面讲透一个算法模型,高斯核!!

...

神经网络与传统统计方法的简单对比

传统的统计方法如...

AI 基础知识从0.1到0.2——用“房价预测”入门机器学习全流程

...

自回归滞后模型进行多变量时间序列预测

下图显示了关于不同类型葡萄酒销量的月度多元时间序列。每种葡萄酒类型都是时间序列中的一个变量。假设要预测其中一个变量。比如,sparklingwine。如何建立一个模型来进行预测呢?一种常见的方...

苹果AI策略:慢哲学——科技行业的“长期主义”试金石

苹果AI策略的深度原创分析,结合技术伦理、商业逻辑与行业博弈,揭示其“慢哲学”背后的战略智慧:一、反常之举:AI狂潮中的“逆行者”当科技巨头深陷AI军备竞赛,苹果的克制显得格格不入:功能延期:App...

时间序列预测全攻略,6大模型代码实操

如果你对数据分析感兴趣,希望学习更多的方法论,希望听听经验分享,欢迎移步宝藏公众号...

AI 基础知识从 0.4 到 0.5—— 计算机视觉之光 CNN

...

取消回复欢迎 发表评论: