TimeDART:基于扩散自回归Transformer 的自监督时间序列预测方法
ztj100 2025-01-01 23:50 19 浏览 0 评论
近年来,随着机器学习技术的进步,深度神经网络已经成为解决时间序列预测问题的主流方法。这反映了学术界和工业界在利用先进技术处理序列数据复杂性方面的持续努力。
自监督学习概述
基本定义
自监督学习是一种创新的学习范式,其特点是模型能够从未标记数据中通过内部生成的监督信号进行学习,通常这种学习通过预文任务来实现。与传统的监督学习不同,自监督学习不需要外部标签,而是利用数据本身的内在结构来创建必要的学习信号。
在时间序列领域的应用
在时间序列分析领域,自监督学习展现出独特的优势。它使得模型能够:
- 从未标记数据中学习通用表示
- 同时捕获数据中的长期依赖关系和局部细节特征
然而,这种学习方式仍面临着显著的挑战,这也是为什么需要像TimeDART这样的创新方法。通过集成扩散和自回归建模,TimeDART旨在解决这些根本性的挑战。
现有方法的问题
时间序列预测面临两个主要挑战:
全局依赖关系捕获:
需要有效理解和建模长期时间依赖;传统方法难以准确捕获序列中的全局模式
局部特征提取:
需要精确捕获时间序列中的局部细节特征;现有方法在同时处理这两个任务时表现不佳
这些挑战严重影响了模型学习全面和富有表现力的时间序列数据表示的能力。
TimeDarT方法详解
TimeDART是一种专为时间序列预测设计的自监督学习方法。它的核心思想是通过从时间序列历史数据中学习模式来改进未来数据点的预测。研究者采用了一种创新的方法,将时间序列数据分解成更小的片段(patches),并将这些patches作为建模的基本单位。
核心技术组件
- Transformer编码器设计:
- 使用了具有自注意力机制的Transformer编码器专注于理解patches之间的依赖关系有效捕获数据的整体序列结构
- 扩散和去噪过程:
- 实现了两个关键过程:扩散和去噪通过向数据添加和移除噪声来捕获局部特征这是所有扩散模型中的典型过程提升了模型在详细模式上的表现
TimeDART架构详解
TimeDART架构图展示了模型如何:
- 使用自回归生成捕获全局依赖关系
- 通过去噪扩散模型处理局部结构
- 在前向扩散过程中向输入patches引入噪声
- 生成自监督信号
- 通过自回归方式在反向过程中恢复原始序列
实例归一化和Patch嵌入
这一阶段包含几个关键步骤:
- 实例归一化:
- 对输入的多变量时间序列数据进行标准化确保每个实例具有零均值和单位标准差目的是保持最终预测的一致性
- 数据分割策略:
- 将时间序列数据划分为patches而非单个点这种方法能够捕获更全面的局部信息
- 避免信息泄漏:
- patch长度设置为等于stride(步长)确保每个patch包含原始序列的非重叠段防止训练过程中的信息泄漏
Transformer编码器中的Patch间依赖关系
在架构中,研究者实现了以下关键特性:
- 基于自注意力的处理:
- 使用自注意力的Transformer编码器专门用于建模patches之间的依赖关系
- 全局依赖性捕获:
- 通过考虑时间序列数据中不同patches之间的关系有效捕获全局序列依赖关系
- 表示学习:
- Transformer编码器能够学习有意义的patch间表示这对于理解时间序列的高层结构至关重要
class TransformerEncoderBlock(nn.Module):
def __init__(
self, d_model: int, num_heads: int, feedforward_dim: int, dropout: float
):
super(TransformerEncoderBlock, self).__init__()
self.attention = nn.MultiheadAttention(
embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True
)
self.norm1 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, feedforward_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(feedforward_dim, d_model),
)
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=feedforward_dim, kernel_size=1)
self.activation = nn.GELU()
self.conv2 = nn.Conv1d(in_channels=feedforward_dim, out_channels=d_model, kernel_size=1)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
"""
:param x: [batch_size * num_features, seq_len, d_model]
:param mask: [1, 1, seq_len, seq_len]
:return: [batch_size * num_features, seq_len, d_model]
"""
# Self-attention
attn_output, _ = self.attention(x, x, x, attn_mask=mask)
x = self.norm1(x + self.dropout(attn_output))
# Feed-forward network
# y = self.dropout(self.activation(self.conv1(y.permute(0, 2, 1))))
# ff_output = self.conv2(y).permute(0, 2, 1)
ff_output = self.ff(x)
output = self.norm2(x + self.dropout(ff_output))
return output
前向扩散过程
前向扩散过程的主要特点:
- 噪声应用:
- 在输入patches上应用噪声生成自监督信号通过从带噪声版本中重构原始数据来学习稳健的表示
- 模式识别:
- 噪声帮助模型识别和关注专注于时间序列数据中的内在模式
class Diffusion(nn.Module):
def __init__(
self,
time_steps: int,
device: torch.device,
scheduler: str = "cosine",
):
super(Diffusion, self).__init__()
self.device = device
self.time_steps = time_steps
if scheduler == "cosine":
self.betas = self._cosine_beta_schedule().to(self.device)
elif scheduler == "linear":
self.betas = self._linear_beta_schedule().to(self.device)
else:
raise ValueError(f"Invalid scheduler: {scheduler=}")
self.alpha = 1 - self.betas
self.gamma = torch.cumprod(self.alpha, dim=0).to(self.device)
def _cosine_beta_schedule(self, s=0.008):
steps = self.time_steps + 1
x = torch.linspace(0, self.time_steps, steps)
alphas_cumprod = (
torch.cos(((x / self.time_steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
)
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
def _linear_beta_schedule(self, beta_start=1e-4, beta_end=0.02):
betas = torch.linspace(beta_start, beta_end, self.time_steps)
return betas
def sample_time_steps(self, shape):
return torch.randint(0, self.time_steps, shape, device=self.device)
def noise(self, x, t):
noise = torch.randn_like(x)
gamma_t = self.gamma[t].unsqueeze(-1) # [batch_size * num_features, seq_len, 1]
# x_t = sqrt(gamma_t) * x + sqrt(1 - gamma_t) * noise
noisy_x = torch.sqrt(gamma_t) * x + torch.sqrt(1 - gamma_t) * noise
return noisy_x, noise
def forward(self, x):
# x: [batch_size * num_features, seq_len, patch_len]
t = self.sample_time_steps(x.shape[:2]) # [batch_size * num_features, seq_len]
noisy_x, noise = self.noise(x, t)
return noisy_x, noise, t
基于交叉注意力的去噪解码器
该解码器具有以下特点:
- 核心功能:
- 使用交叉注意力机制目的是重构原始的、无噪声的patches
- 优化设计:
- 允许可调整的优化难度使自监督任务更有效使模型能够专注于捕获详细的patch内特征
解码器的工作机制:
- 接收噪声(作为查询)和编码器的输出(键和值)
- 使用掩码确保第j个噪声输入对应于Transformer编码器的第j个输出
class TransformerDecoderBlock(nn.Module):
def __init__(
self, d_model: int, num_heads: int, feedforward_dim: int, dropout: float
):
super(TransformerDecoderBlock, self).__init__()
self.self_attention = nn.MultiheadAttention(
embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True
)
self.norm1 = nn.LayerNorm(d_model)
self.encoder_attention = nn.MultiheadAttention(
embed_dim=d_model, num_heads=num_heads, dropout=dropout, batch_first=True
)
self.norm2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, feedforward_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(feedforward_dim, d_model),
)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, tgt_mask, src_mask):
"""
:param query: [batch_size * num_features, seq_len, d_model]
:param key: [batch_size * num_features, seq_len, d_model]
:param value: [batch_size * num_features, seq_len, d_model]
:param mask: [1, 1, seq_len, seq_len]
:return: [batch_size * num_features, seq_len, d_model]
"""
# Self-attention
attn_output, _ = self.self_attention(query, query, query, attn_mask=tgt_mask)
query = self.norm1(query + self.dropout(attn_output))
# Encoder attention
attn_output, _ = self.encoder_attention(query, key, value, attn_mask=src_mask)
query = self.norm2(query + self.dropout(attn_output))
# Feed-forward network
ff_output = self.ff(query)
x = self.norm3(query + self.dropout(ff_output))
return x
用于全局依赖关系的自回归生成
自回归生成的主要职责:
- 高层依赖捕获:
- 捕获时间序列中的高层全局依赖关系通过自回归方式恢复原始序列使模型能够理解整体时间模式和依赖关系显著提升预测能力
class DenoisingPatchDecoder(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
num_layers: int,
feedforward_dim: int,
dropout: float,
):
super(DenoisingPatchDecoder, self).__init__()
self.layers = nn.ModuleList(
[
TransformerDecoderBlock(d_model, num_heads, feedforward_dim, dropout)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(d_model)
def forward(self, query, key, value, is_tgt_mask=True, is_src_mask=True):
seq_len = query.size(1)
tgt_mask = (
generate_self_only_mask(seq_len).to(query.device) if is_tgt_mask else None
)
src_mask = (
generate_self_only_mask(seq_len).to(query.device) if is_src_mask else None
)
for layer in self.layers:
query = layer(query, key, value, tgt_mask, src_mask)
x = self.norm(query)
return x
class ForecastingHead(nn.Module):
def __init__(
self,
seq_len: int,
d_model: int,
pred_len: int,
dropout: float,
):
super(ForecastingHead, self).__init__()
self.pred_len = pred_len
self.flatten = nn.Flatten(start_dim=-2)
self.forecast_head = nn.Linear(seq_len * d_model, pred_len)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
:param x: [batch_size, num_features, seq_len, d_model]
:return: [batch_size, pred_len, num_features]
"""
x = self.flatten(x) # (batch_size, num_features, seq_len * d_model)
x = self.forecast_head(x) # (batch_size, num_features, pred_len)
x = self.dropout(x) # (batch_size, num_features, pred_len)
x = x.permute(0, 2, 1) # (batch_size, pred_len, num_features)
return x
优化和微调
优化过程的关键特点:
- 自回归优化:
- 整个模型以自回归方式进行优化获得可以针对特定预测任务进行微调的可迁移表示
- 表示特性:
- 确保模型学习的表示既全面又适应性强能够适应各种下游应用在时间序列预测中实现卓越性能
实验评估
数据集介绍
实验使用了八个广泛使用的数据集:
- ETT数据集系列:
- ETTh1、ETTh2、ETTm1、ETTm2四个子集代表能源领域的时间序列数据
- 其他领域数据集:
- Weather数据集Exchange数据集Electricity数据集Traffic数据集
这些数据集涵盖了多个应用场景,包括电力系统、交通网络和天气预测等领域。
实验结果分析
表1展示了TimeDART与现有方法的对比结果:
- 与最先进的自监督方法和监督方法进行比较
- 最佳结果用粗体标示
- 第二好的结果带有下划线
- "#1 Counts"表示该方法达到最佳结果的次数
表2显示了TimeDART在不同设置下的性能:
- 展示了在五个数据集上预训练并在特定数据集上微调的结果
- 所有结果都是从4个不同预测窗口{96, 192, 336, 720}中平均得出
- 最好的结果用粗体标示
消融研究结果:
- 所有结果都是从4个不同预测窗口{96, 192, 336, 720}中平均得出
- 最好的结果用粗体标示
超参数敏感性分析
前向过程参数
- 噪声步数T的影响:测试了{750, 1000, 1250}三个设置发现噪声步数对预训练难度影响不大所有设置都优于随机初始化
- 噪声调度器的选择:余弦调度器显著优于线性调度器某些情况下,线性调度器甚至导致性能低于随机初始化证实了平滑噪声添加的重要性
去噪patch解码器层数
- 测试了{0, 1, 2, 3}层配置
- 单层解码器通常提供最佳的模型复杂度和准确性平衡
- 过多的层数可能导致表示网络的训练不足
patch长度的影响
- 测试了{1, 2, 4, 8, 16}不同长度
- 最佳patch长度取决于数据集特征
- 较大的patch长度可能更适合具有高冗余性的数据集
总结
TimeDART通过创新性地结合扩散模型和自回归建模,成功解决了时间序列预测中的关键挑战:
- 技术创新:
- 首次将扩散和自回归建模统一到单一框架设计了灵活的交叉注意力去噪网络
- 性能提升:
- 在多个数据集上实现了最优性能展示了强大的域内和跨域泛化能力
- 实际意义:
- 为时间序列预测提供了新的研究方向为实际应用提供了更可靠的预测工具
TimeDART的成功表明,结合不同的生成方法可以有效提升时间序列预测的性能,为该领域的进一步研究提供了新的思路。
相关推荐
- SpringBoot整合SpringSecurity+JWT
-
作者|Sans_https://juejin.im/post/5da82f066fb9a04e2a73daec一.说明SpringSecurity是一个用于Java企业级应用程序的安全框架,主要包含...
- 「计算机毕设」一个精美的JAVA博客系统源码分享
-
前言大家好,我是程序员it分享师,今天给大家带来一个精美的博客系统源码!可以自己买一个便宜的云服务器,当自己的博客网站,记录一下自己学习的心得。开发技术博客系统源码基于SpringBoot,shiro...
- springboot教务管理系统+微信小程序云开发附带源码
-
今天给大家分享的程序是基于springboot的管理,前端是小程序,系统非常的nice,不管是学习还是毕设都非常的靠谱。本系统主要分为pc端后台管理和微信小程序端,pc端有三个角色:管理员、学生、教师...
- SpringBoot+LayUI后台管理系统开发脚手架
-
源码获取方式:关注,转发之后私信回复【源码】即可免费获取到!项目简介本项目本着避免重复造轮子的原则,建立一套快速开发JavaWEB项目(springboot-mini),能满足大部分后台管理系统基础开...
- Spring Boot的Security安全控制——认识SpringSecurity!
-
SpringBoot的Security安全控制在Web项目开发中,安全控制是非常重要的,不同的人配置不同的权限,这样的系统才安全。最常见的权限框架有Shiro和SpringSecurity。Shi...
- 前同事2024年接私活已入百万,都是用这几个开源的SpringBoot项目
-
前言不得不佩服SpringBoot的生态如此强大,今天给大家推荐几款优秀的后台管理系统,小伙伴们再也不用从头到尾撸一个项目了。SmartAdmin...
- 值得学习的15 个优秀开源的 Spring Boot 学习项目
-
SpringBoot算是目前Java领域最火的技术栈了,除了书呢?当然就是开源项目了,今天整理15个开源领域非常不错的SpringBoot项目供大家学习,参考。高富帅的路上只能帮你到这里了,...
- 开发企业官网就用这个基于SpringBoot的CMS系统,真香
-
前言推荐这个项目是因为使用手册部署手册非常...
- 2021年超详细的java学习路线总结—纯干货分享
-
本文整理了java开发的学习路线和相关的学习资源,非常适合零基础入门java的同学,希望大家在学习的时候,能够节省时间。纯干货,良心推荐!第一阶段:Java基础...
- jeecg-boot学习总结及使用心得(jeecgboot简单吗)
-
jeecg-boot学习总结及使用心得1.jeecg-boot是一个真正前后端分离的模版项目,便于二次开发,使用的都是较流行的新技术,后端技术主要有spring-boot2.x、shiro、Myb...
- 后勤集团原料管理系统springboot+Layui+MybatisPlus+Shiro源代码
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目描述后勤集团原料管理系统spr...
- 白卷开源SpringBoot+Vue的前后端分离入门项目
-
简介白卷是一个简单的前后端分离项目,主要采用Vue.js+SpringBoot技术栈开发。除了用作入门练习,作者还希望该项目可以作为一些常见Web项目的脚手架,帮助大家简化搭建网站的流程。...
- Spring Security 自动踢掉前一个登录用户,一个配置搞定
-
登录成功后,自动踢掉前一个登录用户,松哥第一次见到这个功能,就是在扣扣里边见到的,当时觉得挺好玩的。自己做开发后,也遇到过一模一样的需求,正好最近的SpringSecurity系列正在连载,就结...
- 收藏起来!这款开源在线考试系统,我爱了
-
大家好,我是为广大程序员兄弟操碎了心的小编,每天推荐一个小工具/源码,装满你的收藏夹,每天分享一个小技巧,让你轻松节省开发效率,实现不加班不熬夜不掉头发,是我的目标!今天小编推荐一款基于Spr...
- Shiro框架:认证和授权原理(shiro权限认证流程)
-
优质文章,及时送达前言Shiro作为解决权限问题的常用框架,常用于解决认证、授权、加密、会话管理等场景。本文将对Shiro的认证和授权原理进行介绍:Shiro可以做什么?、Shiro是由什么组成的?举...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- SpringBoot整合SpringSecurity+JWT
- 「计算机毕设」一个精美的JAVA博客系统源码分享
- springboot教务管理系统+微信小程序云开发附带源码
- SpringBoot+LayUI后台管理系统开发脚手架
- Spring Boot的Security安全控制——认识SpringSecurity!
- 前同事2024年接私活已入百万,都是用这几个开源的SpringBoot项目
- 值得学习的15 个优秀开源的 Spring Boot 学习项目
- 开发企业官网就用这个基于SpringBoot的CMS系统,真香
- 2021年超详细的java学习路线总结—纯干货分享
- jeecg-boot学习总结及使用心得(jeecgboot简单吗)
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)