当前位置: 首页 > news >正文

时间序列数据集增强构造方案(时空网络建模)

时间序列数据集增强构造方案(时空网络建模)

时间序列数据集TimeSeriesDataset
时间序列数据集增强EnhancedTimeSeriesDataset

一、方案背景与动机

1.1 背景分析

传统时间序列预测方法(如ARIMA、Prophet等)以及很多深度学习方法(如 TFT、DeepAR、Auto former等模型)主要关注单序列的时序模式以及对应的协变量,不考虑实体间(序列间)的依赖关系。不考虑实体间依赖关系的训练集可以通过TimeSeriesDataset构建,但在现代商业场景中,实体间往往存在复杂的依赖关系,这时就需要更复杂的数据集生成方式。具体业务中的实体依赖关系比如:

  • 零售场景:互补商品(咖啡机与咖啡胶囊,排骨和玉米/冬瓜)存在连带销售效应,替代商品(肋排和汤小排)存在竞争销售效应

  • 能源场景:相邻变电站的电力负荷存在空间相关性

  • 交通场景:相邻路段的交通流量相互影响

现有解决方案的局限性:

  • 孤立预测:独立处理各实体序列,忽视跨实体关联

  • 静态关系:无法动态捕捉随时间变化的依赖模式

  • 计算效率:全连接注意力机制导致计算复杂度爆炸式增长,销量预测场景中实体百万量级,全连接图无法满足内存需求

1.2 设计动机

本方案旨在解决以下核心问题:

  • 关系建模:显式建模具有业务逻辑的实体依赖关系(替代/互补/空间邻近等),同时建模时空依赖关系

  • 效率优化:通过分组批处理减少无效计算,分组批处理可以理解为一个子图,每个节点代表一个实体(一个时间序列)

  • 兼容性:保持与传统时间序列模型的接口一致性,在传统时间序列数据模型的接口上进行继承和增强

二、应用场景

2.1典型应用案例

场景依赖类型数据特征
零售销量预测替代关系(竞品商品)
互补关系(配套商品)
多门店×多商品×时间序列(多变量,销量、活动、calendar、weather等)
电力负荷预测空间邻近关系
电压等级关联
多变电站×时间序列
交通流量预测路段连通性
时段相关性
多路段×时间序列

2.2业务价值

  • 提升预测精度:通过关联序列信息修正预测偏差,比如建模商品替代互补作用对销量的影响,替代关系中A品有大折扣会导致B销量降低,互补关系中A品大折扣会带着B销量增高
  • 增强可解释性:可视化注意力权重揭示商品关联强度,可以挖掘商品替代互补关系强度、路段连通性等
  • 优化库存管理:基于互补关系预测优化备货策略

三、批处理方案详解

3.1 核心设计思想

批处理架构图

graph TDA[原始数据] --> B{依赖关系定义}B --> C[Cluster分组,子图]C --> D[批次采样,每次一个cluster子图]D --> E[动态填充]E --> F[模型训练]

3.2 关键技术实现

(1) ClusterBatchSampler

关键特性:

  • 依赖关系保持:确保同cluster样本同批次
  • 灵活批次控制:
    • shuffle=True:打乱cluster顺序(保持cluster内顺序)
    • drop_last=True:丢弃不足批次大小的尾部数据
  • 高效内存管理:预先生成所有批次索引

与DataLoader的协作:

DataLoader参数作用本方案设置
batch_sampler替代默认采样逻辑使用ClusterBatchSampler实例
collate_fn自定义批次组装逻辑使用 cluster_collate_fn
shuffle全局打乱数据必须设为 False

(2) 动态填充策略

填充策略优势:

  • 形状一致性:确保张量维度统一
  • 信息无损:原始数据完整保留
  • 灵活配置:支持动态/固定批次大小

3.3与传统方案的对比

维度传统方案本方案
实体关系处理忽略显式建模
内存占用全连接矩阵存储仅存储分组索引
业务适配性通用场景可定制依赖关系

四、总结

本方案通过创新的批处理机制,在传统时间序列预测框架中引入实体关系建模能力,为复杂业务场景提供了有效的解决方案。建议根据实际业务需求调整依赖关系定义策略,并通过可视化工具持续监控模型学习效果。

五、代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import random
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoderclass EnhancedTimeSeriesDataset(TimeSeriesDataset): ## 继承TimeSeriesDatasetdef __init__(self, dependency_clusters: dict,  # {group_id: cluster_id}**kwargs):"""增强版时间序列数据集Args:dependency_clusters (dict): 定义分组关系的字典**kwargs: 原始TimeSeriesDataset参数"""super().__init__(**kwargs)# 存储分组信息self.dependency_clusters = dependency_clusters# 构建cluster索引self._build_cluster_index()def _build_cluster_index(self):"""构建cluster到样本索引的映射"""self.cluster_to_indices = defaultdict(list)for idx, sample in enumerate(self.samples):group_id = sample["group_id"]cluster_id = self.dependency_clusters.get(group_id, -1)  # -1表示未分组self.cluster_to_indices[cluster_id].append(idx)class ClusterBatchSampler:def __init__(self, dataset, batch_size, shuffle=True, drop_last=False):"""基于cluster的批采样器Args:dataset: EnhancedTimeSeriesDataset实例batch_size: 目标批次大小shuffle: 是否打乱cluster顺序drop_last: 是否丢弃不足批次"""self.batch_size = batch_sizeself.shuffle = shuffleself.drop_last = drop_last# 按cluster组织索引self.batches = []for cluster_id, indices in dataset.cluster_to_indices.items():if cluster_id == -1:  # 跳过未分组样本continue# 打乱cluster内部顺序if shuffle:random.shuffle(indices)# 拆分cluster为多个批次,如果batch_size大于cluster size,则一个cluster在一个batch内for i in range(0, len(indices), batch_size):batch = indices[i:i+batch_size]if not drop_last or len(batch) == batch_size:self.batches.append(batch)# 打乱批次顺序if shuffle:random.shuffle(self.batches)def __iter__(self):return iter(self.batches)def __len__(self):return len(self.batches)def cluster_collate_fn(batch, dataset, fixed_batch_size=None):"""增强版批处理函数Args:batch: 原始批次数据dataset: 数据集实例fixed_batch_size: 固定批次大小(None表示动态大小)"""# 动态填充逻辑current_size = len(batch)padding_samples = []if fixed_batch_size is not None and current_size < fixed_batch_size:# 生成填充样本for _ in range(fixed_batch_size - current_size):padding_samples.append({"encoder_input": torch.full((dataset.max_encoder_length, len(dataset.encoder_features)), dataset.padding_value),"decoder_input": torch.full((dataset.max_decoder_length, len(dataset.decoder_features)), dataset.padding_value),"target": torch.full((dataset.max_decoder_length,), dataset.padding_value),"encoder_mask": torch.zeros(dataset.max_encoder_length),"decoder_mask": torch.zeros(dataset.max_decoder_length),"static_features": torch.zeros(len(dataset.static_features), dtype=torch.long),"actual_lengths": (0, 0)})# 合并真实数据和填充数据full_batch = batch + padding_samples# 转换张量collated = {}for key in batch[0].keys():if key == "actual_lengths":L = [item[key][0] for item in full_batch]D = [item[key][1] for item in full_batch]collated["actual_lengths"] = (torch.tensor(L), torch.tensor(D))else:collated[key] = torch.stack([item[key] for item in full_batch])# 生成批次掩码if fixed_batch_size is not None:batch_mask = torch.zeros(fixed_batch_size, dtype=torch.float)batch_mask[:current_size] = 1.0collated["batch_mask"] = batch_maskreturn collatedclass SpatioTemporalTransformer(nn.Module):def __init__(self, input_dim, nhead=8, hidden_dim=256):super().__init__()# 位置编码self.pos_encoder = nn.Embedding(1000, input_dim)  # 假设最大序列长度1000# 跨序列注意力层self.cross_attention = nn.MultiheadAttention(embed_dim=input_dim,num_heads=nhead,batch_first=True)# 前馈网络self.ffn = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, input_dim))def forward(self, x, batch_mask=None):"""Args:x: 输入序列 [B, L, D]batch_mask: 批次掩码 [B]"""B, L, D = x.size()# 位置编码positions = torch.arange(L, device=x.device).expand(B, L)x = x + self.pos_encoder(positions)# 跨序列注意力if batch_mask is not None:padding_mask = ~batch_mask.bool()  # 反转掩码else:padding_mask = None# 维度变换 [B, L, D] -> [L, B, D]x = x.permute(1, 0, 2)# 计算注意力attn_output, _ = self.cross_attention(query=x,key=x,value=x,key_padding_mask=padding_mask)# 维度恢复 [L, B, D] -> [B, L, D]attn_output = attn_output.permute(1, 0, 2)# 前馈网络output = self.ffn(attn_output)return output# 使用示例
if __name__ == "__main__":# 创建测试数据dates = pd.date_range(start="2023-01-01", periods=90, name="target_date")example_data = pd.DataFrame({"store_id": np.repeat([1, 2], 45),"product_id": np.repeat([101, 102], 45),"target_date": np.tile(dates, 2),"sale_amount": np.random.randint(0, 100, 180),"discount": np.random.rand(180),"precipitation": np.random.rand(180),"temperature": np.random.rand(180),})# 定义分组关系,示例:互补商品分组dependency_clusters = {(1, 101): 0,(1, 102): 0,  # 同一cluster(2, 101): 1,(2, 102): 1   # 另一cluster}# 初始化数据集dataset = EnhancedTimeSeriesDataset(data=example_data,dependency_clusters=dependency_clusters,max_encoder_length=35,min_encoder_length=14,max_decoder_length=14,min_decoder_length=7,num_samples_per_step=1)# 创建数据加载器dataloader = DataLoader(dataset,batch_sampler=ClusterBatchSampler(dataset, batch_size=32, shuffle=True),collate_fn=lambda b: cluster_collate_fn(b, dataset, fixed_batch_size=32),num_workers=4)# 初始化模型model = SpatioTemporalTransformer(input_dim=len(dataset.encoder_features),nhead=4,hidden_dim=128)# 训练循环optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)for epoch in range(10):for batch in dataloader:# 前向传播encoder_input = batch["encoder_input"]  # [32, 35, 4]batch_mask = batch.get("batch_mask", None)output = model(encoder_input, batch_mask)# 计算损失(示例)target = batch["target"]  # [32, 14]loss = F.mse_loss(output[:, -14:], target)  # 取最后14个时间步# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch {epoch} Loss: {loss.item():.4f}")
http://www.xdnf.cn/news/328429.html

相关文章:

  • 【网络编程】二、UDP网络套接字编程详解
  • 项目文档归档的最佳实践有哪些?
  • Nacos源码—Nacos集群高可用分析(二)
  • java实现一个操作日志模块功能,怎么设计
  • 【云备份】项目展示项目总结
  • 深入理解Redis缓存与数据库不一致问题及其解决方案
  • Matlab 多策略改进蜣螂优化算法及其在CEC2017性能
  • PCI-Compatible Configuration Registers--BIST Register (Offset 0Fh)
  • 跨物种交流新时代!百度发布动物语言转换专利,听懂宠物心声
  • 电池管理系统BMS三级架构——BMU、BCU和BAU详解
  • Webug4.0靶场通关笔记20- 第25关越权查看admin
  • 读《暗时间》有感
  • 基于RT-Thread的STM32G4开发第二讲第二篇——ADC
  • 2014年写的一个文档《基于大数据应用的综合健康服务平台研发及应用示范》
  • layui下拉框输入关键字才出数据
  • JMeter快速指南:命令行生成HTML测试报告(附样例命令解析)
  • Android学习总结之网络篇补充
  • conda init before conda activate
  • MVC是什么?分别对应SpringBoot哪些层?
  • 【C/C++】ARM处理器对齐_伪共享问题
  • autojs和冰狐智能辅助该怎么选择?
  • 从D盘分配空间为C盘扩容?利用工具1+1>2
  • 使用JMeter 编写的测试计划的多个线程组如何生成独立的线程组报告
  • 理解文本嵌入:语义空间之旅
  • 探索 H-ZERO 模态框组件:提升用户交互体验的利器
  • PaaS筑基,中国中化实现转型飞跃
  • ROS1和ROS2使用桥接工具通信
  • 【CF】Day53——Codeforces Round 1023 (Div. 2) CD
  • 中级网络工程师知识点1
  • 自定义分区器-基础