Halide 21.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
IRVisitor.h
Go to the documentation of this file.
1#ifndef HALIDE_IR_VISITOR_H
2#define HALIDE_IR_VISITOR_H
3
4#include <set>
5
6#include "IR.h"
7
8/** \file
9 * Defines the base class for things that recursively walk over the IR
10 */
11
12namespace Halide {
13namespace Internal {
14
15/** A base class for algorithms that need to recursively walk over the
16 * IR. The default implementations just recursively walk over the
17 * children. Override the ones you care about.
18 */
19class IRVisitor {
20public:
21 IRVisitor() = default;
22 virtual ~IRVisitor() = default;
23
24protected:
25 // ExprNode<> and StmtNode<> are allowed to call visit (to implement accept())
26 template<typename T>
27 friend struct ExprNode;
28
29 template<typename T>
30 friend struct StmtNode;
31
32 virtual void visit(const IntImm *);
33 virtual void visit(const UIntImm *);
34 virtual void visit(const FloatImm *);
35 virtual void visit(const StringImm *);
36 virtual void visit(const Cast *);
37 virtual void visit(const Reinterpret *);
38 virtual void visit(const Variable *);
39 virtual void visit(const Add *);
40 virtual void visit(const Sub *);
41 virtual void visit(const Mul *);
42 virtual void visit(const Div *);
43 virtual void visit(const Mod *);
44 virtual void visit(const Min *);
45 virtual void visit(const Max *);
46 virtual void visit(const EQ *);
47 virtual void visit(const NE *);
48 virtual void visit(const LT *);
49 virtual void visit(const LE *);
50 virtual void visit(const GT *);
51 virtual void visit(const GE *);
52 virtual void visit(const And *);
53 virtual void visit(const Or *);
54 virtual void visit(const Not *);
55 virtual void visit(const Select *);
56 virtual void visit(const Load *);
57 virtual void visit(const Ramp *);
58 virtual void visit(const Broadcast *);
59 virtual void visit(const Call *);
60 virtual void visit(const Let *);
61 virtual void visit(const LetStmt *);
62 virtual void visit(const AssertStmt *);
63 virtual void visit(const ProducerConsumer *);
64 virtual void visit(const For *);
65 virtual void visit(const Store *);
66 virtual void visit(const Provide *);
67 virtual void visit(const Allocate *);
68 virtual void visit(const Free *);
69 virtual void visit(const Realize *);
70 virtual void visit(const Block *);
71 virtual void visit(const IfThenElse *);
72 virtual void visit(const Evaluate *);
73 virtual void visit(const Shuffle *);
74 virtual void visit(const VectorReduce *);
75 virtual void visit(const Prefetch *);
76 virtual void visit(const Fork *);
77 virtual void visit(const Acquire *);
78 virtual void visit(const Atomic *);
79 virtual void visit(const HoistedStorage *);
80};
81
82/** A base class for algorithms that walk recursively over the IR
83 * without visiting the same node twice. This is for passes that are
84 * capable of interpreting the IR as a DAG instead of a tree. */
85class IRGraphVisitor : public IRVisitor {
86protected:
87 /** By default these methods add the node to the visited set, and
88 * return whether or not it was already there. If it wasn't there,
89 * it delegates to the appropriate visit method. You can override
90 * them if you like. */
91 // @{
92 virtual void include(const Expr &);
93 virtual void include(const Stmt &);
94 // @}
95
96private:
97 /** The nodes visited so far. Only includes nodes with a ref count greater
98 * than one, because we know that nodes with a ref count of 1 will only be
99 * visited once if their parents are only visited once. */
100 std::set<const IRNode *> visited;
101
102protected:
103 /** These methods should call 'include' on the children to only
104 * visit them if they haven't been visited already. */
105 // @{
106 void visit(const IntImm *) override;
107 void visit(const UIntImm *) override;
108 void visit(const FloatImm *) override;
109 void visit(const StringImm *) override;
110 void visit(const Cast *) override;
111 void visit(const Reinterpret *) override;
112 void visit(const Variable *) override;
113 void visit(const Add *) override;
114 void visit(const Sub *) override;
115 void visit(const Mul *) override;
116 void visit(const Div *) override;
117 void visit(const Mod *) override;
118 void visit(const Min *) override;
119 void visit(const Max *) override;
120 void visit(const EQ *) override;
121 void visit(const NE *) override;
122 void visit(const LT *) override;
123 void visit(const LE *) override;
124 void visit(const GT *) override;
125 void visit(const GE *) override;
126 void visit(const And *) override;
127 void visit(const Or *) override;
128 void visit(const Not *) override;
129 void visit(const Select *) override;
130 void visit(const Load *) override;
131 void visit(const Ramp *) override;
132 void visit(const Broadcast *) override;
133 void visit(const Call *) override;
134 void visit(const Let *) override;
135 void visit(const LetStmt *) override;
136 void visit(const AssertStmt *) override;
137 void visit(const ProducerConsumer *) override;
138 void visit(const For *) override;
139 void visit(const Store *) override;
140 void visit(const Provide *) override;
141 void visit(const Allocate *) override;
142 void visit(const Free *) override;
143 void visit(const Realize *) override;
144 void visit(const Block *) override;
145 void visit(const IfThenElse *) override;
146 void visit(const Evaluate *) override;
147 void visit(const Shuffle *) override;
148 void visit(const VectorReduce *) override;
149 void visit(const Prefetch *) override;
150 void visit(const Acquire *) override;
151 void visit(const Fork *) override;
152 void visit(const Atomic *) override;
153 void visit(const HoistedStorage *) override;
154 // @}
155};
156
157/** A visitor/mutator capable of passing arbitrary arguments to the
158 * visit methods using CRTP and returning any types from them. All
159 * Expr visitors must have the same signature, and all Stmt visitors
160 * must have the same signature. Does not have default implementations
161 * of the visit methods. */
162template<typename T, typename ExprRet, typename StmtRet>
164private:
165 template<typename... Args>
166 ExprRet dispatch_expr(const BaseExprNode *node, Args &&...args) {
167 if (node == nullptr) {
168 return ExprRet{};
169 }
170 switch (node->node_type) {
172 return ((T *)this)->visit((const IntImm *)node, std::forward<Args>(args)...);
174 return ((T *)this)->visit((const UIntImm *)node, std::forward<Args>(args)...);
176 return ((T *)this)->visit((const FloatImm *)node, std::forward<Args>(args)...);
178 return ((T *)this)->visit((const StringImm *)node, std::forward<Args>(args)...);
180 return ((T *)this)->visit((const Broadcast *)node, std::forward<Args>(args)...);
181 case IRNodeType::Cast:
182 return ((T *)this)->visit((const Cast *)node, std::forward<Args>(args)...);
184 return ((T *)this)->visit((const Reinterpret *)node, std::forward<Args>(args)...);
186 return ((T *)this)->visit((const Variable *)node, std::forward<Args>(args)...);
187 case IRNodeType::Add:
188 return ((T *)this)->visit((const Add *)node, std::forward<Args>(args)...);
189 case IRNodeType::Sub:
190 return ((T *)this)->visit((const Sub *)node, std::forward<Args>(args)...);
191 case IRNodeType::Mod:
192 return ((T *)this)->visit((const Mod *)node, std::forward<Args>(args)...);
193 case IRNodeType::Mul:
194 return ((T *)this)->visit((const Mul *)node, std::forward<Args>(args)...);
195 case IRNodeType::Div:
196 return ((T *)this)->visit((const Div *)node, std::forward<Args>(args)...);
197 case IRNodeType::Min:
198 return ((T *)this)->visit((const Min *)node, std::forward<Args>(args)...);
199 case IRNodeType::Max:
200 return ((T *)this)->visit((const Max *)node, std::forward<Args>(args)...);
201 case IRNodeType::EQ:
202 return ((T *)this)->visit((const EQ *)node, std::forward<Args>(args)...);
203 case IRNodeType::NE:
204 return ((T *)this)->visit((const NE *)node, std::forward<Args>(args)...);
205 case IRNodeType::LT:
206 return ((T *)this)->visit((const LT *)node, std::forward<Args>(args)...);
207 case IRNodeType::LE:
208 return ((T *)this)->visit((const LE *)node, std::forward<Args>(args)...);
209 case IRNodeType::GT:
210 return ((T *)this)->visit((const GT *)node, std::forward<Args>(args)...);
211 case IRNodeType::GE:
212 return ((T *)this)->visit((const GE *)node, std::forward<Args>(args)...);
213 case IRNodeType::And:
214 return ((T *)this)->visit((const And *)node, std::forward<Args>(args)...);
215 case IRNodeType::Or:
216 return ((T *)this)->visit((const Or *)node, std::forward<Args>(args)...);
217 case IRNodeType::Not:
218 return ((T *)this)->visit((const Not *)node, std::forward<Args>(args)...);
220 return ((T *)this)->visit((const Select *)node, std::forward<Args>(args)...);
221 case IRNodeType::Load:
222 return ((T *)this)->visit((const Load *)node, std::forward<Args>(args)...);
223 case IRNodeType::Ramp:
224 return ((T *)this)->visit((const Ramp *)node, std::forward<Args>(args)...);
225 case IRNodeType::Call:
226 return ((T *)this)->visit((const Call *)node, std::forward<Args>(args)...);
227 case IRNodeType::Let:
228 return ((T *)this)->visit((const Let *)node, std::forward<Args>(args)...);
230 return ((T *)this)->visit((const Shuffle *)node, std::forward<Args>(args)...);
232 return ((T *)this)->visit((const VectorReduce *)node, std::forward<Args>(args)...);
233 // Explicitly list the Stmt types rather than using a
234 // default case so that when new IR nodes are added we
235 // don't miss them here.
239 case IRNodeType::For:
244 case IRNodeType::Free:
247 case IRNodeType::Fork:
253 internal_error << "Unreachable";
254 }
255 return ExprRet{};
256 }
257
258 template<typename... Args>
259 StmtRet dispatch_stmt(const BaseStmtNode *node, Args &&...args) {
260 if (node == nullptr) {
261 return StmtRet{};
262 }
263 switch (node->node_type) {
269 case IRNodeType::Cast:
272 case IRNodeType::Add:
273 case IRNodeType::Sub:
274 case IRNodeType::Mod:
275 case IRNodeType::Mul:
276 case IRNodeType::Div:
277 case IRNodeType::Min:
278 case IRNodeType::Max:
279 case IRNodeType::EQ:
280 case IRNodeType::NE:
281 case IRNodeType::LT:
282 case IRNodeType::LE:
283 case IRNodeType::GT:
284 case IRNodeType::GE:
285 case IRNodeType::And:
286 case IRNodeType::Or:
287 case IRNodeType::Not:
289 case IRNodeType::Load:
290 case IRNodeType::Ramp:
291 case IRNodeType::Call:
292 case IRNodeType::Let:
295 internal_error << "Unreachable";
296 break;
298 return ((T *)this)->visit((const LetStmt *)node, std::forward<Args>(args)...);
300 return ((T *)this)->visit((const AssertStmt *)node, std::forward<Args>(args)...);
302 return ((T *)this)->visit((const ProducerConsumer *)node, std::forward<Args>(args)...);
303 case IRNodeType::For:
304 return ((T *)this)->visit((const For *)node, std::forward<Args>(args)...);
306 return ((T *)this)->visit((const Acquire *)node, std::forward<Args>(args)...);
308 return ((T *)this)->visit((const Store *)node, std::forward<Args>(args)...);
310 return ((T *)this)->visit((const Provide *)node, std::forward<Args>(args)...);
312 return ((T *)this)->visit((const Allocate *)node, std::forward<Args>(args)...);
313 case IRNodeType::Free:
314 return ((T *)this)->visit((const Free *)node, std::forward<Args>(args)...);
316 return ((T *)this)->visit((const Realize *)node, std::forward<Args>(args)...);
318 return ((T *)this)->visit((const Block *)node, std::forward<Args>(args)...);
319 case IRNodeType::Fork:
320 return ((T *)this)->visit((const Fork *)node, std::forward<Args>(args)...);
322 return ((T *)this)->visit((const IfThenElse *)node, std::forward<Args>(args)...);
324 return ((T *)this)->visit((const Evaluate *)node, std::forward<Args>(args)...);
326 return ((T *)this)->visit((const Prefetch *)node, std::forward<Args>(args)...);
328 return ((T *)this)->visit((const Atomic *)node, std::forward<Args>(args)...);
330 return ((T *)this)->visit((const HoistedStorage *)node, std::forward<Args>(args)...);
331 }
332 return StmtRet{};
333 }
334
335public:
336 template<typename... Args>
337 HALIDE_ALWAYS_INLINE StmtRet dispatch(const Stmt &s, Args &&...args) {
338 return dispatch_stmt(s.get(), std::forward<Args>(args)...);
339 }
340
341 template<typename... Args>
342 HALIDE_ALWAYS_INLINE StmtRet dispatch(Stmt &&s, Args &&...args) {
343 return dispatch_stmt(s.get(), std::forward<Args>(args)...);
344 }
345
346 template<typename... Args>
347 HALIDE_ALWAYS_INLINE ExprRet dispatch(const Expr &e, Args &&...args) {
348 return dispatch_expr(e.get(), std::forward<Args>(args)...);
349 }
350
351 template<typename... Args>
352 HALIDE_ALWAYS_INLINE ExprRet dispatch(Expr &&e, Args &&...args) {
353 return dispatch_expr(e.get(), std::forward<Args>(args)...);
354 }
356
357} // namespace Internal
358} // namespace Halide
359
360#endif
#define internal_error
Definition Error.h:215
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
A base class for algorithms that walk recursively over the IR without visiting the same node twice.
Definition IRVisitor.h:85
void visit(const Div *) override
void visit(const Shuffle *) override
void visit(const NE *) override
void visit(const Block *) override
void visit(const EQ *) override
void visit(const Let *) override
void visit(const Provide *) override
void visit(const StringImm *) override
virtual void include(const Expr &)
By default these methods add the node to the visited set, and return whether or not it was already th...
void visit(const For *) override
void visit(const HoistedStorage *) override
void visit(const Ramp *) override
void visit(const Or *) override
void visit(const UIntImm *) override
void visit(const Mul *) override
void visit(const AssertStmt *) override
void visit(const GE *) override
void visit(const Min *) override
void visit(const Free *) override
void visit(const Add *) override
void visit(const Acquire *) override
void visit(const Store *) override
void visit(const Max *) override
void visit(const IntImm *) override
These methods should call 'include' on the children to only visit them if they haven't been visited a...
void visit(const IfThenElse *) override
void visit(const LT *) override
void visit(const VectorReduce *) override
void visit(const Atomic *) override
void visit(const Sub *) override
void visit(const Not *) override
void visit(const Mod *) override
void visit(const ProducerConsumer *) override
void visit(const LetStmt *) override
void visit(const LE *) override
void visit(const Allocate *) override
void visit(const Load *) override
virtual void include(const Stmt &)
void visit(const Realize *) override
void visit(const Prefetch *) override
void visit(const FloatImm *) override
void visit(const Fork *) override
void visit(const Call *) override
void visit(const Reinterpret *) override
void visit(const And *) override
void visit(const Variable *) override
void visit(const Evaluate *) override
void visit(const Broadcast *) override
void visit(const GT *) override
void visit(const Cast *) override
void visit(const Select *) override
virtual void visit(const NE *)
virtual void visit(const Mul *)
virtual void visit(const Max *)
virtual void visit(const Select *)
virtual void visit(const Load *)
virtual void visit(const Div *)
virtual void visit(const Fork *)
virtual void visit(const Sub *)
virtual void visit(const LE *)
virtual ~IRVisitor()=default
virtual void visit(const ProducerConsumer *)
virtual void visit(const VectorReduce *)
virtual void visit(const GE *)
virtual void visit(const StringImm *)
virtual void visit(const Allocate *)
virtual void visit(const IfThenElse *)
virtual void visit(const For *)
virtual void visit(const Prefetch *)
virtual void visit(const Block *)
virtual void visit(const UIntImm *)
virtual void visit(const HoistedStorage *)
virtual void visit(const FloatImm *)
virtual void visit(const GT *)
virtual void visit(const Mod *)
virtual void visit(const Acquire *)
virtual void visit(const Atomic *)
virtual void visit(const Ramp *)
virtual void visit(const Free *)
virtual void visit(const IntImm *)
virtual void visit(const Or *)
virtual void visit(const EQ *)
virtual void visit(const Broadcast *)
virtual void visit(const Call *)
virtual void visit(const Min *)
virtual void visit(const Variable *)
virtual void visit(const Realize *)
virtual void visit(const Add *)
virtual void visit(const Shuffle *)
virtual void visit(const Reinterpret *)
virtual void visit(const Evaluate *)
virtual void visit(const AssertStmt *)
virtual void visit(const And *)
virtual void visit(const LetStmt *)
virtual void visit(const Store *)
virtual void visit(const Provide *)
virtual void visit(const LT *)
virtual void visit(const Cast *)
virtual void visit(const Not *)
virtual void visit(const Let *)
A visitor/mutator capable of passing arbitrary arguments to the visit methods using CRTP and returnin...
Definition IRVisitor.h:163
HALIDE_ALWAYS_INLINE StmtRet dispatch(const Stmt &s, Args &&...args)
Definition IRVisitor.h:337
HALIDE_ALWAYS_INLINE ExprRet dispatch(Expr &&e, Args &&...args)
Definition IRVisitor.h:352
HALIDE_ALWAYS_INLINE StmtRet dispatch(Stmt &&s, Args &&...args)
Definition IRVisitor.h:342
HALIDE_ALWAYS_INLINE ExprRet dispatch(const Expr &e, Args &&...args)
Definition IRVisitor.h:347
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
A fragment of Halide syntax.
Definition Expr.h:258
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition Expr.h:321
The sum of two expressions.
Definition IR.h:56
Allocate a scratch area called with the given name, type, and size.
Definition IR.h:371
Logical and - are both expressions true.
Definition IR.h:175
If the 'condition' is false, then evaluate and return the message, which should be a call to an error...
Definition IR.h:294
Lock all the Store nodes in the body statement.
Definition IR.h:994
A base class for expression nodes.
Definition Expr.h:143
IR nodes are split into expressions and statements.
Definition Expr.h:134
A sequence of statements to be executed in-order.
Definition IR.h:442
A vector with 'lanes' elements, in which every element is 'value'.
Definition IR.h:259
A function call.
Definition IR.h:490
The actual IR nodes begin here.
Definition IR.h:30
The ratio of two expressions.
Definition IR.h:83
Is the first expression equal to the second.
Definition IR.h:121
Evaluate and discard an expression, presumably because it has some side-effect.
Definition IR.h:476
Floating point constants.
Definition Expr.h:236
A for loop.
Definition IR.h:848
A pair of statements executed concurrently.
Definition IR.h:457
Free the resources associated with the given buffer.
Definition IR.h:413
Is the first expression greater than or equal to the second.
Definition IR.h:166
Is the first expression greater than the second.
Definition IR.h:157
Represents a location where storage will be hoisted to for a Func / Realize node with a given name.
Definition IR.h:978
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition Expr.h:113
An if-then-else block.
Definition IR.h:466
Integer constants.
Definition Expr.h:218
Is the first expression less than or equal to the second.
Definition IR.h:148
Is the first expression less than the second.
Definition IR.h:139
A let expression, like you might find in a functional language.
Definition IR.h:271
The statement form of a let node.
Definition IR.h:282
Load a value from a named symbol if predicate is true.
Definition IR.h:217
The greater of two values.
Definition IR.h:112
The lesser of two values.
Definition IR.h:103
The remainder of a / b.
Definition IR.h:94
The product of two expressions.
Definition IR.h:74
Is the first expression not equal to the second.
Definition IR.h:130
Logical not - true if the expression false.
Definition IR.h:193
Logical or - is at least one of the expression true.
Definition IR.h:184
Represent a multi-dimensional region of a Func or an ImageParam that needs to be prefetched.
Definition IR.h:956
This node is a helpful annotation to do with permissions.
Definition IR.h:315
This defines the value of a function at a multi-dimensional location.
Definition IR.h:354
A linear ramp vector node.
Definition IR.h:247
Allocate a multi-dimensional buffer of the given type and size.
Definition IR.h:427
Reinterpret value as another type, without affecting any of the bits (on little-endian systems).
Definition IR.h:47
A ternary operator.
Definition IR.h:204
Construct a new vector by taking elements from another sequence of vectors.
Definition IR.h:884
A reference-counted handle to a statement node.
Definition Expr.h:427
HALIDE_ALWAYS_INLINE const BaseStmtNode * get() const
Override get() to return a BaseStmtNode * instead of an IRNode *.
Definition Expr.h:435
Store a 'value' to the buffer called 'name' at a given 'index' if 'predicate' is true.
Definition IR.h:333
String constants.
Definition Expr.h:245
The difference of two expressions.
Definition IR.h:65
Unsigned integer constants.
Definition Expr.h:227
A named variable.
Definition IR.h:801
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition IR.h:1012