当前位置: 首页 > news >正文

DFPatternFunctor遍历计算图

文件include/tvm/relay/dataflow_pattern_functor.h

功能:定义 DFPatternFunctor 基类,为 DFPattern 提供访问者模式(Visitor Pattern)的实现框架,支持对不同类型的模式节点进行差异化处理。

继承关系

template <typename R, typename... Args>
class DFPatternFunctor<R(const DFPattern& n, Args...)>

DFPatternFunctor 类概括

  DFPatternFunctor 是一个用于处理深度学习中计算图模式匹配的访问者模式(Visitor Pattern)实现,它是TVM(深度学习编译器)框架中的一个重要组件。下面我将从多个方面详细解析这个类的设计和实现。

class DFPatternFunctor;// functions to be overriden.
#define DFPATTERN_FUNCTOR_DEFAULT \{ return VisitDFPatternDefault_(op, std::forward<Args>(args)...); }#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP)                                                    \vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) {          \return self->VisitDFPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \});template <typename R, typename... Args>
class DFPatternFunctor<R(const DFPattern& n, Args...)> {private:using TSelf = DFPatternFunctor<R(const DFPattern& n, Args...)>;using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;public:/*! \brief virtual destructor */virtual ~DFPatternFunctor() {}/*!* \brief Same as call.* \param n The expression node.* \param args Additional arguments.* \return The result of the call*/R operator()(const DFPattern& n, Args... args) {return VisitDFPattern(n, std::forward<Args>(args)...);}/*!* \brief The functor call.* \param n The expression node.* \param args Additional arguments.* \return The result of the call*/virtual R VisitDFPattern(const DFPattern& n, Args... args) {CHECK(n.defined());static FType vtable = InitVTable();return vtable(n, this, std::forward<Args>(args)...);}// Functions that can be overriden by subclassvirtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPatternDefault_(const Object* op, Args...) {LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();throw;}private:// initialize the vtable.static FType InitVTable() {FType vtable;// Set dispatchRELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);return vtable;}
};

1. 模板定义

template <typename R, typename... Args>
class DFPatternFunctor<R(const DFPattern& n, Args...)> {
  • 这是一个模板类,接受两个模板参数:
    • R: 表示访问者函数的返回类型
    • Args...: 表示可变数量的额外参数类型
  • 模板特化为函数类型 R(const DFPattern& n, Args...),表示这是一个接受DFPattern和额外参数的函数对象

2. 内部类型定义

private:using TSelf = DFPatternFunctor<R(const DFPattern& n, Args...)>;using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
  • TSelf: 定义当前类的类型别名,简化代码
  • FType: 定义TVM内部的节点函数类型,它接受:
    • ObjectRef& n: 待处理的节点
    • TSelf* self: 访问者对象自身
    • Args...: 可变参数

3. 核心方法

3.1 操作符重载

R operator()(const DFPattern& n, Args... args) {return VisitDFPattern(n, std::forward<Args>(args)...);
}
  • 重载函数调用操作符(),使得对象可以像函数一样被调用
  • 将调用转发给VisitDFPattern方法
  • 使用std::forward完美转发参数,保持参数的值类别(左值/右值)

3.2 主访问方法

virtual R VisitDFPattern(const DFPattern& n, Args... args) {CHECK(n.defined());static FType vtable = InitVTable();return vtable(n, this, std::forward<Args>(args)...);
}
  • 虚函数,可以被派生类重写
  • 首先检查节点是否有效(n.defined())
  • 初始化虚函数表(vtable)
  • 通过虚函数表分发到具体的处理函数

3.3 默认实现方法

virtual R VisitDFPatternDefault_(const Object* op, Args...) {LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();throw;
}
  • 当遇到未知节点类型时的默认处理方法
  • 记录致命错误并抛出异常

4. 虚函数表初始化

static FType InitVTable() {FType vtable;// Set dispatchRELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);// ...其他节点类型的注册return vtable;
}
  • 静态方法,初始化虚函数表
  • 使用宏RELAY_DFPATTERN_FUNCTOR_DISPATCH为每种DFPattern节点类型注册处理函数
  • 这些宏通常会展开为将节点类型与对应的VisitDFPattern_方法关联起来的代码

5. 节点处理方法

virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
// ...其他节点类型的处理方法
  • 每个方法对应一种特定的DFPattern节点类型
  • 使用DFPATTERN_FUNCTOR_DEFAULT宏(可能定义为= 0或默认实现)来指定默认行为
  • 派生类可以重写这些方法来提供特定于节点类型的处理逻辑

6. 设计模式分析

这个类实现了访问者模式的变体,具有以下特点:

  1. 双重分发:
  2. 第一次分发:通过 虚函数表(vtable) 根据 DFPattern 的具体类型(如 CallPatternNodeVarPatternNode 等)选择对应的 VisitDFPattern_ 方法。
  3. 第二次分发:调用具体的 VisitDFPattern_ 方法(如 VisitDFPattern_(const CallPatternNode* op, Args...)),执行真正的处理逻辑。

流程图

+--------------------------+
|  调用 operator()(n, args)  |
+--------------------------+|v
+--------------------------+
| VisitDFPattern(n, args)   |  // 第一次分发:查虚函数表
+--------------------------+|v
+--------------------------+
| vtable(n, this, args)     |  // 根据 n 的类型找到对应的 VisitDFPattern_
+--------------------------+|v
+--------------------------+
| VisitDFPattern_(op, args) |  // 第二次分发:执行具体逻辑
+--------------------------+|v
+--------------------------+
|  返回结果 R               |
+--------------------------+
class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {public:void VisitDFPattern(const DFPattern& pattern) override;void VisitDFPattern_(const AltPatternNode* op) override;void VisitDFPattern_(const AttrPatternNode* op) override;void VisitDFPattern_(const CallPatternNode* op) override;void VisitDFPattern_(const ConstantPatternNode* op) override;void VisitDFPattern_(const DataTypePatternNode* op) override;void VisitDFPattern_(const DominatorPatternNode* op) override;void VisitDFPattern_(const ExprPatternNode* op) override;void VisitDFPattern_(const ShapePatternNode* op) override;void VisitDFPattern_(const TupleGetItemPatternNode* op) override;void VisitDFPattern_(const TuplePatternNode* op) override;void VisitDFPattern_(const TypePatternNode* op) override;void VisitDFPattern_(const VarPatternNode* op) override;void VisitDFPattern_(const WildcardPatternNode* op) override;protected:// set of already-visited nodesstd::unordered_set<const Object*> visited_;
};}  // namespace relay
}  // namespace tvm
#endif  // TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_
  1. 扩展性:

    • 可以轻松添加新的节点类型处理方法
    • 派生类只需重写感兴趣的方法
  2. 类型安全:

    • 使用模板确保类型安全
    • 每个节点类型有专门的处理方法

7. 使用场景

这个类主要用于:

  1. 模式匹配:在深度学习计算图中查找特定模式
  2. 模式转换:修改或重写计算图模式
  3. 模式分析:收集关于计算图模式的统计信息或属性

8. 技术细节

  1. 完美转发:

    • 使用std::forward保持参数的值类别
    • 允许传递左值或右值参数
  2. 虚函数表:

    • 使用TVM内部的NodeFunctor实现动态分发
    • 比传统的虚函数更灵活,可以在运行时修改
  3. 类型系统:

    • 充分利用C++的类型系统为每种节点类型提供类型安全的接口

这个设计展示了TVM框架如何高效地处理复杂的计算图操作,同时保持代码的可扩展性和类型安全。

http://www.xdnf.cn/news/184933.html

相关文章:

  • 【博客系统】博客系统第一弹:博客系统项目配置、MyBatis-Plus 实现 Mapper 接口、处理项目公共模块:统一返回结果、统一异常处理
  • 关于华为高斯数据库出现Invalid or unsupported by client SCRAM mechanisms定位解决的过程
  • -信息革命-
  • OpenManus云端部署及经典案例应用
  • 心磁图技术突破传统局限!心血管疾病早筛迈入“三零“新时代
  • TV launcher官方下载-tv launcher汉化版-tv桌面启动器极简下载
  • c++17 对于临时对象作为右值的优化
  • MRI学习笔记-conjunction analysis
  • Linux——线程(2)线程互斥(锁)
  • 机器学习 | 基于回归模型的交通需求预测案例分析及代码示例
  • 日本IT|UIUX主要的工作都是哪些?及职业前景
  • 【每日随笔】文化属性 ② ( 高维度信息处理 | 强者思维形成 | 认知重构 | 资源捕获 | 进化路径 )
  • LangChain构建大模型应用之RAG
  • 使用ROS实现多机通讯
  • 线上查询车辆出险记录:快速掌握事故情况!
  • 大模型API密钥的环境变量配置(大模型API KEY管理)(将密钥存储在环境变量)(python-dotenv)(密钥管理)
  • 数据结构(七)---链式栈
  • AI看论文自动生成代码库:Paper2Code如何革新科研复现?
  • 函数式链表:Python编程的非常规 “链” 接
  • QT6 源(53)篇三:存储 c 语言字符串的类 QByteArray 的使用举例,
  • 移除生产环境所有console.log
  • 给视频自动打字幕:从Humanoid-X、UH-1到首个人形VLA Humanoid-VLA:迈向整合第一人称视角的通用人形控制
  • 基于STM32、HAL库的AD7616BSTZ模数转换器ADC驱动程序设计
  • Linux操作系统学习---进程地址空间
  • 【LaTex】8.1 文档类与层级
  • 前端权限管理
  • 小刚说C语言刷题——1320时钟旋转
  • 生成式人工智能认证(GAI认证)要学哪些知识?
  • google chrome 中 fcitx5 候选框不跟随光标
  • 【SpringCloudAlibaba】Dubbo 和 Spring Cloud OpenFeign 在服务治理能力上的差异