DFPatternFunctor遍历计算图
文件:include/tvm/relay/dataflow_pattern_functor.h
功能:定义 DFPatternFunctor
基类,为 DFPattern
提供访问者模式(Visitor Pattern)的实现框架,支持对不同类型的模式节点进行差异化处理。
继承关系:
template <typename R, typename... Args>
class DFPatternFunctor<R(const DFPattern& n, Args...)>
DFPatternFunctor 类概括
DFPatternFunctor
是一个用于处理深度学习中计算图模式匹配的访问者模式(Visitor Pattern)实现,它是TVM(深度学习编译器)框架中的一个重要组件。下面我将从多个方面详细解析这个类的设计和实现。
class DFPatternFunctor;// functions to be overriden.
#define DFPATTERN_FUNCTOR_DEFAULT \{ return VisitDFPatternDefault_(op, std::forward<Args>(args)...); }#define RELAY_DFPATTERN_FUNCTOR_DISPATCH(OP) \vtable.template set_dispatch<OP>([](const ObjectRef& n, TSelf* self, Args... args) { \return self->VisitDFPattern_(static_cast<const OP*>(n.get()), std::forward<Args>(args)...); \});template <typename R, typename... Args>
class DFPatternFunctor<R(const DFPattern& n, Args...)> {private:using TSelf = DFPatternFunctor<R(const DFPattern& n, Args...)>;using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;public:/*! \brief virtual destructor */virtual ~DFPatternFunctor() {}/*!* \brief Same as call.* \param n The expression node.* \param args Additional arguments.* \return The result of the call*/R operator()(const DFPattern& n, Args... args) {return VisitDFPattern(n, std::forward<Args>(args)...);}/*!* \brief The functor call.* \param n The expression node.* \param args Additional arguments.* \return The result of the call*/virtual R VisitDFPattern(const DFPattern& n, Args... args) {CHECK(n.defined());static FType vtable = InitVTable();return vtable(n, this, std::forward<Args>(args)...);}// Functions that can be overriden by subclassvirtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;virtual R VisitDFPatternDefault_(const Object* op, Args...) {LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();throw;}private:// initialize the vtable.static FType InitVTable() {FType vtable;// Set dispatchRELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);return vtable;}
};
1. 模板定义
template <typename R, typename... Args>
class DFPatternFunctor<R(const DFPattern& n, Args...)> {
- 这是一个模板类,接受两个模板参数:
R
: 表示访问者函数的返回类型Args...
: 表示可变数量的额外参数类型
- 模板特化为函数类型
R(const DFPattern& n, Args...)
,表示这是一个接受DFPattern和额外参数的函数对象
2. 内部类型定义
private:using TSelf = DFPatternFunctor<R(const DFPattern& n, Args...)>;using FType = tvm::NodeFunctor<R(const ObjectRef& n, TSelf* self, Args...)>;
TSelf
: 定义当前类的类型别名,简化代码FType
: 定义TVM内部的节点函数类型,它接受:ObjectRef& n
: 待处理的节点TSelf* self
: 访问者对象自身Args...
: 可变参数
3. 核心方法
3.1 操作符重载
R operator()(const DFPattern& n, Args... args) {return VisitDFPattern(n, std::forward<Args>(args)...);
}
- 重载函数调用操作符
()
,使得对象可以像函数一样被调用 - 将调用转发给
VisitDFPattern
方法 - 使用
std::forward
完美转发参数,保持参数的值类别(左值/右值)
3.2 主访问方法
virtual R VisitDFPattern(const DFPattern& n, Args... args) {CHECK(n.defined());static FType vtable = InitVTable();return vtable(n, this, std::forward<Args>(args)...);
}
- 虚函数,可以被派生类重写
- 首先检查节点是否有效(
n.defined()
) - 初始化虚函数表(
vtable
) - 通过虚函数表分发到具体的处理函数
3.3 默认实现方法
virtual R VisitDFPatternDefault_(const Object* op, Args...) {LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();throw;
}
- 当遇到未知节点类型时的默认处理方法
- 记录致命错误并抛出异常
4. 虚函数表初始化
static FType InitVTable() {FType vtable;// Set dispatchRELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode);RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode);// ...其他节点类型的注册return vtable;
}
- 静态方法,初始化虚函数表
- 使用宏
RELAY_DFPATTERN_FUNCTOR_DISPATCH
为每种DFPattern节点类型注册处理函数 - 这些宏通常会展开为将节点类型与对应的
VisitDFPattern_
方法关联起来的代码
5. 节点处理方法
virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
// ...其他节点类型的处理方法
- 每个方法对应一种特定的DFPattern节点类型
- 使用
DFPATTERN_FUNCTOR_DEFAULT
宏(可能定义为= 0
或默认实现)来指定默认行为 - 派生类可以重写这些方法来提供特定于节点类型的处理逻辑
6. 设计模式分析
这个类实现了访问者模式的变体,具有以下特点:
- 双重分发:
- 第一次分发:通过 虚函数表(vtable) 根据
DFPattern
的具体类型(如CallPatternNode
、VarPatternNode
等)选择对应的VisitDFPattern_
方法。 - 第二次分发:调用具体的
VisitDFPattern_
方法(如VisitDFPattern_(const CallPatternNode* op, Args...)
),执行真正的处理逻辑。
流程图
+--------------------------+
| 调用 operator()(n, args) |
+--------------------------+|v
+--------------------------+
| VisitDFPattern(n, args) | // 第一次分发:查虚函数表
+--------------------------+|v
+--------------------------+
| vtable(n, this, args) | // 根据 n 的类型找到对应的 VisitDFPattern_
+--------------------------+|v
+--------------------------+
| VisitDFPattern_(op, args) | // 第二次分发:执行具体逻辑
+--------------------------+|v
+--------------------------+
| 返回结果 R |
+--------------------------+
class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {public:void VisitDFPattern(const DFPattern& pattern) override;void VisitDFPattern_(const AltPatternNode* op) override;void VisitDFPattern_(const AttrPatternNode* op) override;void VisitDFPattern_(const CallPatternNode* op) override;void VisitDFPattern_(const ConstantPatternNode* op) override;void VisitDFPattern_(const DataTypePatternNode* op) override;void VisitDFPattern_(const DominatorPatternNode* op) override;void VisitDFPattern_(const ExprPatternNode* op) override;void VisitDFPattern_(const ShapePatternNode* op) override;void VisitDFPattern_(const TupleGetItemPatternNode* op) override;void VisitDFPattern_(const TuplePatternNode* op) override;void VisitDFPattern_(const TypePatternNode* op) override;void VisitDFPattern_(const VarPatternNode* op) override;void VisitDFPattern_(const WildcardPatternNode* op) override;protected:// set of already-visited nodesstd::unordered_set<const Object*> visited_;
};} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_DATAFLOW_PATTERN_FUNCTOR_H_
-
扩展性:
- 可以轻松添加新的节点类型处理方法
- 派生类只需重写感兴趣的方法
-
类型安全:
- 使用模板确保类型安全
- 每个节点类型有专门的处理方法
7. 使用场景
这个类主要用于:
- 模式匹配:在深度学习计算图中查找特定模式
- 模式转换:修改或重写计算图模式
- 模式分析:收集关于计算图模式的统计信息或属性
8. 技术细节
-
完美转发:
- 使用
std::forward
保持参数的值类别 - 允许传递左值或右值参数
- 使用
-
虚函数表:
- 使用TVM内部的
NodeFunctor
实现动态分发 - 比传统的虚函数更灵活,可以在运行时修改
- 使用TVM内部的
-
类型系统:
- 充分利用C++的类型系统为每种节点类型提供类型安全的接口
这个设计展示了TVM框架如何高效地处理复杂的计算图操作,同时保持代码的可扩展性和类型安全。