TIME - MoE 模型代码 3.2——Time-MoE-main/time_moe/datasets/time_moe_dataset.py
源码:GitHub - Time-MoE/Time-MoE: [ICLR 2025 Spotlight] Official implementation of "Time-MoE: Billion-Scale Time Series Foundation Models with Mixture of Experts"
这段代码定义了一个用于时间序列数据处理的 TimeMoEDataset
类,支持多种数据格式和归一化方法。
1. 类定义与初始化 (__init__
)
class TimeMoEDataset(TimeSeriesDataset):def __init__(self, data_folder, normalization_method=None):self.data_folder = data_folderself.normalization_method = normalization_methodself.datasets = [] # 存储子数据集(BinaryDataset/GeneralDataset)self.num_tokens = None # 总时间点数量# 处理归一化方法if normalization_method is None:self.normalization_method = Noneelif isinstance(normalization_method, str):if normalization_method.lower() == 'max':self.normalization_method = max_scaler # 最大值归一化elif normalization_method.lower() == 'zero':self.normalization_method = zero_scaler # 标准化(Z-score)else:raise ValueError(f'未知归一化方法: {normalization_method}')else:self.normalization_method = normalization_method # 自定义归一化函数# 加载数据:支持二进制文件或普通文件/文件夹if BinaryDataset.is_valid_path(self.data_folder):ds = BinaryDataset(self.data_folder)if len(ds) > 0:self.datasets.append(ds)elif GeneralDataset.is_valid_path(self.data_folder):ds = GeneralDataset(self.data_folder)if len(ds) > 0:self.datasets.append(ds)else:# 递归遍历文件夹,加载所有有效文件for root, dirs, files in os.walk(self.data_folder):for file in files:fn_path = os.path.join(root, file)# 跳过二进制元数据文件,加载普通数据集if file != BinaryDataset.meta_file_name and GeneralDataset.is_valid_path(fn_path):ds = GeneralDataset(fn_path)if len(ds) > 0:self.datasets.append(ds)for sub_folder in dirs:folder_path = os.path.join(root, sub_folder)# 检查子文件夹是否为二进制数据集if BinaryDataset.is_valid_path(folder_path):ds = BinaryDataset(folder_path)if len(ds) > 0:self.datasets.append(ds)# 计算累计长度数组,用于快速定位子数据集self.cumsum_lengths = [0]for ds in self.datasets:self.cumsum_lengths.append(self.cumsum_lengths[-1] + len(ds))self.num_sequences = self.cumsum_lengths[-1] # 总序列数
- 数据加载:
- 支持两种数据集类型:
BinaryDataset
(二进制格式)和GeneralDataset
(普通文本 / CSV 等)。 - 通过
os.walk
递归遍历文件夹,自动识别有效数据文件,避免手动指定每个文件路径。
- 支持两种数据集类型:
- 归一化处理:
- 内置两种归一化方法:
max_scaler
(最大值归一化)和zero_scaler
(Z-score 标准化)。 - 支持自定义归一化函数,通过
normalization_method
参数传入。
- 内置两种归一化方法:
- 子数据集管理:
cumsum_lengths
数组记录每个子数据集的起始索引(类似前缀和),例如cumsum_lengths = [0, 100, 300]
表示第一个子数据集有 100 条序列,第二个有 200 条。
2. 序列索引与获取 (__getitem__
)
def __getitem__(self, seq_idx):if seq_idx >= self.cumsum_lengths[-1] or seq_idx < 0:raise ValueError(f'索引越界: {seq_idx}')# 二分查找确定子数据集索引和偏移量dataset_idx = binary_search(self.cumsum_lengths, seq_idx)dataset_offset = seq_idx - self.cumsum_lengths[dataset_idx]seq = self.datasets[dataset_idx][dataset_offset] # 获取原始序列# 应用归一化if self.normalization_method is not None:seq = self.normalization_method(seq)return seq
- 二分查找:通过
binary_search
函数在cumsum_lengths
中快速定位序列所属的子数据集(时间复杂度为 \(O(\log N)\))。 - 归一化应用:对获取的序列调用
normalization_method
函数,返回归一化后的数据(numpy 数组)。
3. 辅助方法
4. 归一化函数实现
5. 二分查找函数 (binary_search
)
def binary_search(sorted_list, value):low, high = 0, len(sorted_list) - 1best_index = -1while low <= high:mid = (low + high) // 2if sorted_list[mid] <= value: # 寻找最大的不超过value的索引best_index = midlow = mid + 1else:high = mid - 1return best_index
- 在有序数组
sorted_list
(即cumsum_lengths
)中查找value
所属的区间,返回子数据集索引。例如,若cumsum_lengths = [0, 100, 300]
,value=150
会被定位到索引 1(第二个子数据集,偏移量 50)。
6. 总结
TimeMoEDataset
是一个高效、鲁棒的时间序列数据集加载器,核心功能包括:
- 多格式数据加载:自动识别二进制和普通文件,支持递归遍历文件夹。
- 灵活归一化:内置两种常用归一化方法,支持自定义函数,处理边界情况。
- 高效索引:通过前缀和数组和二分查找,快速定位子数据集,适合大规模数据。
该类为后续模型训练(如 TimeMoeTrainer
)提供了统一的数据接口,确保数据预处理的标准化和高效性。