Halide
IRVisitor.h
Go to the documentation of this file.
1 #ifndef HALIDE_IR_VISITOR_H
2 #define HALIDE_IR_VISITOR_H
3 
4 #include <set>
5 
6 #include "IR.h"
7 
8 /** \file
9  * Defines the base class for things that recursively walk over the IR
10  */
11 
12 namespace Halide {
13 namespace Internal {
14 
15 /** A base class for algorithms that need to recursively walk over the
16  * IR. The default implementations just recursively walk over the
17  * children. Override the ones you care about.
18  */
19 class IRVisitor {
20 public:
21  IRVisitor() = default;
22  virtual ~IRVisitor() = default;
23 
24 protected:
25  // ExprNode<> and StmtNode<> are allowed to call visit (to implement accept())
26  template<typename T>
27  friend struct ExprNode;
28 
29  template<typename T>
30  friend struct StmtNode;
31 
32  virtual void visit(const IntImm *);
33  virtual void visit(const UIntImm *);
34  virtual void visit(const FloatImm *);
35  virtual void visit(const StringImm *);
36  virtual void visit(const Cast *);
37  virtual void visit(const Reinterpret *);
38  virtual void visit(const Variable *);
39  virtual void visit(const Add *);
40  virtual void visit(const Sub *);
41  virtual void visit(const Mul *);
42  virtual void visit(const Div *);
43  virtual void visit(const Mod *);
44  virtual void visit(const Min *);
45  virtual void visit(const Max *);
46  virtual void visit(const EQ *);
47  virtual void visit(const NE *);
48  virtual void visit(const LT *);
49  virtual void visit(const LE *);
50  virtual void visit(const GT *);
51  virtual void visit(const GE *);
52  virtual void visit(const And *);
53  virtual void visit(const Or *);
54  virtual void visit(const Not *);
55  virtual void visit(const Select *);
56  virtual void visit(const Load *);
57  virtual void visit(const Ramp *);
58  virtual void visit(const Broadcast *);
59  virtual void visit(const Call *);
60  virtual void visit(const Let *);
61  virtual void visit(const LetStmt *);
62  virtual void visit(const AssertStmt *);
63  virtual void visit(const ProducerConsumer *);
64  virtual void visit(const For *);
65  virtual void visit(const Store *);
66  virtual void visit(const Provide *);
67  virtual void visit(const Allocate *);
68  virtual void visit(const Free *);
69  virtual void visit(const Realize *);
70  virtual void visit(const Block *);
71  virtual void visit(const IfThenElse *);
72  virtual void visit(const Evaluate *);
73  virtual void visit(const Shuffle *);
74  virtual void visit(const VectorReduce *);
75  virtual void visit(const Prefetch *);
76  virtual void visit(const Fork *);
77  virtual void visit(const Acquire *);
78  virtual void visit(const Atomic *);
79 };
80 
81 /** A base class for algorithms that walk recursively over the IR
82  * without visiting the same node twice. This is for passes that are
83  * capable of interpreting the IR as a DAG instead of a tree. */
84 class IRGraphVisitor : public IRVisitor {
85 protected:
86  /** By default these methods add the node to the visited set, and
87  * return whether or not it was already there. If it wasn't there,
88  * it delegates to the appropriate visit method. You can override
89  * them if you like. */
90  // @{
91  virtual void include(const Expr &);
92  virtual void include(const Stmt &);
93  // @}
94 
95 private:
96  /** The nodes visited so far */
97  std::set<IRHandle> visited;
98 
99 protected:
100  /** These methods should call 'include' on the children to only
101  * visit them if they haven't been visited already. */
102  // @{
103  void visit(const IntImm *) override;
104  void visit(const UIntImm *) override;
105  void visit(const FloatImm *) override;
106  void visit(const StringImm *) override;
107  void visit(const Cast *) override;
108  void visit(const Reinterpret *) override;
109  void visit(const Variable *) override;
110  void visit(const Add *) override;
111  void visit(const Sub *) override;
112  void visit(const Mul *) override;
113  void visit(const Div *) override;
114  void visit(const Mod *) override;
115  void visit(const Min *) override;
116  void visit(const Max *) override;
117  void visit(const EQ *) override;
118  void visit(const NE *) override;
119  void visit(const LT *) override;
120  void visit(const LE *) override;
121  void visit(const GT *) override;
122  void visit(const GE *) override;
123  void visit(const And *) override;
124  void visit(const Or *) override;
125  void visit(const Not *) override;
126  void visit(const Select *) override;
127  void visit(const Load *) override;
128  void visit(const Ramp *) override;
129  void visit(const Broadcast *) override;
130  void visit(const Call *) override;
131  void visit(const Let *) override;
132  void visit(const LetStmt *) override;
133  void visit(const AssertStmt *) override;
134  void visit(const ProducerConsumer *) override;
135  void visit(const For *) override;
136  void visit(const Store *) override;
137  void visit(const Provide *) override;
138  void visit(const Allocate *) override;
139  void visit(const Free *) override;
140  void visit(const Realize *) override;
141  void visit(const Block *) override;
142  void visit(const IfThenElse *) override;
143  void visit(const Evaluate *) override;
144  void visit(const Shuffle *) override;
145  void visit(const VectorReduce *) override;
146  void visit(const Prefetch *) override;
147  void visit(const Acquire *) override;
148  void visit(const Fork *) override;
149  void visit(const Atomic *) override;
150  // @}
151 };
152 
153 /** A visitor/mutator capable of passing arbitrary arguments to the
154  * visit methods using CRTP and returning any types from them. All
155  * Expr visitors must have the same signature, and all Stmt visitors
156  * must have the same signature. Does not have default implementations
157  * of the visit methods. */
158 template<typename T, typename ExprRet, typename StmtRet>
160 private:
161  template<typename... Args>
162  ExprRet dispatch_expr(const BaseExprNode *node, Args &&...args) {
163  if (node == nullptr) {
164  return ExprRet{};
165  }
166  switch (node->node_type) {
167  case IRNodeType::IntImm:
168  return ((T *)this)->visit((const IntImm *)node, std::forward<Args>(args)...);
169  case IRNodeType::UIntImm:
170  return ((T *)this)->visit((const UIntImm *)node, std::forward<Args>(args)...);
172  return ((T *)this)->visit((const FloatImm *)node, std::forward<Args>(args)...);
174  return ((T *)this)->visit((const StringImm *)node, std::forward<Args>(args)...);
176  return ((T *)this)->visit((const Broadcast *)node, std::forward<Args>(args)...);
177  case IRNodeType::Cast:
178  return ((T *)this)->visit((const Cast *)node, std::forward<Args>(args)...);
180  return ((T *)this)->visit((const Reinterpret *)node, std::forward<Args>(args)...);
182  return ((T *)this)->visit((const Variable *)node, std::forward<Args>(args)...);
183  case IRNodeType::Add:
184  return ((T *)this)->visit((const Add *)node, std::forward<Args>(args)...);
185  case IRNodeType::Sub:
186  return ((T *)this)->visit((const Sub *)node, std::forward<Args>(args)...);
187  case IRNodeType::Mod:
188  return ((T *)this)->visit((const Mod *)node, std::forward<Args>(args)...);
189  case IRNodeType::Mul:
190  return ((T *)this)->visit((const Mul *)node, std::forward<Args>(args)...);
191  case IRNodeType::Div:
192  return ((T *)this)->visit((const Div *)node, std::forward<Args>(args)...);
193  case IRNodeType::Min:
194  return ((T *)this)->visit((const Min *)node, std::forward<Args>(args)...);
195  case IRNodeType::Max:
196  return ((T *)this)->visit((const Max *)node, std::forward<Args>(args)...);
197  case IRNodeType::EQ:
198  return ((T *)this)->visit((const EQ *)node, std::forward<Args>(args)...);
199  case IRNodeType::NE:
200  return ((T *)this)->visit((const NE *)node, std::forward<Args>(args)...);
201  case IRNodeType::LT:
202  return ((T *)this)->visit((const LT *)node, std::forward<Args>(args)...);
203  case IRNodeType::LE:
204  return ((T *)this)->visit((const LE *)node, std::forward<Args>(args)...);
205  case IRNodeType::GT:
206  return ((T *)this)->visit((const GT *)node, std::forward<Args>(args)...);
207  case IRNodeType::GE:
208  return ((T *)this)->visit((const GE *)node, std::forward<Args>(args)...);
209  case IRNodeType::And:
210  return ((T *)this)->visit((const And *)node, std::forward<Args>(args)...);
211  case IRNodeType::Or:
212  return ((T *)this)->visit((const Or *)node, std::forward<Args>(args)...);
213  case IRNodeType::Not:
214  return ((T *)this)->visit((const Not *)node, std::forward<Args>(args)...);
215  case IRNodeType::Select:
216  return ((T *)this)->visit((const Select *)node, std::forward<Args>(args)...);
217  case IRNodeType::Load:
218  return ((T *)this)->visit((const Load *)node, std::forward<Args>(args)...);
219  case IRNodeType::Ramp:
220  return ((T *)this)->visit((const Ramp *)node, std::forward<Args>(args)...);
221  case IRNodeType::Call:
222  return ((T *)this)->visit((const Call *)node, std::forward<Args>(args)...);
223  case IRNodeType::Let:
224  return ((T *)this)->visit((const Let *)node, std::forward<Args>(args)...);
225  case IRNodeType::Shuffle:
226  return ((T *)this)->visit((const Shuffle *)node, std::forward<Args>(args)...);
228  return ((T *)this)->visit((const VectorReduce *)node, std::forward<Args>(args)...);
229  // Explicitly list the Stmt types rather than using a
230  // default case so that when new IR nodes are added we
231  // don't miss them here.
232  case IRNodeType::LetStmt:
235  case IRNodeType::For:
236  case IRNodeType::Acquire:
237  case IRNodeType::Store:
238  case IRNodeType::Provide:
240  case IRNodeType::Free:
241  case IRNodeType::Realize:
242  case IRNodeType::Block:
243  case IRNodeType::Fork:
247  case IRNodeType::Atomic:
248  internal_error << "Unreachable";
249  }
250  return ExprRet{};
251  }
252 
253  template<typename... Args>
254  StmtRet dispatch_stmt(const BaseStmtNode *node, Args &&...args) {
255  if (node == nullptr) {
256  return StmtRet{};
257  }
258  switch (node->node_type) {
259  case IRNodeType::IntImm:
260  case IRNodeType::UIntImm:
264  case IRNodeType::Cast:
267  case IRNodeType::Add:
268  case IRNodeType::Sub:
269  case IRNodeType::Mod:
270  case IRNodeType::Mul:
271  case IRNodeType::Div:
272  case IRNodeType::Min:
273  case IRNodeType::Max:
274  case IRNodeType::EQ:
275  case IRNodeType::NE:
276  case IRNodeType::LT:
277  case IRNodeType::LE:
278  case IRNodeType::GT:
279  case IRNodeType::GE:
280  case IRNodeType::And:
281  case IRNodeType::Or:
282  case IRNodeType::Not:
283  case IRNodeType::Select:
284  case IRNodeType::Load:
285  case IRNodeType::Ramp:
286  case IRNodeType::Call:
287  case IRNodeType::Let:
288  case IRNodeType::Shuffle:
290  internal_error << "Unreachable";
291  break;
292  case IRNodeType::LetStmt:
293  return ((T *)this)->visit((const LetStmt *)node, std::forward<Args>(args)...);
295  return ((T *)this)->visit((const AssertStmt *)node, std::forward<Args>(args)...);
297  return ((T *)this)->visit((const ProducerConsumer *)node, std::forward<Args>(args)...);
298  case IRNodeType::For:
299  return ((T *)this)->visit((const For *)node, std::forward<Args>(args)...);
300  case IRNodeType::Acquire:
301  return ((T *)this)->visit((const Acquire *)node, std::forward<Args>(args)...);
302  case IRNodeType::Store:
303  return ((T *)this)->visit((const Store *)node, std::forward<Args>(args)...);
304  case IRNodeType::Provide:
305  return ((T *)this)->visit((const Provide *)node, std::forward<Args>(args)...);
307  return ((T *)this)->visit((const Allocate *)node, std::forward<Args>(args)...);
308  case IRNodeType::Free:
309  return ((T *)this)->visit((const Free *)node, std::forward<Args>(args)...);
310  case IRNodeType::Realize:
311  return ((T *)this)->visit((const Realize *)node, std::forward<Args>(args)...);
312  case IRNodeType::Block:
313  return ((T *)this)->visit((const Block *)node, std::forward<Args>(args)...);
314  case IRNodeType::Fork:
315  return ((T *)this)->visit((const Fork *)node, std::forward<Args>(args)...);
317  return ((T *)this)->visit((const IfThenElse *)node, std::forward<Args>(args)...);
319  return ((T *)this)->visit((const Evaluate *)node, std::forward<Args>(args)...);
321  return ((T *)this)->visit((const Prefetch *)node, std::forward<Args>(args)...);
322  case IRNodeType::Atomic:
323  return ((T *)this)->visit((const Atomic *)node, std::forward<Args>(args)...);
324  }
325  return StmtRet{};
326  }
327 
328 public:
329  template<typename... Args>
330  HALIDE_ALWAYS_INLINE StmtRet dispatch(const Stmt &s, Args &&...args) {
331  return dispatch_stmt(s.get(), std::forward<Args>(args)...);
332  }
333 
334  template<typename... Args>
335  HALIDE_ALWAYS_INLINE StmtRet dispatch(Stmt &&s, Args &&...args) {
336  return dispatch_stmt(s.get(), std::forward<Args>(args)...);
337  }
338 
339  template<typename... Args>
340  HALIDE_ALWAYS_INLINE ExprRet dispatch(const Expr &e, Args &&...args) {
341  return dispatch_expr(e.get(), std::forward<Args>(args)...);
342  }
343 
344  template<typename... Args>
345  HALIDE_ALWAYS_INLINE ExprRet dispatch(Expr &&e, Args &&...args) {
346  return dispatch_expr(e.get(), std::forward<Args>(args)...);
347  }
348 };
349 
350 } // namespace Internal
351 } // namespace Halide
352 
353 #endif
Halide::Internal::Acquire
Definition: IR.h:807
Halide::Internal::Allocate
Allocate a scratch area called with the given name, type, and size.
Definition: IR.h:363
Halide::Internal::Add
The sum of two expressions.
Definition: IR.h:48
Halide::Internal::IRNodeType::Cast
@ Cast
Halide::Internal::IRVisitor::visit
virtual void visit(const IntImm *)
Halide::Internal::IRNodeType::Not
@ Not
Halide::Internal::IRNodeType::Select
@ Select
Halide::Internal::IRNodeType::Block
@ Block
Halide::Internal::VectorReduce
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:929
Halide::Internal::IRNodeType::LE
@ LE
Halide::Internal::IRNodeType::For
@ For
Halide::Internal::GE
Is the first expression greater than or equal to the second.
Definition: IR.h:158
Halide::Internal::IRVisitor
A base class for algorithms that need to recursively walk over the IR.
Definition: IRVisitor.h:19
Halide::Internal::For
A for loop.
Definition: IR.h:788
Halide::Internal::IRNodeType::EQ
@ EQ
Halide::Internal::FloatImm
Floating point constants.
Definition: Expr.h:235
Halide::Internal::IRNodeType::Evaluate
@ Evaluate
Halide::Internal::IRNodeType::Or
@ Or
Halide::Internal::Broadcast
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:251
Halide::Internal::Div
The ratio of two expressions.
Definition: IR.h:75
Halide::Internal::IRNodeType::LetStmt
@ LetStmt
Halide::Internal::ExprNode
We use the "curiously recurring template pattern" to avoid duplicated code in the IR Nodes.
Definition: Expr.h:157
Halide::Internal::IRNodeType::LT
@ LT
Halide::Internal::IntImm
Integer constants.
Definition: Expr.h:217
Halide::Internal::IRNodeType::Provide
@ Provide
Halide::Internal::LetStmt
The statement form of a let node.
Definition: IR.h:274
Halide::Internal::VariadicVisitor::dispatch
HALIDE_ALWAYS_INLINE StmtRet dispatch(Stmt &&s, Args &&...args)
Definition: IRVisitor.h:335
Halide::Internal::IRNodeType::Min
@ Min
IR.h
Halide::Internal::Cast
The actual IR nodes begin here.
Definition: IR.h:29
Halide::Internal::LE
Is the first expression less than or equal to the second.
Definition: IR.h:140
Halide::Internal::IRNodeType::GE
@ GE
Halide::Internal::NE
Is the first expression not equal to the second.
Definition: IR.h:122
Halide::Internal::Fork
A pair of statements executed concurrently.
Definition: IR.h:449
Halide::Internal::Stmt
A reference-counted handle to a statement node.
Definition: Expr.h:418
Halide::Internal::IRNodeType::Variable
@ Variable
Halide::Internal::IRNodeType::ProducerConsumer
@ ProducerConsumer
Halide::Internal::BaseStmtNode
IR nodes are split into expressions and statements.
Definition: Expr.h:133
Halide::Internal::Load
Load a value from a named symbol if predicate is true.
Definition: IR.h:209
Halide::Internal::Free
Free the resources associated with the given buffer.
Definition: IR.h:405
Halide::Internal::Realize
Allocate a multi-dimensional buffer of the given type and size.
Definition: IR.h:419
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AbstractGenerator.h:19
Halide::Internal::Or
Logical or - is at least one of the expression true.
Definition: IR.h:176
Halide::Internal::EQ
Is the first expression equal to the second.
Definition: IR.h:113
Halide::Internal::Provide
This defines the value of a function at a multi-dimensional location.
Definition: IR.h:346
Halide::Internal::IRNodeType::Atomic
@ Atomic
Halide::LinkageType::Internal
@ Internal
Not visible externally, similar to 'static' linkage in C.
Halide::Internal::Max
The greater of two values.
Definition: IR.h:104
Halide::Internal::IRNodeType::Fork
@ Fork
Halide::Internal::IRNodeType::Broadcast
@ Broadcast
Halide::Internal::IRNodeType::Free
@ Free
Halide::Internal::IRNodeType::Add
@ Add
Halide::Internal::Let
A let expression, like you might find in a functional language.
Definition: IR.h:263
Halide::Internal::IRGraphVisitor::visit
void visit(const IntImm *) override
These methods should call 'include' on the children to only visit them if they haven't been visited a...
Halide::Internal::IRNodeType::Sub
@ Sub
Halide::Internal::IRNodeType::FloatImm
@ FloatImm
Halide::Internal::Ramp
A linear ramp vector node.
Definition: IR.h:239
HALIDE_ALWAYS_INLINE
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:40
Halide::Internal::IRNodeType::AssertStmt
@ AssertStmt
Halide::Internal::IRNodeType::IntImm
@ IntImm
Halide::Internal::Evaluate
Evaluate and discard an expression, presumably because it has some side-effect.
Definition: IR.h:468
Halide::Internal::VariadicVisitor::dispatch
HALIDE_ALWAYS_INLINE ExprRet dispatch(const Expr &e, Args &&...args)
Definition: IRVisitor.h:340
Halide::Internal::IRNodeType::And
@ And
Halide::Internal::IRNodeType::Max
@ Max
Halide::Internal::BaseExprNode
A base class for expression nodes.
Definition: Expr.h:142
Halide::Internal::VariadicVisitor::dispatch
HALIDE_ALWAYS_INLINE StmtRet dispatch(const Stmt &s, Args &&...args)
Definition: IRVisitor.h:330
Halide::Internal::Store
Store a 'value' to the buffer called 'name' at a given 'index' if 'predicate' is true.
Definition: IR.h:325
Halide::Internal::VariadicVisitor
A visitor/mutator capable of passing arbitrary arguments to the visit methods using CRTP and returnin...
Definition: IRVisitor.h:159
Halide::Internal::IRVisitor::~IRVisitor
virtual ~IRVisitor()=default
Halide::Internal::Variable
A named variable.
Definition: IR.h:741
Halide::Internal::Min
The lesser of two values.
Definition: IR.h:95
Halide::Internal::ProducerConsumer
This node is a helpful annotation to do with permissions.
Definition: IR.h:307
Halide::Internal::IRNodeType::VectorReduce
@ VectorReduce
Halide::Internal::Mod
The remainder of a / b.
Definition: IR.h:86
Halide::Internal::IRNodeType::Let
@ Let
Halide::Internal::VariadicVisitor::dispatch
HALIDE_ALWAYS_INLINE ExprRet dispatch(Expr &&e, Args &&...args)
Definition: IRVisitor.h:345
Halide::Internal::AssertStmt
If the 'condition' is false, then evaluate and return the message, which should be a call to an error...
Definition: IR.h:286
Halide::Internal::IRNodeType::NE
@ NE
Halide::Internal::IRNodeType::Reinterpret
@ Reinterpret
internal_error
#define internal_error
Definition: Errors.h:23
Halide::Internal::Call
A function call.
Definition: IR.h:482
Halide::Internal::Stmt::get
const HALIDE_ALWAYS_INLINE BaseStmtNode * get() const
Override get() to return a BaseStmtNode * instead of an IRNode *.
Definition: Expr.h:426
Halide::Internal::IRNodeType::IfThenElse
@ IfThenElse
Halide::Internal::IRNodeType::Mul
@ Mul
Halide::Internal::Reinterpret
Reinterpret value as another type, without affecting any of the bits (on little-endian systems).
Definition: IR.h:39
Halide::Internal::IRNodeType::GT
@ GT
Halide::Internal::IRNodeType::UIntImm
@ UIntImm
Halide::Internal::Select
A ternary operator.
Definition: IR.h:196
Halide::Internal::IRNodeType::StringImm
@ StringImm
Halide::Internal::IRNodeType::Ramp
@ Ramp
Halide::Internal::IRNode::node_type
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:112
Halide::Expr
A fragment of Halide syntax.
Definition: Expr.h:257
Halide::Internal::IRGraphVisitor
A base class for algorithms that walk recursively over the IR without visiting the same node twice.
Definition: IRVisitor.h:84
Halide::Internal::IRGraphVisitor::include
virtual void include(const Expr &)
By default these methods add the node to the visited set, and return whether or not it was already th...
Halide::Internal::IRNodeType::Acquire
@ Acquire
Halide::Internal::Prefetch
Represent a multi-dimensional region of a Func or an ImageParam that needs to be prefetched.
Definition: IR.h:888
Halide::Internal::GT
Is the first expression greater than the second.
Definition: IR.h:149
Halide::Internal::IRVisitor::IRVisitor
IRVisitor()=default
Halide::Internal::Atomic
Lock all the Store nodes in the body statement.
Definition: IR.h:911
Halide::Internal::IRNodeType::Realize
@ Realize
Halide::Internal::Shuffle
Construct a new vector by taking elements from another sequence of vectors.
Definition: IR.h:819
Halide::Internal::IRNodeType::Mod
@ Mod
Halide::Expr::get
const HALIDE_ALWAYS_INLINE Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition: Expr.h:315
Halide::Internal::IRNodeType::Shuffle
@ Shuffle
Halide::Internal::UIntImm
Unsigned integer constants.
Definition: Expr.h:226
Halide::Internal::StmtNode
Definition: Expr.h:167
Halide::Internal::IRNodeType::Store
@ Store
Halide::Internal::IRNodeType::Prefetch
@ Prefetch
Halide::Internal::IRNodeType::Allocate
@ Allocate
Halide::Internal::IRNodeType::Div
@ Div
Halide::Internal::IfThenElse
An if-then-else block.
Definition: IR.h:458
Halide::Internal::Sub
The difference of two expressions.
Definition: IR.h:57
Halide::Internal::IRNodeType::Call
@ Call
Halide::Internal::And
Logical and - are both expressions true.
Definition: IR.h:167
Halide::Internal::StringImm
String constants.
Definition: Expr.h:244
Halide::Internal::Not
Logical not - true if the expression false.
Definition: IR.h:185
Halide::Internal::LT
Is the first expression less than the second.
Definition: IR.h:131
Halide::Internal::Mul
The product of two expressions.
Definition: IR.h:66
Halide::Internal::Block
A sequence of statements to be executed in-order.
Definition: IR.h:434
Halide::Internal::IRNodeType::Load
@ Load