Halide
Expr.h
Go to the documentation of this file.
1 #ifndef HALIDE_EXPR_H
2 #define HALIDE_EXPR_H
3 
4 /** \file
5  * Base classes for Halide expressions (\ref Halide::Expr) and statements (\ref Halide::Internal::Stmt)
6  */
7 
8 #include <string>
9 #include <vector>
10 
11 #include "Debug.h"
12 #include "Error.h"
13 #include "Float16.h"
14 #include "Type.h"
15 #include "IntrusivePtr.h"
16 #include "Util.h"
17 
18 namespace Halide {
19 namespace Internal {
20 
21 class IRMutator2;
22 class IRVisitor;
23 
24 /** All our IR node types get unique IDs for the purposes of RTTI */
25 enum class IRNodeType {
26  IntImm,
27  UIntImm,
28  FloatImm,
29  StringImm,
30  Cast,
31  Variable,
32  Add,
33  Sub,
34  Mul,
35  Div,
36  Mod,
37  Min,
38  Max,
39  EQ,
40  NE,
41  LT,
42  LE,
43  GT,
44  GE,
45  And,
46  Or,
47  Not,
48  Select,
49  Load,
50  Ramp,
51  Broadcast,
52  Call,
53  Let,
54  LetStmt,
55  AssertStmt,
57  For,
58  Store,
59  Provide,
60  Allocate,
61  Free,
62  Realize,
63  Block,
64  IfThenElse,
65  Evaluate,
66  Shuffle,
67  Prefetch,
68 };
69 
70 /** The abstract base classes for a node in the Halide IR. */
71 struct IRNode {
72 
73  /** We use the visitor pattern to traverse IR nodes throughout the
74  * compiler, so we have a virtual accept method which accepts
75  * visitors.
76  */
77  virtual void accept(IRVisitor *v) const = 0;
78  IRNode(IRNodeType t) : node_type(t) {}
79  virtual ~IRNode() {}
80 
81  /** These classes are all managed with intrusive reference
82  * counting, so we also track a reference count. It's mutable
83  * so that we can do reference counting even through const
84  * references to IR nodes.
85  */
87 
88  /** Each IR node subclass has a unique identifier. We can compare
89  * these values to do runtime type identification. We don't
90  * compile with rtti because that injects run-time type
91  * identification stuff everywhere (and often breaks when linking
92  * external libraries compiled without it), and we only want it
93  * for IR nodes. One might want to put this value in the vtable,
94  * but that adds another level of indirection, and for Exprs we
95  * have 32 free bits in between the ref count and the Type
96  * anyway, so this doesn't increase the memory footprint of an IR node.
97  */
99 };
100 
101 template<>
102 EXPORT inline RefCount &ref_count<IRNode>(const IRNode *n) {return n->ref_count;}
103 
104 template<>
105 EXPORT inline void destroy<IRNode>(const IRNode *n) {delete n;}
106 
107 /** IR nodes are split into expressions and statements. These are
108  similar to expressions and statements in C - expressions
109  represent some value and have some type (e.g. x + 3), and
110  statements are side-effecting pieces of code that do not
111  represent a value (e.g. assert(x > 3)) */
112 
113 /** A base class for statement nodes. They have no properties or
114  methods beyond base IR nodes for now. */
115 struct BaseStmtNode : public IRNode {
117  virtual Stmt mutate_stmt(IRMutator2 *v) const = 0;
118 };
119 
120 /** A base class for expression nodes. They all contain their types
121  * (e.g. Int(32), Float(32)) */
122 struct BaseExprNode : public IRNode {
124  virtual Expr mutate_expr(IRMutator2 *v) const = 0;
126 };
127 
128 /** We use the "curiously recurring template pattern" to avoid
129  duplicated code in the IR Nodes. These classes live between the
130  abstract base classes and the actual IR Nodes in the
131  inheritance hierarchy. It provides an implementation of the
132  accept function necessary for the visitor pattern to work, and
133  a concrete instantiation of a unique IRNodeType per class. */
134 template<typename T>
135 struct ExprNode : public BaseExprNode {
136  EXPORT void accept(IRVisitor *v) const;
137  EXPORT Expr mutate_expr(IRMutator2 *v) const;
138  ExprNode() : BaseExprNode(T::_node_type) {}
139  virtual ~ExprNode() {}
140 };
141 
142 template<typename T>
143 struct StmtNode : public BaseStmtNode {
144  EXPORT void accept(IRVisitor *v) const;
145  EXPORT Stmt mutate_stmt(IRMutator2 *v) const;
146  StmtNode() : BaseStmtNode(T::_node_type) {}
147  virtual ~StmtNode() {}
148 };
149 
150 /** IR nodes are passed around opaque handles to them. This is a
151  base class for those handles. It manages the reference count,
152  and dispatches visitors. */
153 struct IRHandle : public IntrusivePtr<const IRNode> {
154  IRHandle() : IntrusivePtr<const IRNode>() {}
155  IRHandle(const IRNode *p) : IntrusivePtr<const IRNode>(p) {}
156 
157  /** Dispatch to the correct visitor method for this node. E.g. if
158  * this node is actually an Add node, then this will call
159  * IRVisitor::visit(const Add *) */
160  void accept(IRVisitor *v) const {
161  ptr->accept(v);
162  }
163 
164  /** Downcast this ir node to its actual type (e.g. Add, or
165  * Select). This returns nullptr if the node is not of the requested
166  * type. Example usage:
167  *
168  * if (const Add *add = node->as<Add>()) {
169  * // This is an add node
170  * }
171  */
172  template<typename T> const T *as() const {
173  if (ptr && ptr->node_type == T::_node_type) {
174  return (const T *)ptr;
175  }
176  return nullptr;
177  }
178 };
179 
180 
181 /** Integer constants */
182 struct IntImm : public ExprNode<IntImm> {
184 
185  static const IntImm *make(Type t, int64_t value) {
186  internal_assert(t.is_int() && t.is_scalar())
187  << "IntImm must be a scalar Int\n";
188  internal_assert(t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64)
189  << "IntImm must be 8, 16, 32, or 64-bit\n";
190 
191  // Normalize the value by dropping the high bits
192  value <<= (64 - t.bits());
193  // Then sign-extending to get them back
194  value >>= (64 - t.bits());
195 
196  IntImm *node = new IntImm;
197  node->type = t;
198  node->value = value;
199  return node;
200  }
201 
202  static const IRNodeType _node_type = IRNodeType::IntImm;
203 };
204 
205 /** Unsigned integer constants */
206 struct UIntImm : public ExprNode<UIntImm> {
208 
209  static const UIntImm *make(Type t, uint64_t value) {
210  internal_assert(t.is_uint() && t.is_scalar())
211  << "UIntImm must be a scalar UInt\n";
212  internal_assert(t.bits() == 1 || t.bits() == 8 || t.bits() == 16 || t.bits() == 32 || t.bits() == 64)
213  << "UIntImm must be 1, 8, 16, 32, or 64-bit\n";
214 
215  // Normalize the value by dropping the high bits
216  value <<= (64 - t.bits());
217  value >>= (64 - t.bits());
218 
219  UIntImm *node = new UIntImm;
220  node->type = t;
221  node->value = value;
222  return node;
223  }
224 
225  static const IRNodeType _node_type = IRNodeType::UIntImm;
226 };
227 
228 /** Floating point constants */
229 struct FloatImm : public ExprNode<FloatImm> {
230  double value;
231 
232  static const FloatImm *make(Type t, double value) {
234  << "FloatImm must be a scalar Float\n";
235  FloatImm *node = new FloatImm;
236  node->type = t;
237  switch (t.bits()) {
238  case 16:
239  node->value = (double)((float16_t)value);
240  break;
241  case 32:
242  node->value = (float)value;
243  break;
244  case 64:
245  node->value = value;
246  break;
247  default:
248  internal_error << "FloatImm must be 16, 32, or 64-bit\n";
249  }
250 
251  return node;
252  }
253 
254  static const IRNodeType _node_type = IRNodeType::FloatImm;
255 };
256 
257 /** String constants */
258 struct StringImm : public ExprNode<StringImm> {
259  std::string value;
260 
261  static const StringImm *make(const std::string &val) {
262  StringImm *node = new StringImm;
263  node->type = type_of<const char *>();
264  node->value = val;
265  return node;
266  }
267 
268  static const IRNodeType _node_type = IRNodeType::StringImm;
269 };
270 
271 } // namespace Internal
272 
273 /** A fragment of Halide syntax. It's implemented as reference-counted
274  * handle to a concrete expression node, but it's immutable, so you
275  * can treat it as a value type. */
276 struct Expr : public Internal::IRHandle {
277  /** Make an undefined expression */
278  Expr() : Internal::IRHandle() {}
279 
280  /** Make an expression from a concrete expression node pointer (e.g. Add) */
281  Expr(const Internal::BaseExprNode *n) : IRHandle(n) {}
282 
283  /** Make an expression representing numeric constants of various types. */
284  // @{
285  EXPORT explicit Expr(int8_t x) : IRHandle(Internal::IntImm::make(Int(8), x)) {}
286  EXPORT explicit Expr(int16_t x) : IRHandle(Internal::IntImm::make(Int(16), x)) {}
287  EXPORT Expr(int32_t x) : IRHandle(Internal::IntImm::make(Int(32), x)) {}
288  EXPORT explicit Expr(int64_t x) : IRHandle(Internal::IntImm::make(Int(64), x)) {}
289  EXPORT explicit Expr(uint8_t x) : IRHandle(Internal::UIntImm::make(UInt(8), x)) {}
290  EXPORT explicit Expr(uint16_t x) : IRHandle(Internal::UIntImm::make(UInt(16), x)) {}
291  EXPORT explicit Expr(uint32_t x) : IRHandle(Internal::UIntImm::make(UInt(32), x)) {}
292  EXPORT explicit Expr(uint64_t x) : IRHandle(Internal::UIntImm::make(UInt(64), x)) {}
293  EXPORT Expr(float16_t x) : IRHandle(Internal::FloatImm::make(Float(16), (double)x)) {}
294  EXPORT Expr(float x) : IRHandle(Internal::FloatImm::make(Float(32), x)) {}
295  EXPORT explicit Expr(double x) : IRHandle(Internal::FloatImm::make(Float(64), x)) {}
296  // @}
297 
298  /** Make an expression representing a const string (i.e. a StringImm) */
299  EXPORT Expr(const std::string &s) : IRHandle(Internal::StringImm::make(s)) {}
300 
301  /** Get the type of this expression node */
302  Type type() const {
303  return ((const Internal::BaseExprNode *)ptr)->type;
304  }
305 };
306 
307 /** This lets you use an Expr as a key in a map of the form
308  * map<Expr, Foo, ExprCompare> */
309 struct ExprCompare {
310  bool operator()(const Expr &a, const Expr &b) const {
311  return a.get() < b.get();
312  }
313 };
314 
315 /** An enum describing a type of device API. Used by schedules, and in
316  * the For loop IR node. */
317 enum class DeviceAPI {
318  None, /// Used to denote for loops that run on the same device as the containing code.
319  Host,
320  Default_GPU,
321  CUDA,
322  OpenCL,
323  GLSL,
325  Metal,
326  Hexagon
327 };
328 
329 /** An array containing all the device apis. Useful for iterating
330  * through them. */
340 
341 namespace Internal {
342 
343 /** An enum describing a type of loop traversal. Used in schedules, and in
344  * the For loop IR node. GPUBlock and GPUThread are implicitly parallel */
345 enum class ForType {
346  Serial,
347  Parallel,
348  Vectorized,
349  Unrolled,
350  GPUBlock,
351  GPUThread
352 };
353 
354 
355 /** A reference-counted handle to a statement node. */
356 struct Stmt : public IRHandle {
357  Stmt() : IRHandle() {}
358  Stmt(const BaseStmtNode *n) : IRHandle(n) {}
359 
360  /** This lets you use a Stmt as a key in a map of the form
361  * map<Stmt, Foo, Stmt::Compare> */
362  struct Compare {
363  bool operator()(const Stmt &a, const Stmt &b) const {
364  return a.ptr < b.ptr;
365  }
366  };
367 };
368 
369 
370 } // namespace Internal
371 } // namespace Halide
372 
373 #endif
RefCount ref_count
These classes are all managed with intrusive reference counting, so we also track a reference count...
Definition: Expr.h:86
Unsigned integer constants.
Definition: Expr.h:206
Various utility functions used internally Halide.
A base class for algorithms that need to recursively walk over the IR.
Definition: IRVisitor.h:22
bool is_int() const
Is this type a signed integer type?
Definition: Type.h:372
A fragment of Halide syntax.
Definition: Expr.h:276
Integer constants.
Definition: Expr.h:182
A reference-counted handle to a statement node.
Definition: Expr.h:356
IR nodes are split into expressions and statements.
Definition: Expr.h:115
EXPORT RefCount & ref_count< IRNode >(const IRNode *n)
Definition: Expr.h:102
A base class for passes over the IR which modify it (e.g.
Definition: IRMutator.h:124
We use the "curiously recurring template pattern" to avoid duplicated code in the IR Nodes...
Definition: Expr.h:135
Expr()
Make an undefined expression.
Definition: Expr.h:278
bool is_uint() const
Is this type an unsigned integer type?
Definition: Type.h:375
IRNode(IRNodeType t)
Definition: Expr.h:78
#define internal_error
Definition: Error.h:135
static const UIntImm * make(Type t, uint64_t value)
Definition: Expr.h:209
Floating point constants.
Definition: Expr.h:229
signed __INT8_TYPE__ int8_t
virtual ~IRNode()
Definition: Expr.h:79
EXPORT Expr(int8_t x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:285
Intrusive shared pointers have a reference count (a RefCount object) stored in the class itself...
Definition: IntrusivePtr.h:57
BaseExprNode(IRNodeType t)
Definition: Expr.h:123
Defines methods for manipulating and analyzing boolean expressions.
EXPORT Expr(int16_t x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:286
unsigned __INT8_TYPE__ uint8_t
Expr(const Internal::BaseExprNode *n)
Make an expression from a concrete expression node pointer (e.g.
Definition: Expr.h:281
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition: Expr.h:25
A base class for expression nodes.
Definition: Expr.h:122
IR nodes are passed around opaque handles to them.
Definition: Expr.h:153
EXPORT Expr(double x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:295
BaseStmtNode(IRNodeType t)
Definition: Expr.h:116
static const IntImm * make(Type t, int64_t value)
Definition: Expr.h:185
Used to denote for loops that run on the same device as the containing code.
This lets you use a Stmt as a key in a map of the form map<Stmt, Foo, Stmt::Compare> ...
Definition: Expr.h:362
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition: Expr.h:98
Defines halide types.
bool operator()(const Stmt &a, const Stmt &b) const
Definition: Expr.h:363
Class that provides a type that implements half precision floating point (IEEE754 2008 binary16) in s...
Definition: Float16.h:17
static const StringImm * make(const std::string &val)
Definition: Expr.h:261
bool is_float() const
Is this type a floating point type (float or double).
Definition: Type.h:369
const DeviceAPI all_device_apis[]
An array containing all the device apis.
Definition: Expr.h:331
void accept(IRVisitor *v) const
Dispatch to the correct visitor method for this node.
Definition: Expr.h:160
#define internal_assert(c)
Definition: Error.h:140
const T * as() const
Downcast this ir node to its actual type (e.g.
Definition: Expr.h:172
This lets you use an Expr as a key in a map of the form map<Expr, Foo, ExprCompare> ...
Definition: Expr.h:309
unsigned __INT32_TYPE__ uint32_t
EXPORT Expr(uint16_t x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:290
EXPORT Expr(float x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:294
Type type() const
Get the type of this expression node.
Definition: Expr.h:302
EXPORT Expr(int32_t x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:287
Support classes for reference-counting via intrusive shared pointers.
signed __INT64_TYPE__ int64_t
EXPORT void destroy< IRNode >(const IRNode *n)
Definition: Expr.h:105
static const FloatImm * make(Type t, double value)
Definition: Expr.h:232
ForType
An enum describing a type of loop traversal.
Definition: Expr.h:345
Stmt(const BaseStmtNode *n)
Definition: Expr.h:358
The abstract base classes for a node in the Halide IR.
Definition: Expr.h:71
Type UInt(int bits, int lanes=1)
Constructing an unsigned integer type.
Definition: Type.h:442
bool operator()(const Expr &a, const Expr &b) const
Definition: Expr.h:310
bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition: Type.h:366
unsigned __INT16_TYPE__ uint16_t
EXPORT Expr(uint64_t x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:292
EXPORT Expr(int64_t x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:288
Types in the halide type system.
Definition: Type.h:285
T * get() const
Access the raw pointer in a variety of ways.
Definition: IntrusivePtr.h:89
int bits() const
Return the bit size of a single element of this type.
Definition: Type.h:331
EXPORT Expr(uint32_t x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:291
#define EXPORT
Definition: Util.h:30
EXPORT Expr(float16_t x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:293
String constants.
Definition: Expr.h:258
unsigned __INT64_TYPE__ uint64_t
A class representing a reference count to be used with IntrusivePtr.
Definition: IntrusivePtr.h:19
DeviceAPI
An enum describing a type of device API.
Definition: Expr.h:317
EXPORT Expr(uint8_t x)
Make an expression representing numeric constants of various types.
Definition: Expr.h:289
IRHandle(const IRNode *p)
Definition: Expr.h:155
signed __INT32_TYPE__ int32_t
Type Int(int bits, int lanes=1)
Constructing a signed integer type.
Definition: Type.h:437
Defines functions for debug logging during code generation.
signed __INT16_TYPE__ int16_t
Type Float(int bits, int lanes=1)
Construct a floating-point type.
Definition: Type.h:447
EXPORT Expr(const std::string &s)
Make an expression representing a const string (i.e.
Definition: Expr.h:299