深度探索:通过图表和示例来演示什么是循环神经网络
ztj100 2025-04-26 22:45 3 浏览 0 评论
许多问题和现象都是基于顺序的。常见的例子包括语音、天气模式和时间序列。这些系统的下一个位置取决于之前的状态。
不幸的是,传统的神经网络无法处理或预测此类数据,因为它们单独分析输入。他们不知道数据确实是连续的。
那么,我们如何预测这类数据呢?
好吧,我们转向称为循环神经网络的东西!
注意:存在使传统神经网络能够处理序列数据的技术方法和技术。但这就像试图把大象塞进鞋盒一样——根本行不通!
什么是循环神经网络?
下图展示了循环神经网络(RNN):
RNN 的示例架构。作者绘制的图表。
左侧是循环神经元,右侧是随时间展开的循环神经元。 RNN 看起来类似于普通的前馈神经网络,除了它接收来自先前向后执行的输入的关键区别之外。
这就是为什么它们被称为“循环”,因为每个步骤的输出都会及时传播,以帮助计算下一步的值。系统中有一些固有的“记忆”,可以帮助模型获取历史模式。
例如,在预测Y_1时,它将使用X_1的输入加上Y_0的上一个时间步的输出。由于Y_0影响Y_1,我们可以看到Y_0也会间接影响Y_2,生动地展示了该算法的循环性质。
隐藏状态
在文献中,您通常会看到隐藏状态的概念,通常用通过循环神经元传递的h表示。
显示具有隐藏状态的 RNN 示例架构。作者绘制的图表。
在简单的情况下,隐藏状态只是单元的输出,因此h=Y。然而,正如我们将在后面的文章中看到的,只有在更复杂的单元(例如长期短记忆 (LSTM)和门控循环单元 GRU)中,这种情况有时才成立。
因此,最好明确我们通过并进入每个神经元的内容,这就是为什么它在大多数文献中都像上面那样显示。
理论
循环神经元的每个隐藏状态可以计算如下:
循环神经元隐藏状态方程。由作者在 LaTeX 中编写。
在哪里:
- h_t是时间t 的隐藏状态。
- h_{t-1}是上一个时间步的隐藏状态。
- x_t是时间t 的输入数据。
- W_h是隐藏状态的权重矩阵。
- W_x是 输入数据的权重矩阵。
- b_h是隐藏状态的偏差向量。
- σ是激活函数,通常为 tanh 或 sigmoid。
注意:这些值可以是标量,但在大多数实际应用中通常是向量;因此,它们被这样表示。
然后每个循环神经元的预测输出为:
循环神经元输出方程。由作者在 LaTeX 中编写。
在哪里:
- y_t是时间t 的输出。
- W_y是与输出相关的权重矩阵。
- b_y是输出偏置向量。
正如您所看到的,许多符号和变量与常规前馈神经网络类似。唯一的区别是隐藏状态的传递,它可以被视为模型的另一个输入或特征,用于预测输出。
每个隐藏层可以包含多个循环神经元,因此我们将隐藏状态向量传递给每个后续输入神经元。这使得网络能够捕获并表示数据中更复杂的模式。您可以将其想象为每个时间步内的迷你神经网络。
工作示例
我们可以回顾一个简单的例子来解释RNN 内部到底发生了什么。这将是一个非常简单的场景,但它将说明您需要了解的主要直觉。事实上,现实生活中没有任何问题会这么简单!
设置
假设我们有一个数字 1、2 和 3 的序列,我们想要训练一个 RNN 来预测序列中的下一个数字,即 4。
我们的 RNN 将具有以下架构:
- 1个输入神经元
- 1个隐藏神经元
- 1个输出神经元
我们可以随机初始化权重和偏差:
- W_x (隐藏权重的输入):0.5
- W_h(隐藏到隐藏权重):1.0
- b_h(隐藏偏差):0
- _(输出偏差):0
并使用以下激活函数:
- 隐藏层:tanh
- 输出层:无(恒等/线性)
初始隐藏状态值:
- h_0 = 0
时间步长 1(输入:1)
第一个隐藏状态计算如下:
第一次隐藏状态更新。由作者在 LaTeX 中编写。
然后输出计算如下:
第一个输出状态。由作者在 LaTeX 中编写。
在这个例子中,输出激活函数是恒等的,因此输出值与隐藏状态值相同。但是,请记住,在许多问题中情况并非总是如此。
时间步长 2(输入:2)
现在,我们可以使用最近计算的h_1值在时间步 2 处对下一个输入值重复上述过程:
第二次隐藏状态更新。由作者在 LaTeX 中编写。
我们再次计算第 2 步的输出值:
第二输出状态。由作者在 LaTeX 中编写。
时间步长 3(输入:3)
最后,对于最后一个输入值和第三个时间步长:
第三次隐藏状态更新。由作者在 LaTeX 中编写。
第三输出状态。由作者在 LaTeX 中编写。
因此,当前模型预测接下来的数字为 0.984,这显然与实际值 4 相距甚远。实际上,我们将拥有更广泛的训练集并随着时间执行反向传播来优化我们的参数。这将在我的下一篇文章中介绍!
幸运的是,所有这些计算和优化都是通过 PyTorch 和 TensorFlow 等软件包在 Python 中完成的。我将在本文后面展示如何执行此操作的示例!
RNN 的类型
上面的例子说明了多对一RNN的逻辑过程。我们从多个输入 (1,2,3) 开始,旨在预测序列中的下一个数字,即单个值。
然而,还有其他类型的 RNN 可用于不同的任务,我们现在将介绍它们。
一对一
这只是一个传统的神经网络,具有一组输入,可给出单个预测。它有助于解决一般的监督机器学习问题。
一对一循环神经网络。作者绘制的图表。
一对多
单个输入导致多个输出。这可用于生成图像标题和生成音乐。
一对多循环神经网络。作者绘制的图表。
多对一
多个输入生成一个最终输出;情感分析是用于该架构的一个示例。你给它一个电影评论,如果电影好或坏,它会分别分配+1或-1。
多对一循环神经网络。作者绘制的图表。
多对多
这个在每一步都会得到一个输入,并在每一步产生一个输出。该架构用于机器翻译以及语音标记等问题。
多对多循环神经网络。作者绘制的图表。
编码器-解码器
最后,您可以拥有一个编码器-解码器网络。这是一个多对一网络,然后是一对多网络。它通常用于将句子从一种语言翻译成另一种语言。
编码器-解码器模型是用于创建 LLM 的转换器模型背后的基础。
编码器-解码器递归神经网络。作者绘制的图表。
PyTorch 示例
下面是在 PyTorch 中实现简单 RNN 的简单示例。它经历了我们上面解决的问题,我们输入了 1,2,3 并想要预测序列中的后续数字。
import torch
import torch.nn as nn
import torch.optim as optim
# RNN Model Definition
class SimpleRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleRNN, self).__init__()
self.hidden_size = hidden_size
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = x.unsqueeze(-1)
h_0 = torch.zeros(1, x.size(0), self.hidden_size)
rnn_out, _ = self.rnn(x, h_0)
out = self.fc(rnn_out[:, -1, :])
return out
# Dataset
train = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
target = torch.tensor([5], dtype=torch.float32)
# Model Configuration
input_size = 1
hidden_size = 1
output_size = 1
model = SimpleRNN(input_size, hidden_size, output_size)
# Loss and Optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
for epoch in range(1000):
optimizer.zero_grad()
output = model(train.unsqueeze(0)).squeeze() # Add batch dimension and squeeze to match target shape
loss = criterion(output, target)
loss.backward()
optimizer.step()
# Function to Predict Next Number
def predict(model, input_seq):
with torch.no_grad():
input_seq = torch.tensor(input_seq, dtype=torch.float32).unsqueeze(0)
output = model(input_seq).squeeze().item()
return output
# Example Test Set
test = [2, 3, 4]
predicted = predict(model, test)
print(f'Input: {test}, Predicted Next Number: {predicted:.2f}')
运行这个,1000 个 epoch 后我们的输出是 5!显然,在这种情况下,模型实际上是通过反向传播训练了 1000 次,这就是为什么它的性能比我们上面手工计算的示例要好得多。
如果您有兴趣,可以在我的 GitHub 上找到该代码:
Medium-Articles/Neural Networks/rnn_example.py at main · egorhowell/Medium-Articles
优点与缺点
有了所有这些新获得的信息,让我们来看看 RNN 的主要优点和缺点:
优点
- 它们具有来自先前输入的记忆形式,这使得它们有助于处理基于序列的数据。
- 确切的权重和偏差在所有时间步长之间共享,从而减少参数并获得更好的泛化能力。
- 由于其递归性质,RNN 可以处理可变长度的序列。
缺点
- 他们严重遭受梯度消失问题的困扰,从而导致长期记忆问题。
- 每个时间步长都取决于前一步的输出,这使得 RNN 的计算效率低下,因为它们无法并行化。
概括
RNN 对于序列建模非常有用,因为它们保留先前执行的信息和内存,然后传播到下一个预测。它们的优点是可以处理任意长度的输入,并且模型大小不会随着输入大小的增加而增加。然而,由于它们具有递归性质,因此无法并行化,因此它们的计算效率不高,并且严重遭受梯度消失问题。
相关推荐
- 如何将数据仓库迁移到阿里云 AnalyticDB for PostgreSQL
-
阿里云AnalyticDBforPostgreSQL(以下简称ADBPG,即原HybridDBforPostgreSQL)为基于PostgreSQL内核的MPP架构的实时数据仓库服务,可以...
- Python数据分析:探索性分析
-
写在前面如果你忘记了前面的文章,可以看看加深印象:Python数据处理...
- C++基础语法梳理:算法丨十大排序算法(二)
-
本期是C++基础语法分享的第十六节,今天给大家来梳理一下十大排序算法后五个!归并排序...
- C 语言的标准库有哪些
-
C语言的标准库并不是一个单一的实体,而是由一系列头文件(headerfiles)组成的集合。每个头文件声明了一组相关的函数、宏、类型和常量。程序员通过在代码中使用#include<...
- [深度学习] ncnn安装和调用基础教程
-
1介绍ncnn是腾讯开发的一个为手机端极致优化的高性能神经网络前向计算框架,无第三方依赖,跨平台,但是通常都需要protobuf和opencv。ncnn目前已在腾讯多款应用中使用,如QQ,Qzon...
- 用rust实现经典的冒泡排序和快速排序
-
1.假设待排序数组如下letmutarr=[5,3,8,4,2,7,1];...
- ncnn+PPYOLOv2首次结合!全网最详细代码解读来了
-
编辑:好困LRS【新智元导读】今天给大家安利一个宝藏仓库miemiedetection,该仓库集合了PPYOLO、PPYOLOv2、PPYOLOE三个算法pytorch实现三合一,其中的PPYOL...
- C++特性使用建议
-
1.引用参数使用引用替代指针且所有不变的引用参数必须加上const。在C语言中,如果函数需要修改变量的值,参数必须为指针,如...
- Qt4/5升级到Qt6吐血经验总结V202308
-
00:直观总结增加了很多轮子,同时原有模块拆分的也更细致,估计为了方便拓展个管理。把一些过度封装的东西移除了(比如同样的功能有多个函数),保证了只有一个函数执行该功能。把一些Qt5中兼容Qt4的方法废...
- 到底什么是C++11新特性,请看下文
-
C++11是一个比较大的更新,引入了很多新特性,以下是对这些特性的详细解释,帮助您快速理解C++11的内容1.自动类型推导(auto和decltype)...
- 掌握C++11这些特性,代码简洁性、安全性和性能轻松跃升!
-
C++11(又称C++0x)是C++编程语言的一次重大更新,引入了许多新特性,显著提升了代码简洁性、安全性和性能。以下是主要特性的分类介绍及示例:一、核心语言特性1.自动类型推导(auto)编译器自...
- 经典算法——凸包算法
-
凸包算法(ConvexHull)一、概念与问题描述凸包是指在平面上给定一组点,找到包含这些点的最小面积或最小周长的凸多边形。这个多边形没有任何内凹部分,即从一个多边形内的任意一点画一条线到多边形边界...
- 一起学习c++11——c++11中的新增的容器
-
c++11新增的容器1:array当时的初衷是希望提供一个在栈上分配的,定长数组,而且可以使用stl中的模板算法。array的用法如下:#include<string>#includ...
- C++ 编程中的一些最佳实践
-
1.遵循代码简洁原则尽量避免冗余代码,通过模块化设计、清晰的命名和良好的结构,让代码更易于阅读和维护...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)