TimeDART:基于扩散自回归Transformer 的自监督时间序列预测方法
ztj100 2025-01-01 23:50 24 浏览 0 评论
近年来,随着机器学习技术的进步,深度神经网络已经成为解决时间序列预测问题的主流方法。这反映了学术界和工业界在利用先进技术处理序列数据复杂性方面的持续努力。
自监督学习概述
基本定义
自监督学习是一种创新的学习范式,其特点是模型能够从未标记数据中通过内部生成的监督信号进行学习,通常这种学习通过预文任务来实现。与传统的监督学习不同,自监督学习不需要外部标签,而是利用数据本身的内在结构来创建必要的学习信号。
在时间序列领域的应用
在时间序列分析领域,自监督学习展现出独特的优势。它使得模型能够:
- 从未标记数据中学习通用表示
- 同时捕获数据中的长期依赖关系和局部细节特征
然而,这种学习方式仍面临着显著的挑战,这也是为什么需要像TimeDART这样的创新方法。通过集成扩散和自回归建模,TimeDART旨在解决这些根本性的挑战。
现有方法的问题
时间序列预测面临两个主要挑战:
全局依赖关系捕获:
需要有效理解和建模长期时间依赖;传统方法难以准确捕获序列中的全局模式
局部特征提取:
需要精确捕获时间序列中的局部细节特征;现有方法在同时处理这两个任务时表现不佳
这些挑战严重影响了模型学习全面和富有表现力的时间序列数据表示的能力。
TimeDarT方法详解
TimeDART是一种专为时间序列预测设计的自监督学习方法。它的核心思想是通过从时间序列历史数据中学习模式来改进未来数据点的预测。研究者采用了一种创新的方法,将时间序列数据分解成更小的片段(patches),并将这些patches作为建模的基本单位。
核心技术组件
- Transformer编码器设计:
- 使用了具有自注意力机制的Transformer编码器专注于理解patches之间的依赖关系有效捕获数据的整体序列结构
- 扩散和去噪过程:
- 实现了两个关键过程:扩散和去噪通过向数据添加和移除噪声来捕获局部特征这是所有扩散模型中的典型过程提升了模型在详细模式上的表现
TimeDART架构详解
TimeDART架构图展示了模型如何:
- 使用自回归生成捕获全局依赖关系
- 通过去噪扩散模型处理局部结构
- 在前向扩散过程中向输入patches引入噪声
- 生成自监督信号
- 通过自回归方式在反向过程中恢复原始序列
实例归一化和Patch嵌入
这一阶段包含几个关键步骤:
- 实例归一化:
- 对输入的多变量时间序列数据进行标准化确保每个实例具有零均值和单位标准差目的是保持最终预测的一致性
- 数据分割策略:
- 将时间序列数据划分为patches而非单个点这种方法能够捕获更全面的局部信息
- 避免信息泄漏:
- patch长度设置为等于stride(步长)确保每个patch包含原始序列的非重叠段防止训练过程中的信息泄漏
Transformer编码器中的Patch间依赖关系
在架构中,研究者实现了以下关键特性:
- 基于自注意力的处理:
- 使用自注意力的Transformer编码器专门用于建模patches之间的依赖关系
- 全局依赖性捕获:
- 通过考虑时间序列数据中不同patches之间的关系有效捕获全局序列依赖关系
- 表示学习:
- Transformer编码器能够学习有意义的patch间表示这对于理解时间序列的高层结构至关重要
class TransformerEncoderBlock(nn.Module):
def __init__(
self, d_model: int, num_heads: int, feedforward_dim: int, dropout: float
):
super(TransformerEncoderBlock, self).__init__()
self.attention = nn.MultiheadAttention(
embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True
)
self.norm1 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, feedforward_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(feedforward_dim, d_model),
)
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=feedforward_dim, kernel_size=1)
self.activation = nn.GELU()
self.conv2 = nn.Conv1d(in_channels=feedforward_dim, out_channels=d_model, kernel_size=1)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
"""
:param x: [batch_size * num_features, seq_len, d_model]
:param mask: [1, 1, seq_len, seq_len]
:return: [batch_size * num_features, seq_len, d_model]
"""
# Self-attention
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed-forward network
# y = self.dropout(self.activation(self.conv1(y.permute(0, 2, 1))))
# ff_output = self.conv2(y).permute(0, 2, 1)
ff_output = self.ff(x)
output = self.norm2(x + self.dropout(ff_output))
return output
前向扩散过程
前向扩散过程的主要特点:
- 噪声应用:
- 在输入patches上应用噪声生成自监督信号通过从带噪声版本中重构原始数据来学习稳健的表示
- 模式识别:
- 噪声帮助模型识别和关注专注于时间序列数据中的内在模式
class Diffusion(nn.Module):
def __init__(
self,
time_steps: int,
device: torch.device,
scheduler: str = "cosine",
):
super(Diffusion, self).__init__()
self.device = device
self.time_steps = time_steps
if scheduler == "cosine":
self.betas = self._cosine_beta_schedule().to(self.device)
elif scheduler == "linear":
self.betas = self._linear_beta_schedule().to(self.device)
else:
raise ValueError(f"Invalid scheduler: {scheduler=}")
self.alpha = 1 - self.betas
self.gamma = torch.cumprod(self.alpha, dim=0).to(self.device)
def _cosine_beta_schedule(self, s=0.008):
steps = self.time_steps + 1
x = torch.linspace(0, self.time_steps, steps)
alphas_cumprod = (
torch.cos(((x / self.time_steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
)
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
def _linear_beta_schedule(self, beta_start=1e-4, beta_end=0.02):
betas = torch.linspace(beta_start, beta_end, self.time_steps)
return betas
def sample_time_steps(self, shape):
return torch.randint(0, self.time_steps, shape, device=self.device)
def noise(self, x, t):
noise = torch.randn_like(x)
gamma_t = self.gamma[t].unsqueeze(-1) # [batch_size * num_features, seq_len, 1]
# x_t = sqrt(gamma_t) * x + sqrt(1 - gamma_t) * noise
noisy_x = torch.sqrt(gamma_t) * x + torch.sqrt(1 - gamma_t) * noise
return noisy_x, noise
def forward(self, x):
# x: [batch_size * num_features, seq_len, patch_len]
t = self.sample_time_steps(x.shape[:2]) # [batch_size * num_features, seq_len]
noisy_x, noise = self.noise(x, t)
return noisy_x, noise, t
基于交叉注意力的去噪解码器
该解码器具有以下特点:
- 核心功能:
- 使用交叉注意力机制目的是重构原始的、无噪声的patches
- 优化设计:
- 允许可调整的优化难度使自监督任务更有效使模型能够专注于捕获详细的patch内特征
解码器的工作机制:
- 接收噪声(作为查询)和编码器的输出(键和值)
- 使用掩码确保第j个噪声输入对应于Transformer编码器的第j个输出
class TransformerDecoderBlock(nn.Module):
def __init__(
self, d_model: int, num_heads: int, feedforward_dim: int, dropout: float
):
super(TransformerDecoderBlock, self).__init__()
self.self_attention = nn.MultiheadAttention(
embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True
)
self.norm1 = nn.LayerNorm(d_model)
self.encoder_attention = nn.MultiheadAttention(
embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True
)
self.norm2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, feedforward_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(feedforward_dim, d_model),
)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, tgt_mask, src_mask):
"""
:param query: [batch_size * num_features, seq_len, d_model]
:param key: [batch_size * num_features, seq_len, d_model]
:param value: [batch_size * num_features, seq_len, d_model]
:param mask: [1, 1, seq_len, seq_len]
:return: [batch_size * num_features, seq_len, d_model]
"""
# Self-attention
attn_output, _ = self.self_attention(query, query, query, attn_mask=tgt_mask)
query = self.norm1(query + self.dropout(attn_output))
# Encoder attention
attn_output, _ = self.encoder_attention(query, key, value, attn_mask=src_mask)
query = self.norm2(query + self.dropout(attn_output))
# Feed-forward network
ff_output = self.ff(query)
x = self.norm3(query + self.dropout(ff_output))
return x
用于全局依赖关系的自回归生成
自回归生成的主要职责:
- 高层依赖捕获:
- 捕获时间序列中的高层全局依赖关系通过自回归方式恢复原始序列使模型能够理解整体时间模式和依赖关系显著提升预测能力
class DenoisingPatchDecoder(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
num_layers: int,
feedforward_dim: int,
dropout: float,
):
super(DenoisingPatchDecoder, self).__init__()
self.layers = nn.ModuleList(
[
TransformerDecoderBlock(d_model, num_heads, feedforward_dim, dropout)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(d_model)
def forward(self, query, key, value, is_tgt_mask=True, is_src_mask=True):
seq_len = query.size(1)
tgt_mask = (
generate_self_only_mask(seq_len).to(query.device) if is_tgt_mask else None
)
src_mask = (
generate_self_only_mask(seq_len).to(query.device) if is_src_mask else None
)
for layer in self.layers:
query = layer(query, key, value, tgt_mask, src_mask)
x = self.norm(query)
return x
class ForecastingHead(nn.Module):
def __init__(
self,
seq_len: int,
d_model: int,
pred_len: int,
dropout: float,
):
super(ForecastingHead, self).__init__()
self.pred_len = pred_len
self.flatten = nn.Flatten(start_dim=-2)
self.forecast_head = nn.Linear(seq_len * d_model, pred_len)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
:param x: [batch_size, num_features, seq_len, d_model]
:return: [batch_size, pred_len, num_features]
"""
x = self.flatten(x) # (batch_size, num_features, seq_len * d_model)
x = self.forecast_head(x) # (batch_size, num_features, pred_len)
x = self.dropout(x) # (batch_size, num_features, pred_len)
x = x.permute(0, 2, 1) # (batch_size, pred_len, num_features)
return x
优化和微调
优化过程的关键特点:
- 自回归优化:
- 整个模型以自回归方式进行优化获得可以针对特定预测任务进行微调的可迁移表示
- 表示特性:
- 确保模型学习的表示既全面又适应性强能够适应各种下游应用在时间序列预测中实现卓越性能
实验评估
数据集介绍
实验使用了八个广泛使用的数据集:
- ETT数据集系列:
- ETTh1、ETTh2、ETTm1、ETTm2四个子集代表能源领域的时间序列数据
- 其他领域数据集:
- Weather数据集Exchange数据集Electricity数据集Traffic数据集
这些数据集涵盖了多个应用场景,包括电力系统、交通网络和天气预测等领域。
实验结果分析
表1展示了TimeDART与现有方法的对比结果:
- 与最先进的自监督方法和监督方法进行比较
- 最佳结果用粗体标示
- 第二好的结果带有下划线
- "#1 Counts"表示该方法达到最佳结果的次数
表2显示了TimeDART在不同设置下的性能:
- 展示了在五个数据集上预训练并在特定数据集上微调的结果
- 所有结果都是从4个不同预测窗口{96, 192, 336, 720}中平均得出
- 最好的结果用粗体标示
消融研究结果:
- 所有结果都是从4个不同预测窗口{96, 192, 336, 720}中平均得出
- 最好的结果用粗体标示
超参数敏感性分析
前向过程参数
- 噪声步数T的影响:测试了{750, 1000, 1250}三个设置发现噪声步数对预训练难度影响不大所有设置都优于随机初始化
- 噪声调度器的选择:余弦调度器显著优于线性调度器某些情况下,线性调度器甚至导致性能低于随机初始化证实了平滑噪声添加的重要性
去噪patch解码器层数
- 测试了{0, 1, 2, 3}层配置
- 单层解码器通常提供最佳的模型复杂度和准确性平衡
- 过多的层数可能导致表示网络的训练不足
patch长度的影响
- 测试了{1, 2, 4, 8, 16}不同长度
- 最佳patch长度取决于数据集特征
- 较大的patch长度可能更适合具有高冗余性的数据集
总结
TimeDART通过创新性地结合扩散模型和自回归建模,成功解决了时间序列预测中的关键挑战:
- 技术创新:
- 首次将扩散和自回归建模统一到单一框架设计了灵活的交叉注意力去噪网络
- 性能提升:
- 在多个数据集上实现了最优性能展示了强大的域内和跨域泛化能力
- 实际意义:
- 为时间序列预测提供了新的研究方向为实际应用提供了更可靠的预测工具
TimeDART的成功表明,结合不同的生成方法可以有效提升时间序列预测的性能,为该领域的进一步研究提供了新的思路。
相关推荐
- sharding-jdbc实现`分库分表`与`读写分离`
-
一、前言本文将基于以下环境整合...
- 三分钟了解mysql中主键、外键、非空、唯一、默认约束是什么
-
在数据库中,数据表是数据库中最重要、最基本的操作对象,是数据存储的基本单位。数据表被定义为列的集合,数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录,每一列代表记录中的一个域。...
- MySQL8行级锁_mysql如何加行级锁
-
MySQL8行级锁版本:8.0.34基本概念...
- mysql使用小技巧_mysql使用入门
-
1、MySQL中有许多很实用的函数,好好利用它们可以省去很多时间:group_concat()将取到的值用逗号连接,可以这么用:selectgroup_concat(distinctid)fr...
- MySQL/MariaDB中如何支持全部的Unicode?
-
永远不要在MySQL中使用utf8,并且始终使用utf8mb4。utf8mb4介绍MySQL/MariaDB中,utf8字符集并不是对Unicode的真正实现,即不是真正的UTF-8编码,因...
- 聊聊 MySQL Server 可执行注释,你懂了吗?
-
前言MySQLServer当前支持如下3种注释风格:...
- MySQL系列-源码编译安装(v5.7.34)
-
一、系统环境要求...
- MySQL的锁就锁住我啦!与腾讯大佬的技术交谈,是我小看它了
-
对酒当歌,人生几何!朝朝暮暮,唯有己脱。苦苦寻觅找工作之间,殊不知今日之事乃我心之痛,难道是我不配拥有工作嘛。自面试后他所谓的等待都过去一段时日,可惜在下京东上的小金库都要见低啦。每每想到不由心中一...
- MySQL字符问题_mysql中字符串的位置
-
中文写入乱码问题:我输入的中文编码是urf8的,建的库是urf8的,但是插入mysql总是乱码,一堆"???????????????????????"我用的是ibatis,终于找到原因了,我是这么解决...
- 深圳尚学堂:mysql基本sql语句大全(三)
-
数据开发-经典1.按姓氏笔画排序:Select*FromTableNameOrderByCustomerNameCollateChinese_PRC_Stroke_ci_as//从少...
- MySQL进行行级锁的?一会next-key锁,一会间隙锁,一会记录锁?
-
大家好,是不是很多人都对MySQL加行级锁的规则搞的迷迷糊糊,一会是next-key锁,一会是间隙锁,一会又是记录锁。坦白说,确实还挺复杂的,但是好在我找点了点规律,也知道如何如何用命令分析加...
- 一文讲清怎么利用Python Django实现Excel数据表的导入导出功能
-
摘要:Python作为一门简单易学且功能强大的编程语言,广受程序员、数据分析师和AI工程师的青睐。本文系统讲解了如何使用Python的Django框架结合openpyxl库实现Excel...
- 用DataX实现两个MySQL实例间的数据同步
-
DataXDataX使用Java实现。如果可以实现数据库实例之间准实时的...
- MySQL数据库知识_mysql数据库基础知识
-
MySQL是一种关系型数据库管理系统;那废话不多说,直接上自己以前学习整理文档:查看数据库命令:(1).查看存储过程状态:showprocedurestatus;(2).显示系统变量:show...
- 如何为MySQL中的JSON字段设置索引
-
背景MySQL在2015年中发布的5.7.8版本中首次引入了JSON数据类型。自此,它成了一种逃离严格列定义的方式,可以存储各种形状和大小的JSON文档,例如审计日志、配置信息、第三方数据包、用户自定...
你 发表评论:
欢迎- 一周热门
-
-
MySQL中这14个小玩意,让人眼前一亮!
-
旗舰机新标杆 OPPO Find X2系列正式发布 售价5499元起
-
【VueTorrent】一款吊炸天的qBittorrent主题,人人都可用
-
面试官:使用int类型做加减操作,是线程安全吗
-
C++编程知识:ToString()字符串转换你用正确了吗?
-
【Spring Boot】WebSocket 的 6 种集成方式
-
PyTorch 深度学习实战(26):多目标强化学习Multi-Objective RL
-
pytorch中的 scatter_()函数使用和详解
-
与 Java 17 相比,Java 21 究竟有多快?
-
基于TensorRT_LLM的大模型推理加速与OpenAI兼容服务优化
-
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)