当前位置: 首页 > news >正文

pytorch-frame开源程序适用于 PyTorch 的表格深度学习库,一个模块化深度学习框架,用于在异构表格数据上构建神经网络模型。

​一、软件介绍

文末提供程序和源码下载

      pytorch-frame开源程序适用于 PyTorch 的表格深度学习库,一个模块化深度学习框架,用于在异构表格数据上构建神经网络模型。

     PyTorch Frame 是 PyTorch 的深度学习扩展,专为具有不同列类型(包括数字、分类、时间、文本和图像)的异构表格数据而设计。它为实现现有和未来的方法提供了一个模块化框架。该库包含来自最先进模型、用户友好的小批量加载器、基准测试数据集和自定义数据集成接口的方法。

二、Library Highlights 库亮点

PyTorch Frame builds directly upon PyTorch, ensuring a smooth transition for existing PyTorch users. Key features include:
PyTorch Frame 直接基于 PyTorch 构建,确保现有 PyTorch 用户能够顺利过渡。主要功能包括:

  • Diverse column types: PyTorch Frame supports learning across various column types: numericalcategoricalmulticategoricaltext_embeddedtext_tokenizedtimestampimage_embedded, and embedding. See here for the detailed tutorial.
    多种列类型:PyTorch Frame 支持跨各种列类型学习: numerical 、 categorical multicategorical text_embedded text_tokenized timestamp image_embedded embedding 和 。有关详细教程,请参阅此处。
  • Modular model design: Enables modular deep learning model implementations, promoting reusability, clear coding, and experimentation flexibility. Further details in the architecture overview.
    模块化模型设计:支持模块化深度学习模型实施,促进可重用性、清晰的编码和实验灵活性。有关更多详细信息,请参阅 体系结构概述.
  • Models Implements many state-of-the-art deep tabular models as well as strong GBDTs (XGBoost, CatBoost, and LightGBM) with hyper-parameter tuning.
    模型 实现许多最先进的深度表格模型以及具有超参数优化的强大 GBDT(XGBoost、CatBoost 和 LightGBM)。
  • Datasets: Comes with a collection of readily-usable tabular datasets. Also supports custom datasets to solve your own problem. We benchmark deep tabular models against GBDTs.
    数据集:附带一组易于使用的表格数据集。还支持自定义数据集来解决您自己的问题。我们将深度表格模型与 GBDT 进行基准测试。
  • PyTorch integration: Integrates effortlessly with other PyTorch libraries, facilitating end-to-end training of PyTorch Frame with downstream PyTorch models. For example, by integrating with PyG, a PyTorch library for GNNs, we can perform deep learning over relational databases. Learn more in RelBench and example code.
    PyTorch 集成:轻松与其他 PyTorch 库集成,促进 PyTorch Frame 与下游 PyTorch 模型的端到端训练。例如,通过与 PyG(一个用于 GNN 的 PyTorch 库)集成,我们可以对关系数据库执行深度学习。在 RelBench 和示例代码中了解更多信息。

三、Architecture Overview 架构概述

Models in PyTorch Frame follow a modular design of FeatureEncoderTableConv, and Decoder, as shown in the figure below:
PyTorch Frame 中的模型遵循 FeatureEncoder 、 、 TableConv 和 Decoder 的模块化设计,如下图所示:

In essence, this modular setup empowers users to effortlessly experiment with myriad architectures:
从本质上讲,这种模块化设置使用户能够毫不费力地尝试各种架构:

  • Materialization handles converting the raw pandas DataFrame into a TensorFrame that is amenable to Pytorch-based training and modeling.
    Materialization 处理将原始 pandas 转换为 TensorFrame 适合基于 Pytorch 的训练和建模的 pandas DataFrame 。
  • FeatureEncoder encodes TensorFrame into hidden column embeddings of size [batch_size, num_cols, channels].
    FeatureEncoder 编码 TensorFrame 为 size [batch_size, num_cols, channels] 的隐藏列嵌入向量。
  • TableConv models column-wise interactions over the hidden embeddings.
    TableConv 对隐藏嵌入的逐列交互进行建模。
  • Decoder generates embedding/prediction per row.
    Decoder 每行生成嵌入/预测。

四、Quick Tour 快速浏览

In this quick tour, we showcase the ease of creating and training a deep tabular model with only a few lines of code.
在这个快速导览中,我们展示了仅使用几行代码创建和训练深度表格模型的便利性。

Build and train your own deep tabular model
构建和训练您自己的深度表格模型

As an example, we implement a simple ExampleTransformer following the modular architecture of Pytorch Frame. In the example below:
例如,我们按照 Pytorch Frame 的模块化架构实现了一个简单的 ExampleTransformer 。在下面的示例中:

  • self.encoder maps an input TensorFrame to an embedding of size [batch_size, num_cols, channels].
    self.encoder 将 input TensorFrame 映射到 size [batch_size, num_cols, channels] 的嵌入向量。
  • self.convs iteratively transforms the embedding of size [batch_size, num_cols, channels] into an embedding of the same size.
    self.convs 迭代地将 size [batch_size, num_cols, channels] 的嵌入转换为相同大小的嵌入。
  • self.decoder pools the embedding of size [batch_size, num_cols, channels] into [batch_size, out_channels].
    self.decoder 将 size [batch_size, num_cols, channels] 的嵌入池化到 [batch_size, out_channels] 中。
from torch import Tensor
from torch.nn import Linear, Module, ModuleListfrom torch_frame import TensorFrame, stype
from torch_frame.nn.conv import TabTransformerConv
from torch_frame.nn.encoder import (EmbeddingEncoder,LinearEncoder,StypeWiseFeatureEncoder,
)class ExampleTransformer(Module):def __init__(self,channels, out_channels, num_layers, num_heads,col_stats, col_names_dict,):super().__init__()self.encoder = StypeWiseFeatureEncoder(out_channels=channels,col_stats=col_stats,col_names_dict=col_names_dict,stype_encoder_dict={stype.categorical: EmbeddingEncoder(),stype.numerical: LinearEncoder()},)self.convs = ModuleList([TabTransformerConv(channels=channels,num_heads=num_heads,) for _ in range(num_layers)])self.decoder = Linear(channels, out_channels)def forward(self, tf: TensorFrame) -> Tensor:x, _ = self.encoder(tf)for conv in self.convs:x = conv(x)out = self.decoder(x.mean(dim=1))return out

To prepare the data, we can quickly instantiate a pre-defined dataset and create a PyTorch-compatible data loader as follows:
为了准备数据,我们可以快速实例化预定义的数据集并创建与 PyTorch 兼容的数据加载器,如下所示:

from torch_frame.datasets import Yandex
from torch_frame.data import DataLoaderdataset = Yandex(root='/tmp/adult', name='adult')
dataset.materialize()
train_dataset = dataset[:0.8]
train_loader = DataLoader(train_dataset.tensor_frame, batch_size=128,shuffle=True)

Then, we just follow the standard PyTorch training procedure to optimize the model parameters. That's it!
然后,我们只需按照标准的 PyTorch 训练过程来优化模型参数。就是这样!

import torch
import torch.nn.functional as Fdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ExampleTransformer(channels=32,out_channels=dataset.num_classes,num_layers=2,num_heads=8,col_stats=train_dataset.col_stats,col_names_dict=train_dataset.tensor_frame.col_names_dict,
).to(device)optimizer = torch.optim.Adam(model.parameters())for epoch in range(50):for tf in train_loader:tf = tf.to(device)pred = model.forward(tf)loss = F.cross_entropy(pred, tf.y)optimizer.zero_grad()loss.backward()

五、Implemented Deep Tabular Models实现的深度表格模型

We list currently supported deep tabular models:
我们列出了当前支持的深度表格模型:

  • Trompt from Chen et al.: Trompt: Towards a Better Deep Neural Network for Tabular Data (ICML 2023) [Example]
    Chen 等人的 Trompt:Trompt:为表格数据提供更好的深度神经网络 (ICML 2023) [示例]
  • FTTransformer from Gorishniy et al.: Revisiting Deep Learning Models for Tabular Data (NeurIPS 2021) [Example]
    来自 Gorishniy 等人的 FTTransformer:重新审视表格数据的深度学习模型 (NeurIPS 2021) [示例]
  • ResNet from Gorishniy et al.: Revisiting Deep Learning Models for Tabular Data (NeurIPS 2021) [Example]
    Gorishniy 等人的 ResNet:重新审视表格数据的深度学习模型 (NeurIPS 2021) [示例]
  • TabNet from Arık et al.: TabNet: Attentive Interpretable Tabular Learning (AAAI 2021) [Example]
    来自 Arık 等人的 TabNet:TabNet:专注可解释表格学习 (AAAI 2021) [示例]
  • ExcelFormer from Chen et al.: ExcelFormer: A Neural Network Surpassing GBDTs on Tabular Data [Example]
    来自 Chen 等人的 ExcelFormer:ExcelFormer:在表格数据上超越 GBDT 的神经网络 [示例]
  • TabTransformer from Huang et al.: TabTransformer: Tabular Data Modeling Using Contextual Embeddings [Example]
    来自 Huang 等人的 TabTransformer:TabTransformer:使用上下文嵌入的表格数据建模 [示例]

In addition, we implemented XGBoostCatBoost, and LightGBM examples with hyperparameter-tuning using Optuna for users who'd like to compare their model performance with GBDTs.
此外,我们还使用 Optuna 为 XGBoost CatBoost LightGBM 希望将其模型性能与 GBDTs .

Benchmark 基准

We benchmark recent tabular deep learning models against GBDTs over diverse public datasets with different sizes and task types.
我们在具有不同大小和任务类型的各种公有数据集上将最近的表格深度学习模型与 GBDT 进行基准测试。

The following chart shows the performance of various models on small regression datasets, where the row represents the model names and the column represents dataset indices (we have 13 datasets here). For more results on classification and larger datasets, please check the benchmark documentation.
下图显示了各种模型在小型回归数据集上的性能,其中行表示模型名称,列表示数据集索引(我们这里有 13 个数据集)。有关分类和更大数据集的更多结果,请查看基准测试文档。

Model Name 型号名称dataset_0dataset_1dataset_2dataset_3dataset_4dataset_5dataset_6dataset_7dataset_8dataset_9dataset_10dataset_11dataset_12
XGBoost0.250±0.000 0.250±0.000 元0.038±0.000 0.038±0.000 元0.187±0.000 0,187±0.000 元0.475±0.000 0.475±0.000 元0.328±0.000 0,328±0.000 元0.401±0.000 0.401±0.000 元0.249±0.0000.363±0.0000.904±0.0000.056±0.0000.820±0.0000.857±0.0000.418±0.000
CatBoost 猫加速0.265±0.000 0.265±0.000 元0.062±0.000 0,062±0.000 元0.128±0.000 0.128±0.000 元0.336±0.000 0,336±0.000 元0.346±0.000 0.346±0.000 元0.443±0.000 0.443±0.000 元0.375±0.0000.273±0.0000.881±0.0000.040±0.0000.756±0.0000.876±0.0000.439±0.000
LightGBM0.253±0.000 0,253±0.000 元0.054±0.000 0,054±0.000 元0.112±0.000 0.112±0.000 元0.302±0.000 0.302±0.000 元0.325±0.000 0.325±0.000 元0.384±0.000 0.384±0.000 元0.295±0.0000.272±0.0000.877±0.0000.011±0.0000.702±0.0000.863±0.0000.395±0.000
Trompt Trompt (错视)0.261±0.003 0.261±0.003 元0.015±0.0050.118±0.0010.262±0.0010.323±0.001 0.323±0.001 元0.418±0.003 0.418±0.003 元0.329±0.0090.312±0.002OOM0.008±0.0010.779±0.0060.874±0.0040.424±0.005
ResNet ResNet 公司0.288±0.006 0.288±0.006 元0.018±0.0030.124±0.0010.268±0.0010.335±0.001 0.335±0.001 元0.434±0.004 0.434±0.004 元0.325±0.0120.324±0.0040.895±0.0050.036±0.0020.794±0.0060.875±0.0040.468±0.004
FTTransformerBucket0.325±0.008 0,325±0.008 元0.096±0.0050.360±0.354 0.360±0.354 元0.284±0.005 0.284±0.005 元0.342±0.004 0.342±0.004 元0.441±0.003 0.441±0.003 元0.345±0.0070.339±0.003OOM0.105±0.0110.807±0.0100.885±0.0080.468±0.006
ExcelFormer0.262±0.0040.099±0.003 0.099±0.003 元0.128±0.000 0.128±0.000 元0.264±0.003 0.264±0.003 元0.331±0.0030.411±0.0050.298±0.0120.308±0.007OOM0.011±0.0010.785±0.0110.890±0.0030.431±0.006
FTTransformer0.335±0.010 0.335±0.010 元0.161±0.022 0,161±0.022 元0.140±0.0020.277±0.0040.335±0.003 0.335±0.003 元0.445±0.003 0,445±0.003 元0.361±0.0180.345±0.005OOM0.106±0.0120.826±0.0050.896±0.0070.461±0.003
TabNet 标签网0.279±0.003 0.279±0.003 元0.224±0.016 0.224±0.016 元0.141±0.010 0.141±0.010 元0.275±0.002 0.275±0.002 元0.348±0.003 0,348±0.003 元0.451±0.007 0.451±0.007 元0.355±0.0300.332±0.0040.992±0.1820.015±0.0020.805±0.0140.885±0.0130.544±0.011
TabTransformer TabTransformer (标签变压器)0.624±0.0030.229±0.0030.369±0.005 0.369±0.005 元0.340±0.0040.388±0.0020.539±0.003 0.539±0.003 元0.619±0.0050.351±0.0010.893±0.0050.431±0.0010.819±0.0020.886±0.0050.545±0.004

We see that some recent deep tabular models were able to achieve competitive model performance to strong GBDTs (despite being 5--100 times slower to train). Making deep tabular models even more performant with less compute is a fruitful direction for future research.
我们看到,一些最近的深度表格模型能够实现与强 GBDT 相比有竞争力的模型性能(尽管训练速度慢了 5--100 倍)。以更少的计算量使深度表格模型的性能更高,是未来研究的一个富有成效的方向。

We also benchmark different text encoders on a real-world tabular dataset (Wine Reviews) with one text column. The following table shows the performance:
我们还在具有一个文本列的真实表格数据集 ( Wine Reviews) 上对不同的文本编码器进行基准测试。性能如下表所示:

Test Acc 测试账户Method 方法Model Name 型号名称Source 源
0.7926Pre-trained 预训练sentence-transformers/all-distilroberta-v1 (125M # params)
sentence-transformers/all-distilroberta-v1(125M# 参数)
Hugging Face 拥抱脸
0.7998Pre-trained 预训练embed-english-v3.0 (dimension size: 1024)
embed-english-v3.0(维度大小:1024)
Cohere 凝聚
0.8102Pre-trained 预训练text-embedding-ada-002 (dimension size: 1536)
text-embedding-ada-002(维度大小:1536)
OpenAI 开放人工智能
0.8147Pre-trained 预训练voyage-01 (dimension size: 1024)
voyage-01 (尺寸: 1024)
Voyage AI AI Travel
0.8203Pre-trained 预训练intfloat/e5-mistral-7b-instruct (7B # params)
intfloat/e5-mistral-7b-instruct (7B # 参数)
Hugging Face 拥抱脸
0.8230LoRA Finetune LoRA 微调DistilBERT (66M # params)
DistilBERT (66M # 参数)
Hugging Face 拥抱脸

The benchmark script for Hugging Face text encoders is in this file and for the rest of text encoders is in this file.
Hugging Face 文本编码器的基准脚本位于此文件中,其余文本编码器的基准脚本位于此文件中。

Installation 安装

PyTorch Frame is available for Python 3.9 to Python 3.13.
PyTorch Frame 适用于 Python 3.9 到 Python 3.13。

<span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><span style="color:#1f2328"><span style="color:var(--fgColor-default, var(--color-fg-default))"><span style="background-color:var(--bgColor-muted, var(--color-canvas-subtle))"><code>pip install pytorch-frame
</code></span></span></span></span>

六、软件下载

夸克网盘分享

本文信息来源于GitHub作者地址:https://github.com/pyg-team/pytorch-frame

http://www.xdnf.cn/news/969427.html

相关文章:

  • dMSA 滥用(BadSuccessor)导致权限提升
  • C++11 花括号等式初始化器(Brace-or-Equal Initializers):从入门到精通
  • 安全大模型智驱网络和数据安全效能跃迁
  • 利用最小二乘法找圆心和半径
  • 【从零学习JVM|第五篇】打破双亲委派机制
  • OceanBase v4.3.5 特性解读:通过OSS WORM特性进行备份归档
  • 【动手学深度学习】3.2. 线性回归的从零开始实现
  • [UnrealCircle武汉]UE5跨平台游戏常见问题及解决方案笔记
  • Java八股文——JVM「垃圾回收篇」
  • 鸿蒙接入微信sdk登录 解决提示BundleID信息校验不通过
  • rasa NLU意图解析基础学习
  • 全国空气质量监测站点数据分析:从原始数据到空间可视化
  • 1. 网络基础
  • 带eachers的html转word
  • 渲染学进阶内容——joml库
  • 深入浅出Asp.Net Core MVC应用开发系列-AspNetCore中的日志记录
  • cell properties修改参数
  • 突破不可导策略的训练难题:零阶优化与强化学习的深度嵌合
  • Vue 指令详解:概念与作用
  • 渲染学进阶内容——模型
  • ssc377d修改flash分区大小
  • 86壳温湿度传感器:高温下工业生产的安全防线
  • Elasticsearch增删改查语句
  • GAMES202-高质量实时渲染(Real-Time Ray-Tracing)
  • Minktec 柔性弯曲传感器应用:3D 脊柱姿势追踪与人体活动识别
  • 半加器和全加器
  • React19源码系列之 事件优先级
  • Netty从入门到进阶(三)
  • 淘宝SKU与视频详情API接口使用指南
  • 6月10日day50打卡