Halide
IRVisitor.h
Go to the documentation of this file.
1 #ifndef HALIDE_IR_VISITOR_H
2 #define HALIDE_IR_VISITOR_H
3 
4 #include <map>
5 #include <set>
6 #include <string>
7 
8 #include "IR.h"
9 
10 /** \file
11  * Defines the base class for things that recursively walk over the IR
12  */
13 
14 namespace Halide {
15 namespace Internal {
16 
17 /** A base class for algorithms that need to recursively walk over the
18  * IR. The default implementations just recursively walk over the
19  * children. Override the ones you care about.
20  */
21 class IRVisitor {
22 public:
23  IRVisitor();
24  virtual ~IRVisitor();
25 
26 protected:
27  // ExprNode<> and StmtNode<> are allowed to call visit (to implement accept())
28  template<typename T>
29  friend struct ExprNode;
30 
31  template<typename T>
32  friend struct StmtNode;
33 
34  virtual void visit(const IntImm *);
35  virtual void visit(const UIntImm *);
36  virtual void visit(const FloatImm *);
37  virtual void visit(const StringImm *);
38  virtual void visit(const Cast *);
39  virtual void visit(const Variable *);
40  virtual void visit(const Add *);
41  virtual void visit(const Sub *);
42  virtual void visit(const Mul *);
43  virtual void visit(const Div *);
44  virtual void visit(const Mod *);
45  virtual void visit(const Min *);
46  virtual void visit(const Max *);
47  virtual void visit(const EQ *);
48  virtual void visit(const NE *);
49  virtual void visit(const LT *);
50  virtual void visit(const LE *);
51  virtual void visit(const GT *);
52  virtual void visit(const GE *);
53  virtual void visit(const And *);
54  virtual void visit(const Or *);
55  virtual void visit(const Not *);
56  virtual void visit(const Select *);
57  virtual void visit(const Load *);
58  virtual void visit(const Ramp *);
59  virtual void visit(const Broadcast *);
60  virtual void visit(const Call *);
61  virtual void visit(const Let *);
62  virtual void visit(const LetStmt *);
63  virtual void visit(const AssertStmt *);
64  virtual void visit(const ProducerConsumer *);
65  virtual void visit(const For *);
66  virtual void visit(const Store *);
67  virtual void visit(const Provide *);
68  virtual void visit(const Allocate *);
69  virtual void visit(const Free *);
70  virtual void visit(const Realize *);
71  virtual void visit(const Block *);
72  virtual void visit(const IfThenElse *);
73  virtual void visit(const Evaluate *);
74  virtual void visit(const Shuffle *);
75  virtual void visit(const VectorReduce *);
76  virtual void visit(const Prefetch *);
77  virtual void visit(const Fork *);
78  virtual void visit(const Acquire *);
79  virtual void visit(const Atomic *);
80 };
81 
82 /** A base class for algorithms that walk recursively over the IR
83  * without visiting the same node twice. This is for passes that are
84  * capable of interpreting the IR as a DAG instead of a tree. */
85 class IRGraphVisitor : public IRVisitor {
86 protected:
87  /** By default these methods add the node to the visited set, and
88  * return whether or not it was already there. If it wasn't there,
89  * it delegates to the appropriate visit method. You can override
90  * them if you like. */
91  // @{
92  virtual void include(const Expr &);
93  virtual void include(const Stmt &);
94  // @}
95 
96 private:
97  /** The nodes visited so far */
98  std::set<IRHandle> visited;
99 
100 protected:
101  /** These methods should call 'include' on the children to only
102  * visit them if they haven't been visited already. */
103  // @{
104  void visit(const IntImm *) override;
105  void visit(const UIntImm *) override;
106  void visit(const FloatImm *) override;
107  void visit(const StringImm *) override;
108  void visit(const Cast *) override;
109  void visit(const Variable *) override;
110  void visit(const Add *) override;
111  void visit(const Sub *) override;
112  void visit(const Mul *) override;
113  void visit(const Div *) override;
114  void visit(const Mod *) override;
115  void visit(const Min *) override;
116  void visit(const Max *) override;
117  void visit(const EQ *) override;
118  void visit(const NE *) override;
119  void visit(const LT *) override;
120  void visit(const LE *) override;
121  void visit(const GT *) override;
122  void visit(const GE *) override;
123  void visit(const And *) override;
124  void visit(const Or *) override;
125  void visit(const Not *) override;
126  void visit(const Select *) override;
127  void visit(const Load *) override;
128  void visit(const Ramp *) override;
129  void visit(const Broadcast *) override;
130  void visit(const Call *) override;
131  void visit(const Let *) override;
132  void visit(const LetStmt *) override;
133  void visit(const AssertStmt *) override;
134  void visit(const ProducerConsumer *) override;
135  void visit(const For *) override;
136  void visit(const Store *) override;
137  void visit(const Provide *) override;
138  void visit(const Allocate *) override;
139  void visit(const Free *) override;
140  void visit(const Realize *) override;
141  void visit(const Block *) override;
142  void visit(const IfThenElse *) override;
143  void visit(const Evaluate *) override;
144  void visit(const Shuffle *) override;
145  void visit(const VectorReduce *) override;
146  void visit(const Prefetch *) override;
147  void visit(const Acquire *) override;
148  void visit(const Fork *) override;
149  void visit(const Atomic *) override;
150  // @}
151 };
152 
153 /** A visitor/mutator capable of passing arbitrary arguments to the
154  * visit methods using CRTP and returning any types from them. All
155  * Expr visitors must have the same signature, and all Stmt visitors
156  * must have the same signature. Does not have default implementations
157  * of the visit methods. */
158 template<typename T, typename ExprRet, typename StmtRet>
160 private:
161  template<typename... Args>
162  ExprRet dispatch_expr(const BaseExprNode *node, Args &&... args) {
163  if (node == nullptr) return ExprRet{};
164  switch (node->node_type) {
165  case IRNodeType::IntImm:
166  return ((T *)this)->visit((const IntImm *)node, std::forward<Args>(args)...);
167  case IRNodeType::UIntImm:
168  return ((T *)this)->visit((const UIntImm *)node, std::forward<Args>(args)...);
170  return ((T *)this)->visit((const FloatImm *)node, std::forward<Args>(args)...);
172  return ((T *)this)->visit((const StringImm *)node, std::forward<Args>(args)...);
174  return ((T *)this)->visit((const Broadcast *)node, std::forward<Args>(args)...);
175  case IRNodeType::Cast:
176  return ((T *)this)->visit((const Cast *)node, std::forward<Args>(args)...);
178  return ((T *)this)->visit((const Variable *)node, std::forward<Args>(args)...);
179  case IRNodeType::Add:
180  return ((T *)this)->visit((const Add *)node, std::forward<Args>(args)...);
181  case IRNodeType::Sub:
182  return ((T *)this)->visit((const Sub *)node, std::forward<Args>(args)...);
183  case IRNodeType::Mod:
184  return ((T *)this)->visit((const Mod *)node, std::forward<Args>(args)...);
185  case IRNodeType::Mul:
186  return ((T *)this)->visit((const Mul *)node, std::forward<Args>(args)...);
187  case IRNodeType::Div:
188  return ((T *)this)->visit((const Div *)node, std::forward<Args>(args)...);
189  case IRNodeType::Min:
190  return ((T *)this)->visit((const Min *)node, std::forward<Args>(args)...);
191  case IRNodeType::Max:
192  return ((T *)this)->visit((const Max *)node, std::forward<Args>(args)...);
193  case IRNodeType::EQ:
194  return ((T *)this)->visit((const EQ *)node, std::forward<Args>(args)...);
195  case IRNodeType::NE:
196  return ((T *)this)->visit((const NE *)node, std::forward<Args>(args)...);
197  case IRNodeType::LT:
198  return ((T *)this)->visit((const LT *)node, std::forward<Args>(args)...);
199  case IRNodeType::LE:
200  return ((T *)this)->visit((const LE *)node, std::forward<Args>(args)...);
201  case IRNodeType::GT:
202  return ((T *)this)->visit((const GT *)node, std::forward<Args>(args)...);
203  case IRNodeType::GE:
204  return ((T *)this)->visit((const GE *)node, std::forward<Args>(args)...);
205  case IRNodeType::And:
206  return ((T *)this)->visit((const And *)node, std::forward<Args>(args)...);
207  case IRNodeType::Or:
208  return ((T *)this)->visit((const Or *)node, std::forward<Args>(args)...);
209  case IRNodeType::Not:
210  return ((T *)this)->visit((const Not *)node, std::forward<Args>(args)...);
211  case IRNodeType::Select:
212  return ((T *)this)->visit((const Select *)node, std::forward<Args>(args)...);
213  case IRNodeType::Load:
214  return ((T *)this)->visit((const Load *)node, std::forward<Args>(args)...);
215  case IRNodeType::Ramp:
216  return ((T *)this)->visit((const Ramp *)node, std::forward<Args>(args)...);
217  case IRNodeType::Call:
218  return ((T *)this)->visit((const Call *)node, std::forward<Args>(args)...);
219  case IRNodeType::Let:
220  return ((T *)this)->visit((const Let *)node, std::forward<Args>(args)...);
221  case IRNodeType::Shuffle:
222  return ((T *)this)->visit((const Shuffle *)node, std::forward<Args>(args)...);
224  return ((T *)this)->visit((const VectorReduce *)node, std::forward<Args>(args)...);
225  // Explicitly list the Stmt types rather than using a
226  // default case so that when new IR nodes are added we
227  // don't miss them here.
228  case IRNodeType::LetStmt:
231  case IRNodeType::For:
232  case IRNodeType::Acquire:
233  case IRNodeType::Store:
234  case IRNodeType::Provide:
236  case IRNodeType::Free:
237  case IRNodeType::Realize:
238  case IRNodeType::Block:
239  case IRNodeType::Fork:
243  case IRNodeType::Atomic:
244  internal_error << "Unreachable";
245  }
246  return ExprRet{};
247  }
248 
249  template<typename... Args>
250  StmtRet dispatch_stmt(const BaseStmtNode *node, Args &&... args) {
251  if (node == nullptr) return StmtRet{};
252  switch (node->node_type) {
253  case IRNodeType::IntImm:
254  case IRNodeType::UIntImm:
258  case IRNodeType::Cast:
260  case IRNodeType::Add:
261  case IRNodeType::Sub:
262  case IRNodeType::Mod:
263  case IRNodeType::Mul:
264  case IRNodeType::Div:
265  case IRNodeType::Min:
266  case IRNodeType::Max:
267  case IRNodeType::EQ:
268  case IRNodeType::NE:
269  case IRNodeType::LT:
270  case IRNodeType::LE:
271  case IRNodeType::GT:
272  case IRNodeType::GE:
273  case IRNodeType::And:
274  case IRNodeType::Or:
275  case IRNodeType::Not:
276  case IRNodeType::Select:
277  case IRNodeType::Load:
278  case IRNodeType::Ramp:
279  case IRNodeType::Call:
280  case IRNodeType::Let:
281  case IRNodeType::Shuffle:
283  internal_error << "Unreachable";
284  break;
285  case IRNodeType::LetStmt:
286  return ((T *)this)->visit((const LetStmt *)node, std::forward<Args>(args)...);
288  return ((T *)this)->visit((const AssertStmt *)node, std::forward<Args>(args)...);
290  return ((T *)this)->visit((const ProducerConsumer *)node, std::forward<Args>(args)...);
291  case IRNodeType::For:
292  return ((T *)this)->visit((const For *)node, std::forward<Args>(args)...);
293  case IRNodeType::Acquire:
294  return ((T *)this)->visit((const Acquire *)node, std::forward<Args>(args)...);
295  case IRNodeType::Store:
296  return ((T *)this)->visit((const Store *)node, std::forward<Args>(args)...);
297  case IRNodeType::Provide:
298  return ((T *)this)->visit((const Provide *)node, std::forward<Args>(args)...);
300  return ((T *)this)->visit((const Allocate *)node, std::forward<Args>(args)...);
301  case IRNodeType::Free:
302  return ((T *)this)->visit((const Free *)node, std::forward<Args>(args)...);
303  case IRNodeType::Realize:
304  return ((T *)this)->visit((const Realize *)node, std::forward<Args>(args)...);
305  case IRNodeType::Block:
306  return ((T *)this)->visit((const Block *)node, std::forward<Args>(args)...);
307  case IRNodeType::Fork:
308  return ((T *)this)->visit((const Fork *)node, std::forward<Args>(args)...);
310  return ((T *)this)->visit((const IfThenElse *)node, std::forward<Args>(args)...);
312  return ((T *)this)->visit((const Evaluate *)node, std::forward<Args>(args)...);
314  return ((T *)this)->visit((const Prefetch *)node, std::forward<Args>(args)...);
315  case IRNodeType::Atomic:
316  return ((T *)this)->visit((const Atomic *)node, std::forward<Args>(args)...);
317  }
318  return StmtRet{};
319  }
320 
321 public:
322  template<typename... Args>
323  HALIDE_ALWAYS_INLINE StmtRet dispatch(const Stmt &s, Args &&... args) {
324  return dispatch_stmt(s.get(), std::forward<Args>(args)...);
325  }
326 
327  template<typename... Args>
328  HALIDE_ALWAYS_INLINE StmtRet dispatch(Stmt &&s, Args &&... args) {
329  return dispatch_stmt(s.get(), std::forward<Args>(args)...);
330  }
331 
332  template<typename... Args>
333  HALIDE_ALWAYS_INLINE ExprRet dispatch(const Expr &e, Args &&... args) {
334  return dispatch_expr(e.get(), std::forward<Args>(args)...);
335  }
336 
337  template<typename... Args>
338  HALIDE_ALWAYS_INLINE ExprRet dispatch(Expr &&e, Args &&... args) {
339  return dispatch_expr(e.get(), std::forward<Args>(args)...);
340  }
341 };
342 
343 } // namespace Internal
344 } // namespace Halide
345 
346 #endif
Halide::Internal::Acquire
Definition: IR.h:717
Halide::Internal::Allocate
Allocate a scratch area called with the given name, type, and size.
Definition: IR.h:352
Halide::Internal::Add
The sum of two expressions.
Definition: IR.h:38
Halide::Internal::IRNodeType::Cast
@ Cast
Halide::Internal::IRVisitor::visit
virtual void visit(const IntImm *)
Halide::Internal::IRVisitor::~IRVisitor
virtual ~IRVisitor()
Halide::Internal::IRNodeType::Not
@ Not
Halide::Internal::VariadicVisitor::dispatch
HALIDE_ALWAYS_INLINE StmtRet dispatch(Stmt &&s, Args &&... args)
Definition: IRVisitor.h:328
Halide::Internal::IRNodeType::Select
@ Select
Halide::Internal::IRNodeType::Block
@ Block
Halide::Internal::VectorReduce
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition: IR.h:827
Halide::Internal::IRNodeType::LE
@ LE
Halide::Internal::IRNodeType::For
@ For
Halide::Internal::GE
Is the first expression greater than or equal to the second.
Definition: IR.h:148
Halide::Internal::IRVisitor
A base class for algorithms that need to recursively walk over the IR.
Definition: IRVisitor.h:21
Halide::Internal::For
A for loop.
Definition: IR.h:698
Halide::Internal::IRNodeType::EQ
@ EQ
Halide::Internal::FloatImm
Floating point constants.
Definition: Expr.h:234
Halide::Internal::IRNodeType::Evaluate
@ Evaluate
Halide::Internal::IRNodeType::Or
@ Or
Halide::Internal::Broadcast
A vector with 'lanes' elements, in which every element is 'value'.
Definition: IR.h:241
Halide::Internal::Div
The ratio of two expressions.
Definition: IR.h:65
Halide::Internal::IRNodeType::LetStmt
@ LetStmt
Halide::Internal::ExprNode
We use the "curiously recurring template pattern" to avoid duplicated code in the IR Nodes.
Definition: Expr.h:156
Halide::Internal::IRNodeType::LT
@ LT
Halide::Internal::IntImm
Integer constants.
Definition: Expr.h:216
Halide::Internal::IRNodeType::Provide
@ Provide
Halide::Internal::LetStmt
The statement form of a let node.
Definition: IR.h:264
Halide::Internal::VariadicVisitor::dispatch
HALIDE_ALWAYS_INLINE ExprRet dispatch(const Expr &e, Args &&... args)
Definition: IRVisitor.h:333
Halide::Internal::IRNodeType::Min
@ Min
IR.h
Halide::Internal::Cast
The actual IR nodes begin here.
Definition: IR.h:29
Halide::Internal::LE
Is the first expression less than or equal to the second.
Definition: IR.h:130
Halide::Internal::IRNodeType::GE
@ GE
Halide::Internal::NE
Is the first expression not equal to the second.
Definition: IR.h:112
Halide::Internal::Fork
A pair of statements executed concurrently.
Definition: IR.h:431
Halide::Internal::IRVisitor::IRVisitor
IRVisitor()
Halide::Internal::Stmt
A reference-counted handle to a statement node.
Definition: Expr.h:409
Halide::Internal::IRNodeType::Variable
@ Variable
Halide::Internal::IRNodeType::ProducerConsumer
@ ProducerConsumer
Halide::Internal::BaseStmtNode
IR nodes are split into expressions and statements.
Definition: Expr.h:132
Halide::Internal::Load
Load a value from a named symbol if predicate is true.
Definition: IR.h:199
Halide::Internal::Free
Free the resources associated with the given buffer.
Definition: IR.h:388
Halide::Internal::Realize
Allocate a multi-dimensional buffer of the given type and size.
Definition: IR.h:402
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AddAtomicMutex.h:21
Halide::Internal::VariadicVisitor::dispatch
HALIDE_ALWAYS_INLINE StmtRet dispatch(const Stmt &s, Args &&... args)
Definition: IRVisitor.h:323
Halide::Internal::Or
Logical or - is at least one of the expression true.
Definition: IR.h:166
Halide::Internal::EQ
Is the first expression equal to the second.
Definition: IR.h:103
Halide::Internal::Provide
This defines the value of a function at a multi-dimensional location.
Definition: IR.h:336
Halide::Internal::IRNodeType::Atomic
@ Atomic
Halide::LinkageType::Internal
@ Internal
Not visible externally, similar to 'static' linkage in C.
Halide::Internal::Max
The greater of two values.
Definition: IR.h:94
Halide::Internal::IRNodeType::Fork
@ Fork
Halide::Internal::VariadicVisitor::dispatch
HALIDE_ALWAYS_INLINE ExprRet dispatch(Expr &&e, Args &&... args)
Definition: IRVisitor.h:338
Halide::Internal::IRNodeType::Broadcast
@ Broadcast
Halide::Internal::IRNodeType::Free
@ Free
Halide::Internal::IRNodeType::Add
@ Add
Halide::Internal::Let
A let expression, like you might find in a functional language.
Definition: IR.h:253
Halide::Internal::IRGraphVisitor::visit
void visit(const IntImm *) override
These methods should call 'include' on the children to only visit them if they haven't been visited a...
Halide::Internal::IRNodeType::Sub
@ Sub
Halide::Internal::IRNodeType::FloatImm
@ FloatImm
Halide::Internal::Ramp
A linear ramp vector node.
Definition: IR.h:229
HALIDE_ALWAYS_INLINE
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:32
Halide::Internal::IRNodeType::AssertStmt
@ AssertStmt
Halide::Internal::IRNodeType::IntImm
@ IntImm
Halide::Internal::Evaluate
Evaluate and discard an expression, presumably because it has some side-effect.
Definition: IR.h:450
Halide::Internal::IRNodeType::And
@ And
Halide::Internal::IRNodeType::Max
@ Max
Halide::Internal::BaseExprNode
A base class for expression nodes.
Definition: Expr.h:141
Halide::Internal::Store
Store a 'value' to the buffer called 'name' at a given 'index' if 'predicate' is true.
Definition: IR.h:315
Halide::Internal::VariadicVisitor
A visitor/mutator capable of passing arbitrary arguments to the visit methods using CRTP and returnin...
Definition: IRVisitor.h:159
Halide::Internal::Variable
A named variable.
Definition: IR.h:651
Halide::Internal::Min
The lesser of two values.
Definition: IR.h:85
Halide::Internal::ProducerConsumer
This node is a helpful annotation to do with permissions.
Definition: IR.h:297
Halide::Internal::IRNodeType::VectorReduce
@ VectorReduce
Halide::Internal::Mod
The remainder of a / b.
Definition: IR.h:76
Halide::Internal::IRNodeType::Let
@ Let
Halide::Internal::AssertStmt
If the 'condition' is false, then evaluate and return the message, which should be a call to an error...
Definition: IR.h:276
Halide::Internal::IRNodeType::NE
@ NE
internal_error
#define internal_error
Definition: Errors.h:23
Halide::Internal::Call
A function call.
Definition: IR.h:464
Halide::Internal::Stmt::get
const HALIDE_ALWAYS_INLINE BaseStmtNode * get() const
Override get() to return a BaseStmtNode * instead of an IRNode *.
Definition: Expr.h:417
Halide::Internal::IRNodeType::IfThenElse
@ IfThenElse
Halide::Internal::IRNodeType::Mul
@ Mul
Halide::Internal::IRNodeType::GT
@ GT
Halide::Internal::IRNodeType::UIntImm
@ UIntImm
Halide::Internal::Select
A ternary operator.
Definition: IR.h:186
Halide::Internal::IRNodeType::StringImm
@ StringImm
Halide::Internal::IRNodeType::Ramp
@ Ramp
Halide::Internal::IRNode::node_type
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:111
Halide::Expr
A fragment of Halide syntax.
Definition: Expr.h:256
Halide::Internal::IRGraphVisitor
A base class for algorithms that walk recursively over the IR without visiting the same node twice.
Definition: IRVisitor.h:85
Halide::Internal::IRGraphVisitor::include
virtual void include(const Expr &)
By default these methods add the node to the visited set, and return whether or not it was already th...
Halide::Internal::IRNodeType::Acquire
@ Acquire
Halide::Internal::Prefetch
Represent a multi-dimensional region of a Func or an ImageParam that needs to be prefetched.
Definition: IR.h:786
Halide::Internal::GT
Is the first expression greater than the second.
Definition: IR.h:139
Halide::Internal::Atomic
Lock all the Store nodes in the body statement.
Definition: IR.h:809
Halide::Internal::IRNodeType::Realize
@ Realize
Halide::Internal::Shuffle
Construct a new vector by taking elements from another sequence of vectors.
Definition: IR.h:729
Halide::Internal::IRNodeType::Mod
@ Mod
Halide::Expr::get
const HALIDE_ALWAYS_INLINE Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition: Expr.h:314
Halide::Internal::IRNodeType::Shuffle
@ Shuffle
Halide::Internal::UIntImm
Unsigned integer constants.
Definition: Expr.h:225
Halide::Internal::StmtNode
Definition: Expr.h:166
Halide::Internal::IRNodeType::Store
@ Store
Halide::Internal::IRNodeType::Prefetch
@ Prefetch
Halide::Internal::IRNodeType::Allocate
@ Allocate
Halide::Internal::IRNodeType::Div
@ Div
Halide::Internal::IfThenElse
An if-then-else block.
Definition: IR.h:440
Halide::Internal::Sub
The difference of two expressions.
Definition: IR.h:47
Halide::Internal::IRNodeType::Call
@ Call
Halide::Internal::And
Logical and - are both expressions true.
Definition: IR.h:157
Halide::Internal::StringImm
String constants.
Definition: Expr.h:243
Halide::Internal::Not
Logical not - true if the expression false.
Definition: IR.h:175
Halide::Internal::LT
Is the first expression less than the second.
Definition: IR.h:121
Halide::Internal::Mul
The product of two expressions.
Definition: IR.h:56
Halide::Internal::Block
A sequence of statements to be executed in-order.
Definition: IR.h:417
Halide::Internal::IRNodeType::Load
@ Load