Halide 21.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
IRMatch.h
Go to the documentation of this file.
1#ifndef HALIDE_IR_MATCH_H
2#define HALIDE_IR_MATCH_H
3
4/** \file
5 * Defines a method to match a fragment of IR against a pattern containing wildcards
6 */
7
8#include <map>
9#include <random>
10#include <set>
11#include <vector>
12
13#include "IR.h"
14#include "IREquality.h"
15#include "IROperator.h"
16
17namespace Halide {
18namespace Internal {
19
20/** Does the first expression have the same structure as the second?
21 * Variables in the first expression with the name * are interpreted
22 * as wildcards, and their matching equivalent in the second
23 * expression is placed in the vector give as the third argument.
24 * Wildcards require the types to match. For the type bits and width,
25 * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26 * integer vectors of any width (including scalars), and a UInt(0, 0)
27 * will match any unsigned integer type.
28 *
29 * For example:
30 \code
31 Expr x = Variable::make(Int(32), "*");
32 match(x + x, 3 + (2*k), result)
33 \endcode
34 * should return true, and set result[0] to 3 and
35 * result[1] to 2*k.
36 */
37bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38
39/** Does the first expression have the same structure as the second?
40 * Variables are matched consistently. The first time a variable is
41 * matched, it assumes the value of the matching part of the second
42 * expression. Subsequent matches must be equal to the first match.
43 *
44 * For example:
45 \code
46 Var x("x"), y("y");
47 match(x*(x + y), a*(a + b), result)
48 \endcode
49 * should return true, and set result["x"] = a, and result["y"] = b.
50 */
51bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52
53/** Rewrite the expression x to have `lanes` lanes. This is useful
54 * for substituting the results of expr_match into a pattern expression. */
55Expr with_lanes(const Expr &x, int lanes);
56
58
59/** An alternative template-metaprogramming approach to expression
60 * matching. Potentially more efficient. We lift the expression
61 * pattern into a type, and then use force-inlined functions to
62 * generate efficient matching and reconstruction code for any
63 * pattern. Pattern elements are either one of the classes in the
64 * namespace IRMatcher, or are non-null Exprs (represented as
65 * BaseExprNode &).
66 *
67 * Pattern elements that are fully specified by their pattern can be
68 * built into an expression using the make method. Some patterns,
69 * such as a broadcast that matches any number of lanes, don't have
70 * enough information to recreate an Expr.
71 */
72namespace IRMatcher {
73
74constexpr int max_wild = 6;
75
76static const halide_type_t i64_type = {halide_type_int, 64, 1};
77
78/** To save stack space, the matcher objects are largely stateless and
79 * immutable. This state object is built up during matching and then
80 * consumed when constructing a replacement Expr.
81 */
85
86 // values of the lanes field with special meaning.
87 static constexpr uint16_t signed_integer_overflow = 0x8000;
88 static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89
91
93 void set_binding(int i, const BaseExprNode &n) noexcept {
94 bindings[i] = &n;
95 }
96
98 const BaseExprNode *get_binding(int i) const noexcept {
99 return bindings[i];
100 }
101
103 void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104 bound_const[i].u.i64 = s;
105 bound_const_type[i] = t;
106 }
107
109 void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110 bound_const[i].u.u64 = u;
111 bound_const_type[i] = t;
112 }
113
115 void set_bound_const(int i, double f, halide_type_t t) noexcept {
116 bound_const[i].u.f64 = f;
117 bound_const_type[i] = t;
118 }
119
122 bound_const[i] = val;
123 bound_const_type[i] = t;
124 }
125
127 void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128 val = bound_const[i];
129 type = bound_const_type[i];
130 }
131
133 // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
134 MatcherState() noexcept {
135 }
136};
137
138template<typename T,
139 typename = typename std::remove_reference<T>::type::pattern_tag>
141 struct type {};
142};
143
144template<typename T>
145struct bindings {
146 constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147};
148
151 ty.lanes &= ~MatcherState::special_values_mask;
154 }
155 // unreachable
156 return Expr();
157}
158
161 halide_type_t scalar_type = ty;
162 if (scalar_type.lanes & MatcherState::special_values_mask) {
163 return make_const_special_expr(scalar_type);
164 }
165
166 const int lanes = scalar_type.lanes;
167 scalar_type.lanes = 1;
168
169 Expr e;
170 switch (scalar_type.code) {
171 case halide_type_int:
172 e = IntImm::make(scalar_type, val.u.i64);
173 break;
174 case halide_type_uint:
175 e = UIntImm::make(scalar_type, val.u.u64);
176 break;
179 e = FloatImm::make(scalar_type, val.u.f64);
180 break;
181 default:
182 // Unreachable
183 return Expr();
184 }
185 if (lanes > 1) {
186 e = Broadcast::make(std::move(e), lanes);
187 }
188 return e;
189}
190
191// A pattern that matches a specific expression
193 struct pattern_tag {};
194
195 constexpr static uint32_t binds = 0;
196
197 // What is the weakest and strongest IR node this could possibly be
200 constexpr static bool canonical = true;
201
203
204 template<uint32_t bound>
205 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
206 return equal(expr, e);
207 }
208
210 Expr make(MatcherState &state, halide_type_t type_hint) const {
211 return Expr(&expr);
212 }
213
214 constexpr static bool foldable = false;
215};
216
217inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
218 s << Expr(&e.expr);
219 return s;
220}
221
222template<int i>
224 struct pattern_tag {};
225
226 constexpr static uint32_t binds = 1 << i;
227
230 constexpr static bool canonical = true;
231
232 template<uint32_t bound>
233 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
234 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
235 const BaseExprNode *op = &e;
236 if (op->node_type == IRNodeType::Broadcast) {
237 op = ((const Broadcast *)op)->value.get();
238 }
239 if (op->node_type != IRNodeType::IntImm) {
240 return false;
241 }
242 int64_t value = ((const IntImm *)op)->value;
243 if (bound & binds) {
245 halide_type_t type;
246 state.get_bound_const(i, val, type);
247 return (halide_type_t)e.type == type && value == val.u.i64;
248 }
249 state.set_bound_const(i, value, e.type);
250 return true;
251 }
252
253 template<uint32_t bound>
254 HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
255 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
256 if (bound & binds) {
258 halide_type_t type;
259 state.get_bound_const(i, val, type);
260 return type == i64_type && value == val.u.i64;
261 }
262 state.set_bound_const(i, value, i64_type);
263 return true;
264 }
265
267 Expr make(MatcherState &state, halide_type_t type_hint) const {
269 halide_type_t type;
270 state.get_bound_const(i, val, type);
271 return make_const_expr(val, type);
272 }
273
274 constexpr static bool foldable = true;
275
278 state.get_bound_const(i, val, ty);
279 }
280};
281
282template<int i>
283std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
284 s << "ci" << i;
285 return s;
286}
287
288template<int i>
290 struct pattern_tag {};
291
292 constexpr static uint32_t binds = 1 << i;
293
296 constexpr static bool canonical = true;
297
298 template<uint32_t bound>
299 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
300 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
301 const BaseExprNode *op = &e;
302 if (op->node_type == IRNodeType::Broadcast) {
303 op = ((const Broadcast *)op)->value.get();
304 }
305 if (op->node_type != IRNodeType::UIntImm) {
306 return false;
307 }
308 uint64_t value = ((const UIntImm *)op)->value;
309 if (bound & binds) {
311 halide_type_t type;
312 state.get_bound_const(i, val, type);
313 return (halide_type_t)e.type == type && value == val.u.u64;
314 }
315 state.set_bound_const(i, value, e.type);
316 return true;
317 }
318
320 Expr make(MatcherState &state, halide_type_t type_hint) const {
322 halide_type_t type;
323 state.get_bound_const(i, val, type);
324 return make_const_expr(val, type);
325 }
326
327 constexpr static bool foldable = true;
328
330 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
331 state.get_bound_const(i, val, ty);
332 }
333};
334
335template<int i>
336std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
337 s << "cu" << i;
338 return s;
339}
340
341template<int i>
343 struct pattern_tag {};
344
345 constexpr static uint32_t binds = 1 << i;
346
349 constexpr static bool canonical = true;
350
351 template<uint32_t bound>
352 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
353 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
354 const BaseExprNode *op = &e;
355 if (op->node_type == IRNodeType::Broadcast) {
356 op = ((const Broadcast *)op)->value.get();
357 }
358 if (op->node_type != IRNodeType::FloatImm) {
359 return false;
360 }
361 double value = ((const FloatImm *)op)->value;
362 if (bound & binds) {
364 halide_type_t type;
365 state.get_bound_const(i, val, type);
366 return (halide_type_t)e.type == type && value == val.u.f64;
367 }
368 state.set_bound_const(i, value, e.type);
369 return true;
370 }
371
373 Expr make(MatcherState &state, halide_type_t type_hint) const {
375 halide_type_t type;
376 state.get_bound_const(i, val, type);
377 return make_const_expr(val, type);
378 }
379
380 constexpr static bool foldable = true;
381
383 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
384 state.get_bound_const(i, val, ty);
385 }
386};
387
388template<int i>
389std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
390 s << "cf" << i;
391 return s;
392}
393
394// Matches and binds to any constant Expr. Does not support constant-folding.
395template<int i>
396struct WildConst {
397 struct pattern_tag {};
398
399 constexpr static uint32_t binds = 1 << i;
400
403 constexpr static bool canonical = true;
404
405 template<uint32_t bound>
406 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
407 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
408 const BaseExprNode *op = &e;
409 if (op->node_type == IRNodeType::Broadcast) {
410 op = ((const Broadcast *)op)->value.get();
411 }
412 switch (op->node_type) {
414 return WildConstInt<i>().template match<bound>(e, state);
416 return WildConstUInt<i>().template match<bound>(e, state);
418 return WildConstFloat<i>().template match<bound>(e, state);
419 default:
420 return false;
421 }
422 }
423
424 template<uint32_t bound>
425 HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
426 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
427 return WildConstInt<i>().template match<bound>(e, state);
428 }
429
431 Expr make(MatcherState &state, halide_type_t type_hint) const {
433 halide_type_t type;
434 state.get_bound_const(i, val, type);
435 return make_const_expr(val, type);
436 }
437
438 constexpr static bool foldable = true;
439
441 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
442 state.get_bound_const(i, val, ty);
443 }
444};
445
446template<int i>
447std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
448 s << "c" << i;
449 return s;
450}
451
452// Matches and binds to any Expr
453template<int i>
454struct Wild {
455 struct pattern_tag {};
456
457 constexpr static uint32_t binds = 1 << (i + 16);
458
461 constexpr static bool canonical = true;
462
463 template<uint32_t bound>
464 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
465 if (bound & binds) {
466 return equal(*state.get_binding(i), e);
467 }
468 state.set_binding(i, e);
469 return true;
470 }
471
473 Expr make(MatcherState &state, halide_type_t type_hint) const {
474 return state.get_binding(i);
475 }
476
477 constexpr static bool foldable = false;
478};
479
480template<int i>
481std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
482 s << "_" << i;
483 return s;
484}
485
486// Matches a specific constant or broadcast of that constant. The
487// constant must be representable as an int64_t.
489 struct pattern_tag {};
491
492 constexpr static uint32_t binds = 0;
493
496 constexpr static bool canonical = true;
497
500 : v(v) {
501 }
502
503 template<uint32_t bound>
504 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
505 const BaseExprNode *op = &e;
506 if (e.node_type == IRNodeType::Broadcast) {
507 op = ((const Broadcast *)op)->value.get();
508 }
509 switch (op->node_type) {
511 return ((const IntImm *)op)->value == (int64_t)v;
513 return ((const UIntImm *)op)->value == (uint64_t)v;
515 return ((const FloatImm *)op)->value == (double)v;
516 default:
517 return false;
518 }
519 }
520
521 template<uint32_t bound>
522 HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
523 return v == val;
524 }
525
526 template<uint32_t bound>
527 HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
528 return v == b.v;
529 }
530
532 Expr make(MatcherState &state, halide_type_t type_hint) const {
533 return make_const(type_hint, v);
534 }
535
536 constexpr static bool foldable = true;
537
539 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
540 // Assume type is already correct
541 switch (ty.code) {
542 case halide_type_int:
543 val.u.i64 = v;
544 break;
545 case halide_type_uint:
546 val.u.u64 = (uint64_t)v;
547 break;
550 val.u.f64 = (double)v;
551 break;
552 default:
553 // Unreachable
554 ;
555 }
556 }
557};
558
562
563// Convert a provided pattern, expr, or constant int into the internal
564// representation we use in the matcher trees.
565template<typename T,
566 typename = typename std::decay<T>::type::pattern_tag>
568 return t;
569}
572 return IntLiteral{x};
573}
574
575template<typename T>
577 static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
578 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
579}
580
582 return {*e.get()};
583}
584
585// Helpers to deref SpecificExprs to const BaseExprNode & rather than
586// passing them by value anywhere (incurring lots of refcounting)
587template<typename T,
588 // T must be a pattern node
589 typename = typename std::decay<T>::type::pattern_tag,
590 // But T may not be SpecificExpr
591 typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
593 return t;
594}
595
598 return e.expr;
599}
600
601inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
602 s << op.v;
603 return s;
604}
605
606template<typename Op>
608
609template<typename Op>
611
612template<typename Op>
613double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
614
615constexpr bool commutative(IRNodeType t) {
616 return (t == IRNodeType::Add ||
617 t == IRNodeType::Mul ||
618 t == IRNodeType::And ||
619 t == IRNodeType::Or ||
620 t == IRNodeType::Min ||
621 t == IRNodeType::Max ||
622 t == IRNodeType::EQ ||
623 t == IRNodeType::NE);
624}
625
626// Matches one of the binary operators
627template<typename Op, typename A, typename B>
628struct BinOp {
629 struct pattern_tag {};
630 A a;
631 B b;
632
634
635 constexpr static IRNodeType min_node_type = Op::_node_type;
636 constexpr static IRNodeType max_node_type = Op::_node_type;
637
638 // For commutative bin ops, we expect the weaker IR node type on
639 // the right. That is, for the rule to be canonical it must be
640 // possible that A is at least as strong as B.
641 constexpr static bool canonical =
642 A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
643
644 template<uint32_t bound>
645 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
646 if (e.node_type != Op::_node_type) {
647 return false;
648 }
649 const Op &op = (const Op &)e;
650 return (a.template match<bound>(*op.a.get(), state) &&
651 b.template match<(bound | bindings<A>::mask)>(*op.b.get(), state));
652 }
653
654 template<uint32_t bound, typename Op2, typename A2, typename B2>
655 HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
656 return (std::is_same<Op, Op2>::value &&
657 a.template match<bound>(unwrap(op.a), state) &&
658 b.template match<(bound | bindings<A>::mask)>(unwrap(op.b), state));
659 }
660
661 constexpr static bool foldable = A::foldable && B::foldable;
662
664 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
665 halide_scalar_value_t val_a, val_b;
666 if (std::is_same<A, IntLiteral>::value) {
667 b.make_folded_const(val_b, ty, state);
668 if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
669 (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
670 // Short circuit
671 val = val_b;
672 return;
673 }
674 const uint16_t l = ty.lanes;
675 a.make_folded_const(val_a, ty, state);
676 ty.lanes |= l; // Make sure the overflow bits are sticky
677 } else {
678 a.make_folded_const(val_a, ty, state);
679 if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
680 (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
681 // Short circuit
682 val = val_a;
683 return;
684 }
685 const uint16_t l = ty.lanes;
686 b.make_folded_const(val_b, ty, state);
687 ty.lanes |= l;
688 }
689 switch (ty.code) {
690 case halide_type_int:
691 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
692 break;
693 case halide_type_uint:
694 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
695 break;
698 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
699 break;
700 default:
701 // unreachable
702 ;
703 }
704 }
705
707 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
708 Expr ea, eb;
709 if (std::is_same<A, IntLiteral>::value) {
710 eb = b.make(state, type_hint);
711 ea = a.make(state, eb.type());
712 } else {
713 ea = a.make(state, type_hint);
714 eb = b.make(state, ea.type());
715 }
716 return Op::make(std::move(ea), std::move(eb));
717 }
718};
719
720template<typename Op>
722
723template<typename Op>
725
726template<typename Op>
727uint64_t constant_fold_cmp_op(double, double) noexcept;
728
729// Matches one of the comparison operators
730template<typename Op, typename A, typename B>
731struct CmpOp {
732 struct pattern_tag {};
733 A a;
734 B b;
735
737
738 constexpr static IRNodeType min_node_type = Op::_node_type;
739 constexpr static IRNodeType max_node_type = Op::_node_type;
740 constexpr static bool canonical = (A::canonical &&
741 B::canonical &&
742 (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
743 (Op::_node_type != IRNodeType::GE) &&
744 (Op::_node_type != IRNodeType::GT));
745
746 template<uint32_t bound>
747 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
748 if (e.node_type != Op::_node_type) {
749 return false;
750 }
751 const Op &op = (const Op &)e;
752 return (a.template match<bound>(*op.a.get(), state) &&
753 b.template match<(bound | bindings<A>::mask)>(*op.b.get(), state));
754 }
755
756 template<uint32_t bound, typename Op2, typename A2, typename B2>
757 HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
758 return (std::is_same<Op, Op2>::value &&
759 a.template match<bound>(unwrap(op.a), state) &&
760 b.template match<(bound | bindings<A>::mask)>(unwrap(op.b), state));
761 }
762
763 constexpr static bool foldable = A::foldable && B::foldable;
764
766 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
767 halide_scalar_value_t val_a, val_b;
768 // If one side is an untyped const, evaluate the other side first to get a type hint.
769 if (std::is_same<A, IntLiteral>::value) {
770 b.make_folded_const(val_b, ty, state);
771 const uint16_t l = ty.lanes;
772 a.make_folded_const(val_a, ty, state);
773 ty.lanes |= l;
774 } else {
775 a.make_folded_const(val_a, ty, state);
776 const uint16_t l = ty.lanes;
777 b.make_folded_const(val_b, ty, state);
778 ty.lanes |= l;
779 }
780 switch (ty.code) {
781 case halide_type_int:
782 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
783 break;
784 case halide_type_uint:
785 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
786 break;
789 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
790 break;
791 default:
792 // unreachable
793 ;
794 }
795 ty.code = halide_type_uint;
796 ty.bits = 1;
797 }
798
800 Expr make(MatcherState &state, halide_type_t type_hint) const {
801 // If one side is an untyped const, evaluate the other side first to get a type hint.
802 Expr ea, eb;
803 if (std::is_same<A, IntLiteral>::value) {
804 eb = b.make(state, {});
805 ea = a.make(state, eb.type());
806 } else {
807 ea = a.make(state, {});
808 eb = b.make(state, ea.type());
809 }
810 return Op::make(std::move(ea), std::move(eb));
811 }
812};
813
814template<typename A, typename B>
815std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
816 s << "(" << op.a << " + " << op.b << ")";
817 return s;
818}
819
820template<typename A, typename B>
821std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
822 s << "(" << op.a << " - " << op.b << ")";
823 return s;
824}
825
826template<typename A, typename B>
827std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
828 s << "(" << op.a << " * " << op.b << ")";
829 return s;
830}
831
832template<typename A, typename B>
833std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
834 s << "(" << op.a << " / " << op.b << ")";
835 return s;
836}
837
838template<typename A, typename B>
839std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
840 s << "(" << op.a << " && " << op.b << ")";
841 return s;
842}
843
844template<typename A, typename B>
845std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
846 s << "(" << op.a << " || " << op.b << ")";
847 return s;
848}
849
850template<typename A, typename B>
851std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
852 s << "min(" << op.a << ", " << op.b << ")";
853 return s;
854}
855
856template<typename A, typename B>
857std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
858 s << "max(" << op.a << ", " << op.b << ")";
859 return s;
860}
861
862template<typename A, typename B>
863std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
864 s << "(" << op.a << " <= " << op.b << ")";
865 return s;
866}
867
868template<typename A, typename B>
869std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
870 s << "(" << op.a << " < " << op.b << ")";
871 return s;
872}
873
874template<typename A, typename B>
875std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
876 s << "(" << op.a << " >= " << op.b << ")";
877 return s;
878}
879
880template<typename A, typename B>
881std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
882 s << "(" << op.a << " > " << op.b << ")";
883 return s;
884}
885
886template<typename A, typename B>
887std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
888 s << "(" << op.a << " == " << op.b << ")";
889 return s;
890}
891
892template<typename A, typename B>
893std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
894 s << "(" << op.a << " != " << op.b << ")";
895 return s;
896}
897
898template<typename A, typename B>
899std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
900 s << "(" << op.a << " % " << op.b << ")";
901 return s;
902}
903
904template<typename A, typename B>
905HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
908 return {pattern_arg(a), pattern_arg(b)};
909}
910
911template<typename A, typename B>
912HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
915 return IRMatcher::operator+(a, b);
916}
917
918template<>
920 t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
921 int dead_bits = 64 - t.bits;
922 // Drop the high bits then sign-extend them back
923 return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
924}
925
926template<>
928 uint64_t ones = (uint64_t)(-1);
929 return (a + b) & (ones >> (64 - t.bits));
930}
931
932template<>
933HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
934 return a + b;
935}
936
937template<typename A, typename B>
938HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
941 return {pattern_arg(a), pattern_arg(b)};
942}
943
944template<typename A, typename B>
945HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
948 return IRMatcher::operator-(a, b);
949}
950
951template<>
953 t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
954 // Drop the high bits then sign-extend them back
955 int dead_bits = 64 - t.bits;
956 return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
957}
958
959template<>
961 uint64_t ones = (uint64_t)(-1);
962 return (a - b) & (ones >> (64 - t.bits));
963}
964
965template<>
966HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
967 return a - b;
968}
969
970template<typename A, typename B>
971HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
974 return {pattern_arg(a), pattern_arg(b)};
975}
976
977template<typename A, typename B>
978HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
981 return IRMatcher::operator*(a, b);
982}
983
984template<>
986 t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
987 int dead_bits = 64 - t.bits;
988 // Drop the high bits then sign-extend them back
989 return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
990}
991
992template<>
994 uint64_t ones = (uint64_t)(-1);
995 return (a * b) & (ones >> (64 - t.bits));
996}
997
998template<>
999HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1000 return a * b;
1001}
1002
1003template<typename A, typename B>
1004HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1007 return {pattern_arg(a), pattern_arg(b)};
1008}
1009
1010template<typename A, typename B>
1011HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1012 return IRMatcher::operator/(a, b);
1013}
1014
1015template<>
1019
1020template<>
1024
1025template<>
1026HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1027 return div_imp(a, b);
1028}
1029
1030template<typename A, typename B>
1031HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1034 return {pattern_arg(a), pattern_arg(b)};
1035}
1036
1037template<typename A, typename B>
1038HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1041 return IRMatcher::operator%(a, b);
1042}
1043
1044template<>
1048
1049template<>
1053
1054template<>
1055HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1056 return mod_imp(a, b);
1057}
1058
1059template<typename A, typename B>
1060HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1063 return {pattern_arg(a), pattern_arg(b)};
1064}
1065
1066template<>
1068 return std::min(a, b);
1069}
1070
1071template<>
1073 return std::min(a, b);
1074}
1075
1076template<>
1077HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1078 return std::min(a, b);
1079}
1080
1081template<typename A, typename B>
1082HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1085 return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1086}
1087
1088template<>
1090 return std::max(a, b);
1091}
1092
1093template<>
1095 return std::max(a, b);
1096}
1097
1098template<>
1099HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1100 return std::max(a, b);
1101}
1102
1103template<typename A, typename B>
1104HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1105 return {pattern_arg(a), pattern_arg(b)};
1106}
1107
1108template<typename A, typename B>
1109HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1110 return IRMatcher::operator<(a, b);
1111}
1112
1113template<>
1117
1118template<>
1122
1123template<>
1125 return a < b;
1126}
1127
1128template<typename A, typename B>
1129HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1130 return {pattern_arg(a), pattern_arg(b)};
1131}
1132
1133template<typename A, typename B>
1134HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1135 return IRMatcher::operator>(a, b);
1136}
1137
1138template<>
1142
1143template<>
1147
1148template<>
1150 return a > b;
1151}
1152
1153template<typename A, typename B>
1154HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1155 return {pattern_arg(a), pattern_arg(b)};
1156}
1157
1158template<typename A, typename B>
1159HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1160 return IRMatcher::operator<=(a, b);
1161}
1162
1163template<>
1165 return a <= b;
1166}
1167
1168template<>
1172
1173template<>
1175 return a <= b;
1176}
1177
1178template<typename A, typename B>
1179HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1180 return {pattern_arg(a), pattern_arg(b)};
1181}
1182
1183template<typename A, typename B>
1184HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1185 return IRMatcher::operator>=(a, b);
1186}
1187
1188template<>
1190 return a >= b;
1191}
1192
1193template<>
1197
1198template<>
1200 return a >= b;
1201}
1202
1203template<typename A, typename B>
1204HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1205 return {pattern_arg(a), pattern_arg(b)};
1206}
1207
1208template<typename A, typename B>
1209HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1210 return IRMatcher::operator==(a, b);
1211}
1212
1213template<>
1215 return a == b;
1216}
1217
1218template<>
1222
1223template<>
1225 return a == b;
1226}
1227
1228template<typename A, typename B>
1229HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1230 return {pattern_arg(a), pattern_arg(b)};
1231}
1232
1233template<typename A, typename B>
1234HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1235 return IRMatcher::operator!=(a, b);
1236}
1237
1238template<>
1240 return a != b;
1241}
1242
1243template<>
1247
1248template<>
1250 return a != b;
1251}
1252
1253template<typename A, typename B>
1254HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1255 return {pattern_arg(a), pattern_arg(b)};
1256}
1257
1258template<typename A, typename B>
1259HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1260 return IRMatcher::operator||(a, b);
1261}
1262
1263template<>
1265 return (a | b) & 1;
1266}
1267
1268template<>
1270 return (a | b) & 1;
1271}
1272
1273template<>
1274HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1275 // Unreachable, as it would be a type mismatch.
1276 return 0;
1277}
1278
1279template<typename A, typename B>
1280HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1281 return {pattern_arg(a), pattern_arg(b)};
1282}
1283
1284template<typename A, typename B>
1285HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1286 return IRMatcher::operator&&(a, b);
1287}
1288
1289template<>
1291 return a & b & 1;
1292}
1293
1294template<>
1298
1299template<>
1300HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1301 // Unreachable
1302 return 0;
1303}
1304
1306 return 0;
1307}
1308
1309template<typename... Args>
1310constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1311 return first | bitwise_or_reduce(rest...);
1312}
1313
1314constexpr bool and_reduce() {
1315 return true;
1316}
1317
1318template<typename... Args>
1319constexpr bool and_reduce(bool first, Args... rest) {
1320 return first && and_reduce(rest...);
1321}
1322
1323template<Call::IntrinsicOp intrin>
1325 bool check(const Type &) const {
1326 return true;
1327 }
1328};
1329
1330template<>
1333 bool check(const Type &t) const {
1334 return t == Type(type);
1335 }
1336};
1337
1338template<Call::IntrinsicOp intrin, typename... Args>
1339struct Intrin {
1340 struct pattern_tag {};
1341 std::tuple<Args...> args;
1342 // The type of the output of the intrinsic node.
1343 // Only necessary in cases where it can't be inferred
1344 // from the input types (e.g. saturating_cast).
1345
1347
1349
1352 constexpr static bool canonical = and_reduce((Args::canonical)...);
1353
1354 template<int i,
1355 uint32_t bound,
1356 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1357 HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1358 using T = decltype(std::get<i>(args));
1359 return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1360 match_args<i + 1, (bound | bindings<T>::mask)>(0, c, state));
1361 }
1362
1363 template<int i, uint32_t binds>
1364 HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1365 return true;
1366 }
1367
1368 template<uint32_t bound>
1369 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1370 if (e.node_type != IRNodeType::Call) {
1371 return false;
1372 }
1373 const Call &c = (const Call &)e;
1374 return (c.is_intrinsic(intrin) &&
1375 optional_type_hint.check(e.type) &&
1376 match_args<0, bound>(0, c, state));
1377 }
1378
1379 template<int i,
1380 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1381 HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1382 s << std::get<i>(args);
1383 if (i + 1 < sizeof...(Args)) {
1384 s << ", ";
1385 }
1386 print_args<i + 1>(0, s);
1387 }
1388
1389 template<int i>
1390 HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1391 }
1392
1394 void print_args(std::ostream &s) const {
1395 print_args<0>(0, s);
1396 }
1397
1399 Expr make(MatcherState &state, halide_type_t type_hint) const {
1400 Expr arg0 = std::get<0>(args).make(state, type_hint);
1401 if (intrin == Call::likely) {
1402 return likely(std::move(arg0));
1403 } else if (intrin == Call::likely_if_innermost) {
1404 return likely_if_innermost(std::move(arg0));
1405 } else if (intrin == Call::abs) {
1406 return abs(std::move(arg0));
1407 } else if constexpr (intrin == Call::saturating_cast) {
1408 return saturating_cast(optional_type_hint.type, std::move(arg0));
1409 }
1410
1411 Expr arg1 = std::get<std::min<size_t>(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1412 if (intrin == Call::absd) {
1413 return absd(std::move(arg0), std::move(arg1));
1414 } else if (intrin == Call::widen_right_add) {
1415 return widen_right_add(std::move(arg0), std::move(arg1));
1416 } else if (intrin == Call::widen_right_mul) {
1417 return widen_right_mul(std::move(arg0), std::move(arg1));
1418 } else if (intrin == Call::widen_right_sub) {
1419 return widen_right_sub(std::move(arg0), std::move(arg1));
1420 } else if (intrin == Call::widening_add) {
1421 return widening_add(std::move(arg0), std::move(arg1));
1422 } else if (intrin == Call::widening_sub) {
1423 return widening_sub(std::move(arg0), std::move(arg1));
1424 } else if (intrin == Call::widening_mul) {
1425 return widening_mul(std::move(arg0), std::move(arg1));
1426 } else if (intrin == Call::saturating_add) {
1427 return saturating_add(std::move(arg0), std::move(arg1));
1428 } else if (intrin == Call::saturating_sub) {
1429 return saturating_sub(std::move(arg0), std::move(arg1));
1430 } else if (intrin == Call::halving_add) {
1431 return halving_add(std::move(arg0), std::move(arg1));
1432 } else if (intrin == Call::halving_sub) {
1433 return halving_sub(std::move(arg0), std::move(arg1));
1434 } else if (intrin == Call::rounding_halving_add) {
1435 return rounding_halving_add(std::move(arg0), std::move(arg1));
1436 } else if (intrin == Call::shift_left) {
1437 return std::move(arg0) << std::move(arg1);
1438 } else if (intrin == Call::shift_right) {
1439 return std::move(arg0) >> std::move(arg1);
1440 } else if (intrin == Call::rounding_shift_left) {
1441 return rounding_shift_left(std::move(arg0), std::move(arg1));
1442 } else if (intrin == Call::rounding_shift_right) {
1443 return rounding_shift_right(std::move(arg0), std::move(arg1));
1444 }
1445
1446 Expr arg2 = std::get<std::min<size_t>(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1447 if (intrin == Call::mul_shift_right) {
1448 return mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2));
1449 } else if (intrin == Call::rounding_mul_shift_right) {
1450 return rounding_mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2));
1451 }
1452
1453 internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1454 return Expr();
1455 }
1456
1457 constexpr static bool foldable = true;
1458
1461 // Assuming the args have the same type as the intrinsic is incorrect in
1462 // general. But for the intrinsics we can fold (just shifts), the LHS
1463 // has the same type as the intrinsic, and we can always treat the RHS
1464 // as a signed int, because we're using 64 bits for it.
1465 std::get<0>(args).make_folded_const(val, ty, state);
1466 halide_type_t signed_ty = ty;
1467 signed_ty.code = halide_type_int;
1468 // We can just directly get the second arg here, because we only want to
1469 // instantiate this method for shifts, which have two args.
1470 std::get<1>(args).make_folded_const(arg1, signed_ty, state);
1471
1472 if (intrin == Call::shift_left) {
1473 if (arg1.u.i64 < 0) {
1474 if (ty.code == halide_type_int) {
1475 // Arithmetic shift
1476 val.u.i64 >>= -arg1.u.i64;
1477 } else {
1478 // Logical shift
1479 val.u.u64 >>= -arg1.u.i64;
1480 }
1481 } else {
1482 val.u.u64 <<= arg1.u.i64;
1483 }
1484 } else if (intrin == Call::shift_right) {
1485 if (arg1.u.i64 > 0) {
1486 if (ty.code == halide_type_int) {
1487 // Arithmetic shift
1488 val.u.i64 >>= arg1.u.i64;
1489 } else {
1490 // Logical shift
1491 val.u.u64 >>= arg1.u.i64;
1492 }
1493 } else {
1494 val.u.u64 <<= -arg1.u.i64;
1495 }
1496 } else {
1497 internal_error << "Folding not implemented for intrinsic: " << intrin;
1498 }
1499 }
1500
1502 Intrin(Args... args) noexcept
1503 : args(args...) {
1504 }
1505};
1506
1507template<Call::IntrinsicOp intrin, typename... Args>
1508std::ostream &operator<<(std::ostream &s, const Intrin<intrin, Args...> &op) {
1509 s << intrin << "(";
1510 op.print_args(s);
1511 s << ")";
1512 return s;
1513}
1514
1515template<typename A, typename B>
1516auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1517 return {pattern_arg(a), pattern_arg(b)};
1518}
1519template<typename A, typename B>
1520auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1521 return {pattern_arg(a), pattern_arg(b)};
1522}
1523template<typename A, typename B>
1524auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1525 return {pattern_arg(a), pattern_arg(b)};
1526}
1527
1528template<typename A, typename B>
1529auto widening_add(A &&a, B &&b) noexcept -> Intrin<Call::widening_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1530 return {pattern_arg(a), pattern_arg(b)};
1531}
1532template<typename A, typename B>
1533auto widening_sub(A &&a, B &&b) noexcept -> Intrin<Call::widening_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1534 return {pattern_arg(a), pattern_arg(b)};
1535}
1536template<typename A, typename B>
1537auto widening_mul(A &&a, B &&b) noexcept -> Intrin<Call::widening_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1538 return {pattern_arg(a), pattern_arg(b)};
1539}
1540template<typename A, typename B>
1541auto saturating_add(A &&a, B &&b) noexcept -> Intrin<Call::saturating_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1542 return {pattern_arg(a), pattern_arg(b)};
1543}
1544template<typename A, typename B>
1545auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<Call::saturating_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1546 return {pattern_arg(a), pattern_arg(b)};
1547}
1548template<typename A>
1549auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<Call::saturating_cast, decltype(pattern_arg(a))> {
1550 Intrin<Call::saturating_cast, decltype(pattern_arg(a))> p = {pattern_arg(a)};
1551 p.optional_type_hint.type = t;
1552 return p;
1553}
1554template<typename A, typename B>
1555auto halving_add(A &&a, B &&b) noexcept -> Intrin<Call::halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1556 return {pattern_arg(a), pattern_arg(b)};
1557}
1558template<typename A, typename B>
1559auto halving_sub(A &&a, B &&b) noexcept -> Intrin<Call::halving_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1560 return {pattern_arg(a), pattern_arg(b)};
1561}
1562template<typename A, typename B>
1563auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<Call::rounding_halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1564 return {pattern_arg(a), pattern_arg(b)};
1565}
1566template<typename A, typename B>
1567auto shift_left(A &&a, B &&b) noexcept -> Intrin<Call::shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1568 return {pattern_arg(a), pattern_arg(b)};
1569}
1570template<typename A, typename B>
1571auto shift_right(A &&a, B &&b) noexcept -> Intrin<Call::shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1572 return {pattern_arg(a), pattern_arg(b)};
1573}
1574template<typename A, typename B>
1575auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1576 return {pattern_arg(a), pattern_arg(b)};
1577}
1578template<typename A, typename B>
1579auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1580 return {pattern_arg(a), pattern_arg(b)};
1581}
1582template<typename A, typename B, typename C>
1583auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<Call::mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1584 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1585}
1586template<typename A, typename B, typename C>
1587auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<Call::rounding_mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1588 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1589}
1590
1591template<typename A>
1592auto abs(A &&a) noexcept -> Intrin<Call::abs, decltype(pattern_arg(a))> {
1593 return {pattern_arg(a)};
1594}
1595
1596template<typename A, typename B>
1597auto absd(A &&a, B &&b) noexcept -> Intrin<Call::absd, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1598 return {pattern_arg(a), pattern_arg(b)};
1599}
1600
1601template<typename A>
1602auto likely(A &&a) noexcept -> Intrin<Call::likely, decltype(pattern_arg(a))> {
1603 return {pattern_arg(a)};
1604}
1605
1606template<typename A>
1607auto likely_if_innermost(A &&a) noexcept -> Intrin<Call::likely_if_innermost, decltype(pattern_arg(a))> {
1608 return {pattern_arg(a)};
1609}
1610
1611template<typename A>
1612struct NotOp {
1613 struct pattern_tag {};
1614 A a;
1615
1617
1620 constexpr static bool canonical = A::canonical;
1621
1622 template<uint32_t bound>
1623 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1624 if (e.node_type != IRNodeType::Not) {
1625 return false;
1626 }
1627 const Not &op = (const Not &)e;
1628 return (a.template match<bound>(*op.a.get(), state));
1629 }
1630
1631 template<uint32_t bound, typename A2>
1632 HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1633 return a.template match<bound>(unwrap(op.a), state);
1634 }
1635
1637 Expr make(MatcherState &state, halide_type_t type_hint) const {
1638 return Not::make(a.make(state, type_hint));
1639 }
1640
1641 constexpr static bool foldable = A::foldable;
1642
1643 template<typename A1 = A>
1645 a.make_folded_const(val, ty, state);
1646 val.u.u64 = ~val.u.u64;
1647 val.u.u64 &= 1;
1648 }
1649};
1650
1651template<typename A>
1652HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
1654 return {pattern_arg(a)};
1655}
1656
1657template<typename A>
1662
1663template<typename A>
1664inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1665 s << "!(" << op.a << ")";
1666 return s;
1667}
1668
1669template<typename C, typename T, typename F>
1670struct SelectOp {
1671 struct pattern_tag {};
1673 T t;
1674 F f;
1675
1677
1680
1681 constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1682
1683 template<uint32_t bound>
1684 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1685 if (e.node_type != Select::_node_type) {
1686 return false;
1687 }
1688 const Select &op = (const Select &)e;
1689 return (c.template match<bound>(*op.condition.get(), state) &&
1690 t.template match<(bound | bindings<C>::mask)>(*op.true_value.get(), state) &&
1691 f.template match<(bound | bindings<C>::mask | bindings<T>::mask)>(*op.false_value.get(), state));
1692 }
1693 template<uint32_t bound, typename C2, typename T2, typename F2>
1694 HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1695 return (c.template match<bound>(unwrap(instance.c), state) &&
1696 t.template match<(bound | bindings<C>::mask)>(unwrap(instance.t), state) &&
1697 f.template match<(bound | bindings<C>::mask | bindings<T>::mask)>(unwrap(instance.f), state));
1698 }
1699
1701 Expr make(MatcherState &state, halide_type_t type_hint) const {
1702 return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1703 }
1704
1705 constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1706
1707 template<typename C1 = C>
1709 halide_scalar_value_t c_val, t_val, f_val;
1710 halide_type_t c_ty;
1711 c.make_folded_const(c_val, c_ty, state);
1712 if ((c_val.u.u64 & 1) == 1) {
1713 t.make_folded_const(val, ty, state);
1714 } else {
1715 f.make_folded_const(val, ty, state);
1716 }
1717 ty.lanes |= c_ty.lanes & MatcherState::special_values_mask;
1718 }
1719};
1720
1721template<typename C, typename T, typename F>
1722std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1723 s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1724 return s;
1725}
1726
1727template<typename C, typename T, typename F>
1728HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
1732 return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1733}
1734
1735template<typename A, typename B>
1737 struct pattern_tag {};
1738 A a;
1740
1742
1745
1746 constexpr static bool canonical = A::canonical && B::canonical;
1747
1748 template<uint32_t bound>
1749 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1750 if (e.node_type == Broadcast::_node_type) {
1751 const Broadcast &op = (const Broadcast &)e;
1752 if (a.template match<bound>(*op.value.get(), state) &&
1753 lanes.template match<bound>(op.lanes, state)) {
1754 return true;
1755 }
1756 }
1757 return false;
1758 }
1759
1760 template<uint32_t bound, typename A2, typename B2>
1761 HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1762 return (a.template match<bound>(unwrap(op.a), state) &&
1763 lanes.template match<(bound | bindings<A>::mask)>(unwrap(op.lanes), state));
1764 }
1765
1767 Expr make(MatcherState &state, halide_type_t type_hint) const {
1768 halide_scalar_value_t lanes_val;
1769 halide_type_t ty;
1770 lanes.make_folded_const(lanes_val, ty, state);
1771 int32_t l = (int32_t)lanes_val.u.i64;
1772 type_hint.lanes /= l;
1773 Expr val = a.make(state, type_hint);
1774 if (l == 1) {
1775 return val;
1776 } else {
1777 return Broadcast::make(std::move(val), l);
1778 }
1779 }
1780
1781 constexpr static bool foldable = false;
1782
1783 template<typename A1 = A>
1785 halide_scalar_value_t lanes_val;
1786 halide_type_t lanes_ty;
1787 lanes.make_folded_const(lanes_val, lanes_ty, state);
1788 uint16_t l = (uint16_t)lanes_val.u.i64;
1789 a.make_folded_const(val, ty, state);
1790 ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1791 }
1792};
1793
1794template<typename A, typename B>
1795inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1796 s << "broadcast(" << op.a << ", " << op.lanes << ")";
1797 return s;
1798}
1799
1800template<typename A, typename B>
1801HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
1803 return {pattern_arg(a), pattern_arg(lanes)};
1804}
1805
1806template<typename A, typename B, typename C>
1807struct RampOp {
1808 struct pattern_tag {};
1809 A a;
1810 B b;
1812
1814
1817
1818 constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1819
1820 template<uint32_t bound>
1821 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1822 if (e.node_type != Ramp::_node_type) {
1823 return false;
1824 }
1825 const Ramp &op = (const Ramp &)e;
1826 if (a.template match<bound>(*op.base.get(), state) &&
1827 b.template match<(bound | bindings<A>::mask)>(*op.stride.get(), state) &&
1828 lanes.template match<(bound | bindings<A>::mask | bindings<B>::mask)>(op.lanes, state)) {
1829 return true;
1830 } else {
1831 return false;
1832 }
1833 }
1834
1835 template<uint32_t bound, typename A2, typename B2, typename C2>
1836 HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1837 return (a.template match<bound>(unwrap(op.a), state) &&
1838 b.template match<(bound | bindings<A>::mask)>(unwrap(op.b), state) &&
1839 lanes.template match<(bound | bindings<A>::mask | bindings<B>::mask)>(unwrap(op.lanes), state));
1840 }
1841
1843 Expr make(MatcherState &state, halide_type_t type_hint) const {
1844 halide_scalar_value_t lanes_val;
1845 halide_type_t ty;
1846 lanes.make_folded_const(lanes_val, ty, state);
1847 int32_t l = (int32_t)lanes_val.u.i64;
1848 type_hint.lanes /= l;
1849 Expr ea, eb;
1850 eb = b.make(state, type_hint);
1851 ea = a.make(state, eb.type());
1852 return Ramp::make(std::move(ea), std::move(eb), l);
1853 }
1854
1855 constexpr static bool foldable = false;
1856};
1857
1858template<typename A, typename B, typename C>
1859std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1860 s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1861 return s;
1862}
1863
1864template<typename A, typename B, typename C>
1865HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1869 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1870}
1871
1872template<typename A, typename B, VectorReduce::Operator reduce_op>
1874 struct pattern_tag {};
1875 A a;
1877
1879
1882 constexpr static bool canonical = A::canonical;
1883
1884 template<uint32_t bound>
1885 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1886 if (e.node_type == VectorReduce::_node_type) {
1887 const VectorReduce &op = (const VectorReduce &)e;
1888 if (op.op == reduce_op &&
1889 a.template match<bound>(*op.value.get(), state) &&
1890 lanes.template match<(bound | bindings<A>::mask)>(op.type.lanes(), state)) {
1891 return true;
1892 }
1893 }
1894 return false;
1895 }
1896
1897 template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1899 return (reduce_op == reduce_op_2 &&
1900 a.template match<bound>(unwrap(op.a), state) &&
1901 lanes.template match<(bound | bindings<A>::mask)>(unwrap(op.lanes), state));
1902 }
1903
1905 Expr make(MatcherState &state, halide_type_t type_hint) const {
1906 halide_scalar_value_t lanes_val;
1907 halide_type_t ty;
1908 lanes.make_folded_const(lanes_val, ty, state);
1909 int l = (int)lanes_val.u.i64;
1910 return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1911 }
1912
1913 constexpr static bool foldable = false;
1914};
1915
1916template<typename A, typename B, VectorReduce::Operator reduce_op>
1917inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1918 s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1919 return s;
1920}
1921
1922template<typename A, typename B>
1923HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
1925 return {pattern_arg(a), pattern_arg(lanes)};
1926}
1927
1928template<typename A, typename B>
1929HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
1931 return {pattern_arg(a), pattern_arg(lanes)};
1932}
1933
1934template<typename A, typename B>
1935HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
1937 return {pattern_arg(a), pattern_arg(lanes)};
1938}
1939
1940template<typename A, typename B>
1941HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
1943 return {pattern_arg(a), pattern_arg(lanes)};
1944}
1945
1946template<typename A, typename B>
1947HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
1949 return {pattern_arg(a), pattern_arg(lanes)};
1950}
1951
1952template<typename A>
1953struct NegateOp {
1954 struct pattern_tag {};
1955 A a;
1956
1958
1961
1962 constexpr static bool canonical = A::canonical;
1963
1964 template<uint32_t bound>
1965 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1966 if (e.node_type != Sub::_node_type) {
1967 return false;
1968 }
1969 const Sub &op = (const Sub &)e;
1970 return (a.template match<bound>(*op.b.get(), state) &&
1971 is_const_zero(op.a));
1972 }
1973
1974 template<uint32_t bound, typename A2>
1975 HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
1976 return a.template match<bound>(unwrap(p.a), state);
1977 }
1978
1980 Expr make(MatcherState &state, halide_type_t type_hint) const {
1981 Expr ea = a.make(state, type_hint);
1982 Expr z = make_zero(ea.type());
1983 return Sub::make(std::move(z), std::move(ea));
1984 }
1985
1986 constexpr static bool foldable = A::foldable;
1987
1988 template<typename A1 = A>
1990 a.make_folded_const(val, ty, state);
1991 int dead_bits = 64 - ty.bits;
1992 switch (ty.code) {
1993 case halide_type_int:
1994 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1995 // Trying to negate the most negative signed int for a no-overflow type.
1997 } else {
1998 // Negate, drop the high bits, and then sign-extend them back
1999 val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
2000 }
2001 break;
2002 case halide_type_uint:
2003 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
2004 break;
2005 case halide_type_float:
2006 case halide_type_bfloat:
2007 val.u.f64 = -val.u.f64;
2008 break;
2009 default:
2010 // unreachable
2011 ;
2012 }
2013 }
2014};
2015
2016template<typename A>
2017std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
2018 s << "-" << op.a;
2019 return s;
2020}
2021
2022template<typename A>
2023HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
2025 return {pattern_arg(a)};
2026}
2027
2028template<typename A>
2033
2034template<typename A>
2035struct CastOp {
2036 struct pattern_tag {};
2038 A a;
2039
2041
2044 constexpr static bool canonical = A::canonical;
2045
2046 template<uint32_t bound>
2047 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2048 if (e.node_type != Cast::_node_type) {
2049 return false;
2050 }
2051 const Cast &op = (const Cast &)e;
2052 return (e.type == t &&
2053 a.template match<bound>(*op.value.get(), state));
2054 }
2055 template<uint32_t bound, typename A2>
2056 HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2057 return t == op.t && a.template match<bound>(unwrap(op.a), state);
2058 }
2059
2061 Expr make(MatcherState &state, halide_type_t type_hint) const {
2062 return cast(t, a.make(state, {}));
2063 }
2064
2065 constexpr static bool foldable = false;
2066};
2067
2068template<typename A>
2069std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2070 s << "cast(" << op.t << ", " << op.a << ")";
2071 return s;
2072}
2073
2074template<typename A>
2075HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
2077 return {t, pattern_arg(a)};
2078}
2079
2080template<typename A>
2081struct WidenOp {
2082 struct pattern_tag {};
2083 A a;
2084
2086
2089 constexpr static bool canonical = A::canonical;
2090
2091 template<uint32_t bound>
2092 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2093 if (e.node_type != Cast::_node_type) {
2094 return false;
2095 }
2096 const Cast &op = (const Cast &)e;
2097 return (e.type == op.value.type().widen() &&
2098 a.template match<bound>(*op.value.get(), state));
2099 }
2100 template<uint32_t bound, typename A2>
2101 HALIDE_ALWAYS_INLINE bool match(const WidenOp<A2> &op, MatcherState &state) const noexcept {
2102 return a.template match<bound>(unwrap(op.a), state);
2103 }
2104
2106 Expr make(MatcherState &state, halide_type_t type_hint) const {
2107 Expr e = a.make(state, {});
2108 Type w = e.type().widen();
2109 return cast(w, std::move(e));
2110 }
2111
2112 constexpr static bool foldable = false;
2113};
2114
2115template<typename A>
2116std::ostream &operator<<(std::ostream &s, const WidenOp<A> &op) {
2117 s << "widen(" << op.a << ")";
2118 return s;
2119}
2120
2121template<typename A>
2122HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp<decltype(pattern_arg(a))> {
2124 return {pattern_arg(a)};
2125}
2126
2127template<typename Vec, typename Base, typename Stride, typename Lanes>
2128struct SliceOp {
2129 struct pattern_tag {};
2130 Vec vec;
2131 Base base;
2132 Stride stride;
2133 Lanes lanes;
2134
2135 static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2136
2139 constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2140
2141 template<uint32_t bound>
2142 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2143 if (e.node_type != IRNodeType::Shuffle) {
2144 return false;
2145 }
2146 const Shuffle &v = (const Shuffle &)e;
2147 return v.vectors.size() == 1 &&
2148 v.is_slice() &&
2149 vec.template match<bound>(*v.vectors[0].get(), state) &&
2150 base.template match<(bound | bindings<Vec>::mask)>(v.slice_begin(), state) &&
2151 stride.template match<(bound | bindings<Vec>::mask | bindings<Base>::mask)>(v.slice_stride(), state) &&
2153 }
2154
2156 Expr make(MatcherState &state, halide_type_t type_hint) const {
2157 halide_scalar_value_t base_val, stride_val, lanes_val;
2158 halide_type_t ty;
2159 base.make_folded_const(base_val, ty, state);
2160 int b = (int)base_val.u.i64;
2161 stride.make_folded_const(stride_val, ty, state);
2162 int s = (int)stride_val.u.i64;
2163 lanes.make_folded_const(lanes_val, ty, state);
2164 int l = (int)lanes_val.u.i64;
2165 return Shuffle::make_slice(vec.make(state, type_hint), b, s, l);
2166 }
2167
2168 constexpr static bool foldable = false;
2169
2171 SliceOp(Vec v, Base b, Stride s, Lanes l)
2172 : vec(v), base(b), stride(s), lanes(l) {
2173 static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold");
2174 static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold");
2175 static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold");
2176 }
2177};
2178
2179template<typename Vec, typename Base, typename Stride, typename Lanes>
2180std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
2181 s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")";
2182 return s;
2183}
2184
2185template<typename Vec, typename Base, typename Stride, typename Lanes>
2186HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept
2187 -> SliceOp<decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))> {
2188 return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)};
2189}
2190
2191template<typename A>
2192struct Fold {
2193 struct pattern_tag {};
2194 A a;
2195
2197
2200 constexpr static bool canonical = true;
2201
2203 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
2205 halide_type_t ty = type_hint;
2206 a.make_folded_const(c, ty, state);
2207
2208 // The result of the fold may have an underspecified type
2209 // (e.g. because it's from an int literal). Make the type code
2210 // and bits match the required type, if there is one (we can
2211 // tell from the bits field).
2212 if (type_hint.bits) {
2213 if (((int)ty.code == (int)halide_type_int) &&
2214 ((int)type_hint.code == (int)halide_type_float)) {
2215 int64_t x = c.u.i64;
2216 c.u.f64 = (double)x;
2217 }
2218 ty.code = type_hint.code;
2219 ty.bits = type_hint.bits;
2220 }
2221
2222 return make_const_expr(c, ty);
2223 }
2224
2225 constexpr static bool foldable = A::foldable;
2226
2227 template<typename A1 = A>
2229 a.make_folded_const(val, ty, state);
2230 }
2231};
2232
2233template<typename A>
2234HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
2236 return {pattern_arg(a)};
2237}
2238
2239template<typename A>
2240std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2241 s << "fold(" << op.a << ")";
2242 return s;
2243}
2244
2245template<typename A>
2247 struct pattern_tag {};
2248 A a;
2249
2251
2252 // This rule is a predicate, so it always evaluates to a boolean,
2253 // which has IRNodeType UIntImm
2256 constexpr static bool canonical = true;
2257
2258 constexpr static bool foldable = A::foldable;
2259
2260 template<typename A1 = A>
2262 a.make_folded_const(val, ty, state);
2263 ty.code = halide_type_uint;
2264 ty.bits = 64;
2265 val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2266 ty.lanes = 1;
2267 }
2268};
2269
2270template<typename A>
2271HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
2273 return {pattern_arg(a)};
2274}
2275
2276template<typename A>
2277std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2278 s << "overflows(" << op.a << ")";
2279 return s;
2280}
2281
2282struct Overflow {
2283 struct pattern_tag {};
2284
2285 constexpr static uint32_t binds = 0;
2286
2287 // Overflow is an intrinsic, represented as a Call node
2290 constexpr static bool canonical = true;
2291
2292 template<uint32_t bound>
2293 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2294 if (e.node_type != Call::_node_type) {
2295 return false;
2296 }
2297 const Call &op = (const Call &)e;
2299 }
2300
2302 Expr make(MatcherState &state, halide_type_t type_hint) const {
2304 return make_const_special_expr(type_hint);
2305 }
2306
2307 constexpr static bool foldable = true;
2308
2311 val.u.u64 = 0;
2313 }
2314};
2315
2316inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2317 s << "overflow()";
2318 return s;
2319}
2320
2321template<typename A>
2322struct IsConst {
2323 struct pattern_tag {};
2324
2326
2327 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2330 constexpr static bool canonical = true;
2331
2332 A a;
2335
2336 constexpr static bool foldable = true;
2337
2338 template<typename A1 = A>
2340 Expr e = a.make(state, {});
2341 ty.code = halide_type_uint;
2342 ty.bits = 64;
2343 ty.lanes = 1;
2344 if (check_v) {
2345 val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2346 } else {
2347 val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2348 }
2349 }
2350};
2351
2352template<typename A>
2353HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
2355 return {pattern_arg(a), false, 0};
2356}
2357
2358template<typename A>
2359HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
2361 return {pattern_arg(a), true, value};
2362}
2363
2364template<typename A>
2365std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2366 if (op.check_v) {
2367 s << "is_const(" << op.a << ")";
2368 } else {
2369 s << "is_const(" << op.a << ", " << op.v << ")";
2370 }
2371 return s;
2372}
2373
2374template<typename A, typename Prover>
2375struct CanProve {
2376 struct pattern_tag {};
2377 A a;
2378 Prover *prover; // An existing simplifying mutator
2379
2381
2382 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2385 constexpr static bool canonical = true;
2386
2387 constexpr static bool foldable = true;
2388
2389 // Includes a raw call to an inlined make method, so don't inline.
2391 Expr condition = a.make(state, {});
2392 condition = prover->mutate(condition, nullptr);
2393 val.u.u64 = is_const_one(condition);
2395 ty.bits = 1;
2396 ty.lanes = condition.type().lanes();
2397 }
2398};
2399
2400template<typename A, typename Prover>
2401HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
2403 return {pattern_arg(a), p};
2404}
2405
2406template<typename A, typename Prover>
2407std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2408 s << "can_prove(" << op.a << ")";
2409 return s;
2410}
2411
2412template<typename A>
2413struct IsFloat {
2414 struct pattern_tag {};
2415 A a;
2416
2418
2419 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2422 constexpr static bool canonical = true;
2423
2424 constexpr static bool foldable = true;
2425
2428 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2429 Type t = a.make(state, {}).type();
2430 val.u.u64 = t.is_float();
2432 ty.bits = 1;
2433 ty.lanes = t.lanes();
2434 }
2435};
2436
2437template<typename A>
2438HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
2440 return {pattern_arg(a)};
2441}
2442
2443template<typename A>
2444std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2445 s << "is_float(" << op.a << ")";
2446 return s;
2447}
2448
2449template<typename A>
2450struct IsInt {
2451 struct pattern_tag {};
2452 A a;
2455
2457
2458 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2461 constexpr static bool canonical = true;
2462
2463 constexpr static bool foldable = true;
2464
2467 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2468 Type t = a.make(state, {}).type();
2469 val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2471 ty.bits = 1;
2472 ty.lanes = t.lanes();
2473 }
2474};
2475
2476template<typename A>
2477HALIDE_ALWAYS_INLINE auto is_int(A &&a, uint8_t bits = 0, uint16_t lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
2479 return {pattern_arg(a), bits, lanes};
2480}
2481
2482template<typename A>
2483std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2484 s << "is_int(" << op.a;
2485 if (op.bits > 0) {
2486 s << ", " << op.bits;
2487 }
2488 if (op.lanes > 0) {
2489 s << ", " << op.lanes;
2490 }
2491 s << ")";
2492 return s;
2493}
2494
2495template<typename A>
2496struct IsUInt {
2497 struct pattern_tag {};
2498 A a;
2501
2503
2504 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2507 constexpr static bool canonical = true;
2508
2509 constexpr static bool foldable = true;
2510
2513 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2514 Type t = a.make(state, {}).type();
2515 val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2517 ty.bits = 1;
2518 ty.lanes = t.lanes();
2519 }
2520};
2521
2522template<typename A>
2523HALIDE_ALWAYS_INLINE auto is_uint(A &&a, uint8_t bits = 0, uint16_t lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2525 return {pattern_arg(a), bits, lanes};
2526}
2527
2528template<typename A>
2529std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2530 s << "is_uint(" << op.a;
2531 if (op.bits > 0) {
2532 s << ", " << op.bits;
2533 }
2534 if (op.lanes > 0) {
2535 s << ", " << op.lanes;
2536 }
2537 s << ")";
2538 return s;
2539}
2540
2541template<typename A>
2542struct IsScalar {
2543 struct pattern_tag {};
2544 A a;
2545
2547
2548 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2551 constexpr static bool canonical = true;
2552
2553 constexpr static bool foldable = true;
2554
2557 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2558 Type t = a.make(state, {}).type();
2559 val.u.u64 = t.is_scalar();
2561 ty.bits = 1;
2562 ty.lanes = t.lanes();
2563 }
2564};
2565
2566template<typename A>
2567HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
2569 return {pattern_arg(a)};
2570}
2571
2572template<typename A>
2573std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2574 s << "is_scalar(" << op.a << ")";
2575 return s;
2576}
2577
2578template<typename A>
2580 struct pattern_tag {};
2581 A a;
2582
2584
2585 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2588 constexpr static bool canonical = true;
2589
2590 constexpr static bool foldable = true;
2591
2594 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2595 a.make_folded_const(val, ty, state);
2596 const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2597 if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2598 val.u.u64 = (val.u.u64 == max_bits);
2599 } else {
2600 val.u.u64 = 0;
2601 }
2603 ty.bits = 1;
2604 }
2605};
2606
2607template<typename A>
2608HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
2610 return {pattern_arg(a)};
2611}
2612
2613template<typename A>
2614std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2615 s << "is_max_value(" << op.a << ")";
2616 return s;
2617}
2618
2619template<typename A>
2621 struct pattern_tag {};
2622 A a;
2623
2625
2626 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2629 constexpr static bool canonical = true;
2630
2631 constexpr static bool foldable = true;
2632
2635 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2636 a.make_folded_const(val, ty, state);
2637 if (ty.code == halide_type_int) {
2638 const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2639 val.u.u64 = (val.u.u64 == min_bits);
2640 } else if (ty.code == halide_type_uint) {
2641 val.u.u64 = (val.u.u64 == 0);
2642 } else {
2643 val.u.u64 = 0;
2644 }
2646 ty.bits = 1;
2647 }
2648};
2649
2650template<typename A>
2651HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
2653 return {pattern_arg(a)};
2654}
2655
2656template<typename A>
2657std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2658 s << "is_min_value(" << op.a << ")";
2659 return s;
2660}
2661
2662template<typename A>
2663struct LanesOf {
2664 struct pattern_tag {};
2665 A a;
2666
2668
2669 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2672 constexpr static bool canonical = true;
2673
2674 constexpr static bool foldable = true;
2675
2678 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2679 Type t = a.make(state, {}).type();
2680 val.u.u64 = t.lanes();
2682 ty.bits = 32;
2683 ty.lanes = 1;
2684 }
2685};
2686
2687template<typename A>
2688HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf<decltype(pattern_arg(a))> {
2690 return {pattern_arg(a)};
2691}
2692
2693template<typename A>
2694std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
2695 s << "lanes_of(" << op.a << ")";
2696 return s;
2697}
2698
2699// Verify properties of each rewrite rule. Currently just fuzz tests them.
2700template<typename Before,
2701 typename After,
2702 typename Predicate,
2703 typename = typename std::enable_if<std::decay<Before>::type::foldable &&
2704 std::decay<After>::type::foldable>::type>
2705HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2706 halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2707
2708 // We only validate the rules in the scalar case
2709 wildcard_type.lanes = output_type.lanes = 1;
2710
2711 // Track which types this rule has been tested for before
2712 static std::set<uint32_t> tested;
2713
2714 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2715 return;
2716 }
2717
2718 // Print it in a form where it can be piped into a python/z3 validator
2719 debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2720
2721 // Substitute some random constants into the before and after
2722 // expressions and see if the rule holds true. This should catch
2723 // silly errors, but not necessarily corner cases.
2724 static std::mt19937_64 rng(0);
2725 MatcherState state;
2726
2727 Expr exprs[max_wild];
2728
2729 for (int trials = 0; trials < 100; trials++) {
2730 // We want to test small constants more frequently than
2731 // large ones, otherwise we'll just get coverage of
2732 // overflow rules.
2733 int shift = (int)(rng() & (wildcard_type.bits - 1));
2734
2735 for (int i = 0; i < max_wild; i++) {
2736 // Bind all the exprs and constants
2737 switch (wildcard_type.code) {
2738 case halide_type_uint: {
2739 // Normalize to the type's range by adding zero
2740 uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2741 state.set_bound_const(i, val, wildcard_type);
2742 val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2743 exprs[i] = make_const(wildcard_type, val);
2744 state.set_binding(i, *exprs[i].get());
2745 } break;
2746 case halide_type_int: {
2747 int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2748 state.set_bound_const(i, val, wildcard_type);
2749 val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2750 exprs[i] = make_const(wildcard_type, val);
2751 } break;
2752 case halide_type_float:
2753 case halide_type_bfloat: {
2754 // Use a very narrow range of precise floats, so
2755 // that none of the rules a human is likely to
2756 // write have instabilities.
2757 double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2758 state.set_bound_const(i, val, wildcard_type);
2759 val = ((int64_t)(rng() & 15) - 8) / 2.0;
2760 exprs[i] = make_const(wildcard_type, val);
2761 } break;
2762 default:
2763 return; // Don't care about handles
2764 }
2765 state.set_binding(i, *exprs[i].get());
2766 }
2767
2768 halide_scalar_value_t val_pred, val_before, val_after;
2769 halide_type_t type = output_type;
2770 if (!evaluate_predicate(pred, state)) {
2771 continue;
2772 }
2773 before.make_folded_const(val_before, type, state);
2774 uint16_t lanes = type.lanes;
2775 after.make_folded_const(val_after, type, state);
2776 lanes |= type.lanes;
2777
2779 continue;
2780 }
2781
2782 bool ok = true;
2783 switch (output_type.code) {
2784 case halide_type_uint:
2785 // Compare normalized representations
2786 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2787 constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2788 break;
2789 case halide_type_int:
2790 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2791 constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2792 break;
2793 case halide_type_float:
2794 case halide_type_bfloat: {
2795 double error = std::abs(val_before.u.f64 - val_after.u.f64);
2796 // We accept an equal bit pattern (e.g. inf vs inf),
2797 // a small floating point difference, or turning a nan into not-a-nan.
2798 ok &= (error < 0.01 ||
2799 val_before.u.u64 == val_after.u.u64 ||
2800 std::isnan(val_before.u.f64));
2801 break;
2802 }
2803 default:
2804 return;
2805 }
2806
2807 if (!ok) {
2808 debug(0) << "Fails with values:\n";
2809 for (int i = 0; i < max_wild; i++) {
2811 state.get_bound_const(i, val, wildcard_type);
2812 debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2813 }
2814 for (int i = 0; i < max_wild; i++) {
2815 debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2816 }
2817 debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2818 debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2819 debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2821 }
2822 }
2823}
2824
2825template<typename Before,
2826 typename After,
2827 typename Predicate,
2828 typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
2829 std::decay<After>::type::foldable)>::type>
2830HALIDE_ALWAYS_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2831 halide_type_t, halide_type_t, int dummy = 0) noexcept {
2832 // We can't verify rewrite rules that can't be constant-folded.
2833}
2834
2836bool evaluate_predicate(bool x, MatcherState &) noexcept {
2837 return x;
2838}
2839
2840template<typename Pattern,
2841 typename = typename enable_if_pattern<Pattern>::type>
2844 halide_type_t ty = halide_type_of<bool>();
2845 p.make_folded_const(c, ty, state);
2846 // Overflow counts as a failed predicate
2847 return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2848}
2849
2850// #defines for testing
2851
2852// Print all successful or failed matches
2853#define HALIDE_DEBUG_MATCHED_RULES 0
2854#define HALIDE_DEBUG_UNMATCHED_RULES 0
2855
2856// Set to true if you want to fuzz test every rewrite passed to
2857// operator() to ensure the input and the output have the same value
2858// for lots of random values of the wildcards. Run
2859// correctness_simplify with this on.
2860#define HALIDE_FUZZ_TEST_RULES 0
2861
2862template<typename Instance>
2863struct Rewriter {
2864 Instance instance;
2869
2872 : instance(std::move(instance)), output_type(ot), wildcard_type(wt) {
2873 }
2874
2875 template<typename After>
2877#if HALIDE_DEBUG_MATCHED_RULES
2878 debug(0) << instance << " -> " << after << "\n";
2879#endif
2880 result = after.make(state, output_type);
2881 }
2882
2883 template<typename Before,
2884 typename After,
2885 typename = typename enable_if_pattern<Before>::type,
2886 typename = typename enable_if_pattern<After>::type>
2887 HALIDE_ALWAYS_INLINE bool operator()(Before before, After after) {
2888 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2889 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2890 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2891#if HALIDE_FUZZ_TEST_RULES
2892 fuzz_test_rule(before, after, true, wildcard_type, output_type);
2893#endif
2894 if (before.template match<0>(unwrap(instance), state)) {
2895 build_replacement(after);
2896#if HALIDE_DEBUG_MATCHED_RULES
2897 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2898#endif
2899 return true;
2900 } else {
2901#if HALIDE_DEBUG_UNMATCHED_RULES
2902 debug(0) << instance << " does not match " << before << "\n";
2903#endif
2904 return false;
2905 }
2906 }
2907
2908 template<typename Before,
2909 typename = typename enable_if_pattern<Before>::type>
2910 HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept {
2911 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2912 if (before.template match<0>(unwrap(instance), state)) {
2913 result = after;
2914#if HALIDE_DEBUG_MATCHED_RULES
2915 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2916#endif
2917 return true;
2918 } else {
2919#if HALIDE_DEBUG_UNMATCHED_RULES
2920 debug(0) << instance << " does not match " << before << "\n";
2921#endif
2922 return false;
2923 }
2924 }
2925
2926 template<typename Before,
2927 typename = typename enable_if_pattern<Before>::type>
2928 HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept {
2929 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2930#if HALIDE_FUZZ_TEST_RULES
2931 fuzz_test_rule(before, IntLiteral(after), true, wildcard_type, output_type);
2932#endif
2933 if (before.template match<0>(unwrap(instance), state)) {
2934 result = make_const(output_type, after);
2935#if HALIDE_DEBUG_MATCHED_RULES
2936 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2937#endif
2938 return true;
2939 } else {
2940#if HALIDE_DEBUG_UNMATCHED_RULES
2941 debug(0) << instance << " does not match " << before << "\n";
2942#endif
2943 return false;
2944 }
2945 }
2946
2947 template<typename Before,
2948 typename After,
2949 typename Predicate,
2950 typename = typename enable_if_pattern<Before>::type,
2951 typename = typename enable_if_pattern<After>::type,
2952 typename = typename enable_if_pattern<Predicate>::type>
2953 HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred) {
2954 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2955 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2956 static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
2957 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2958 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2959
2960#if HALIDE_FUZZ_TEST_RULES
2961 fuzz_test_rule(before, after, pred, wildcard_type, output_type);
2962#endif
2963 if (before.template match<0>(unwrap(instance), state) &&
2964 evaluate_predicate(pred, state)) {
2965 build_replacement(after);
2966#if HALIDE_DEBUG_MATCHED_RULES
2967 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2968#endif
2969 return true;
2970 } else {
2971#if HALIDE_DEBUG_UNMATCHED_RULES
2972 debug(0) << instance << " does not match " << before << "\n";
2973#endif
2974 return false;
2975 }
2976 }
2977
2978 template<typename Before,
2979 typename Predicate,
2980 typename = typename enable_if_pattern<Before>::type,
2981 typename = typename enable_if_pattern<Predicate>::type>
2982 HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred) {
2983 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2984 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2985
2986 if (before.template match<0>(unwrap(instance), state) &&
2987 evaluate_predicate(pred, state)) {
2988 result = after;
2989#if HALIDE_DEBUG_MATCHED_RULES
2990 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2991#endif
2992 return true;
2993 } else {
2994#if HALIDE_DEBUG_UNMATCHED_RULES
2995 debug(0) << instance << " does not match " << before << "\n";
2996#endif
2997 return false;
2998 }
2999 }
3000
3001 template<typename Before,
3002 typename Predicate,
3003 typename = typename enable_if_pattern<Before>::type,
3004 typename = typename enable_if_pattern<Predicate>::type>
3005 HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred) {
3006 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
3007 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
3008#if HALIDE_FUZZ_TEST_RULES
3009 fuzz_test_rule(before, IntLiteral(after), pred, wildcard_type, output_type);
3010#endif
3011 if (before.template match<0>(unwrap(instance), state) &&
3012 evaluate_predicate(pred, state)) {
3013 result = make_const(output_type, after);
3014#if HALIDE_DEBUG_MATCHED_RULES
3015 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
3016#endif
3017 return true;
3018 } else {
3019#if HALIDE_DEBUG_UNMATCHED_RULES
3020 debug(0) << instance << " does not match " << before << "\n";
3021#endif
3022 return false;
3023 }
3024 }
3025};
3026
3027/** Construct a rewriter for the given instance, which may be a pattern
3028 * with concrete expressions as leaves, or just an expression. The
3029 * second optional argument (wildcard_type) is a hint as to what the
3030 * type of the wildcards is likely to be. If omitted it uses the same
3031 * type as the expression itself. They are not required to be this
3032 * type, but the rule will only be tested for wildcards of that type
3033 * when testing is enabled.
3034 *
3035 * The rewriter can be used to check to see if the instance is one of
3036 * some number of patterns and if so rewrite it into another form,
3037 * using its operator() method. See Simplify.cpp for a bunch of
3038 * example usage.
3039 *
3040 * Important: Any Exprs in patterns are captured by reference, not by
3041 * value, so ensure they outlive the rewriter.
3042 */
3043// @{
3044template<typename Instance,
3045 typename = typename enable_if_pattern<Instance>::type>
3046HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3047 return {pattern_arg(instance), output_type, wildcard_type};
3048}
3049
3050template<typename Instance,
3051 typename = typename enable_if_pattern<Instance>::type>
3052HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3053 return {pattern_arg(instance), output_type, output_type};
3054}
3055
3057auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3058 return {pattern_arg(e), e.type(), wildcard_type};
3059}
3060
3062auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3063 return {pattern_arg(e), e.type(), e.type()};
3064}
3065// @}
3066
3067} // namespace IRMatcher
3068
3069} // namespace Internal
3070} // namespace Halide
3071
3072#endif
#define debug(n)
For optional debugging during codegen, use the debug macro as follows:
Definition Debug.h:52
#define internal_error
Definition Error.h:215
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
An alternative template-metaprogramming approach to expression matching.
Definition IRMatch.h:72
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition IRMatch.h:3046
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition IRMatch.h:567
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< Call::rounding_halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1563
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition IRMatch.h:1259
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition IRMatch.h:1652
auto shift_right(A &&a, B &&b) noexcept -> Intrin< Call::shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1571
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1060
auto widening_add(A &&a, B &&b) noexcept -> Intrin< Call::widening_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1529
HALIDE_ALWAYS_INLINE auto is_int(A &&a, uint8_t bits=0, uint16_t lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2477
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition IRMatch.h:2836
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1016
auto abs(A &&a) noexcept -> Intrin< Call::abs, decltype(pattern_arg(a))>
Definition IRMatch.h:1592
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition IRMatch.h:1234
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, uint8_t bits=0, uint16_t lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2523
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition IRMatch.h:2029
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1154
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:905
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2608
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition IRMatch.h:217
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition IRMatch.h:1285
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition IRMatch.h:1941
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition IRMatch.h:1134
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition IRMatch.h:2353
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1164
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:971
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition IRMatch.h:912
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1516
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition IRMatch.h:1011
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1520
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition IRMatch.h:978
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1082
auto absd(A &&a, B &&b) noexcept -> Intrin< Call::absd, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1597
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:2186
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1865
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1004
HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp< decltype(pattern_arg(a))>
Definition IRMatch.h:2122
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1045
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1290
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition IRMatch.h:559
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< Call::widening_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1537
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1129
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition IRMatch.h:2075
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition IRMatch.h:2271
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< Call::saturating_cast, decltype(pattern_arg(a))>
Definition IRMatch.h:1549
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition IRMatch.h:576
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1031
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:952
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< Call::rounding_shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1575
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition IRMatch.h:2567
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition IRMatch.h:2234
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition IRMatch.h:1658
auto likely(A &&a) noexcept -> Intrin< Call::likely, decltype(pattern_arg(a))>
Definition IRMatch.h:1602
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1089
constexpr bool and_reduce()
Definition IRMatch.h:1314
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1254
constexpr int max_wild
Definition IRMatch.h:74
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1229
auto halving_add(A &&a, B &&b) noexcept -> Intrin< Call::halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1555
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< Call::mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1583
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition IRMatch.h:2438
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< Call::widening_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1533
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1179
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1104
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1280
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition IRMatch.h:1947
constexpr bool commutative(IRNodeType t)
Definition IRMatch.h:615
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition IRMatch.h:945
auto likely_if_innermost(A &&a) noexcept -> Intrin< Call::likely_if_innermost, decltype(pattern_arg(a))>
Definition IRMatch.h:1607
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition IRMatch.h:1935
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:1801
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition IRMatch.h:1728
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< Call::saturating_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1545
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2651
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1067
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition IRMatch.h:2705
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1139
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:985
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1189
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:938
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< Call::rounding_mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1587
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< Call::saturating_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1541
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition IRMatch.h:1159
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition IRMatch.h:1109
auto shift_left(A &&a, B &&b) noexcept -> Intrin< Call::shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1567
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
Definition IRMatch.h:2688
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< Call::rounding_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1579
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1114
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition IRMatch.h:1929
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition IRMatch.h:1923
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1264
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition IRMatch.h:160
constexpr uint32_t bitwise_or_reduce()
Definition IRMatch.h:1305
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1214
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition IRMatch.h:149
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< Call::halving_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1559
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition IRMatch.h:1184
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1239
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition IRMatch.h:1038
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1204
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1524
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:919
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition IRMatch.h:2401
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition IRMatch.h:1209
T div_imp(T a, T b)
Definition IROperator.h:278
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
void expr_match_test()
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
bool equal(const RDom &bounds0, const RDom &bounds1)
Return true if bounds0 and bounds1 represent the same bounds.
constexpr IRNodeType StrongestExprNodeType
Definition Expr.h:81
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition IROperator.h:257
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
DstType reinterpret_bits(const SrcType &src)
An aggressive form of reinterpret cast used for correct type-punning.
Definition Util.h:135
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second?
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition Expr.h:25
bool is_const(const Expr &e)
Is the expression either an IntImm, a FloatImm, a StringImm, or a Cast of the same,...
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
Definition Schedule.h:59
@ C
No name mangling.
Definition Function.h:28
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT8_TYPE__ uint8_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition Expr.h:258
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition Expr.h:327
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition Expr.h:321
The sum of two expressions.
Definition IR.h:56
Logical and - are both expressions true.
Definition IR.h:175
A base class for expression nodes.
Definition Expr.h:143
A vector with 'lanes' elements, in which every element is 'value'.
Definition IR.h:259
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
Definition IR.h:265
A function call.
Definition IR.h:490
bool is_intrinsic() const
Definition IR.h:737
static const IRNodeType _node_type
Definition IR.h:795
std::vector< Expr > args
Definition IR.h:492
The actual IR nodes begin here.
Definition IR.h:30
static const IRNodeType _node_type
Definition IR.h:35
The ratio of two expressions.
Definition IR.h:83
Is the first expression equal to the second.
Definition IR.h:121
Floating point constants.
Definition Expr.h:236
static const FloatImm * make(Type t, double value)
Is the first expression greater than or equal to the second.
Definition IR.h:166
Is the first expression greater than the second.
Definition IR.h:157
static constexpr bool canonical
Definition IRMatch.h:641
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:664
static constexpr uint32_t binds
Definition IRMatch.h:633
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:645
static constexpr bool foldable
Definition IRMatch.h:661
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:707
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:655
static constexpr IRNodeType max_node_type
Definition IRMatch.h:636
static constexpr IRNodeType min_node_type
Definition IRMatch.h:635
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1743
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1767
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1761
static constexpr uint32_t binds
Definition IRMatch.h:1741
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1749
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1744
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1784
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2390
static constexpr uint32_t binds
Definition IRMatch.h:2380
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2383
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2384
static constexpr bool foldable
Definition IRMatch.h:2387
static constexpr bool canonical
Definition IRMatch.h:2385
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2043
static constexpr bool foldable
Definition IRMatch.h:2065
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2047
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2042
static constexpr uint32_t binds
Definition IRMatch.h:2040
static constexpr bool canonical
Definition IRMatch.h:2044
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:2056
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2061
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:800
static constexpr IRNodeType max_node_type
Definition IRMatch.h:739
static constexpr uint32_t binds
Definition IRMatch.h:736
static constexpr bool canonical
Definition IRMatch.h:740
static constexpr bool foldable
Definition IRMatch.h:763
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:747
static constexpr IRNodeType min_node_type
Definition IRMatch.h:738
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:766
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:757
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2199
static constexpr uint32_t binds
Definition IRMatch.h:2196
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2198
static constexpr bool canonical
Definition IRMatch.h:2200
static constexpr bool foldable
Definition IRMatch.h:2225
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:2203
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2228
static constexpr IRNodeType max_node_type
Definition IRMatch.h:495
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:504
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition IRMatch.h:499
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition IRMatch.h:527
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:539
static constexpr IRNodeType min_node_type
Definition IRMatch.h:494
static constexpr bool canonical
Definition IRMatch.h:496
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:532
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition IRMatch.h:522
static constexpr uint32_t binds
Definition IRMatch.h:492
HALIDE_ALWAYS_INLINE Intrin(Args... args) noexcept
Definition IRMatch.h:1502
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition IRMatch.h:1394
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1369
static constexpr bool foldable
Definition IRMatch.h:1457
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1459
std::tuple< Args... > args
Definition IRMatch.h:1341
static constexpr uint32_t binds
Definition IRMatch.h:1348
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition IRMatch.h:1390
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1357
static constexpr bool canonical
Definition IRMatch.h:1352
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1364
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition IRMatch.h:1381
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1351
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1399
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1350
OptionalIntrinType< intrin > optional_type_hint
Definition IRMatch.h:1346
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2328
static constexpr bool canonical
Definition IRMatch.h:2330
static constexpr bool foldable
Definition IRMatch.h:2336
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2329
static constexpr uint32_t binds
Definition IRMatch.h:2325
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2339
static constexpr bool foldable
Definition IRMatch.h:2424
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2427
static constexpr bool canonical
Definition IRMatch.h:2422
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2420
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2421
static constexpr uint32_t binds
Definition IRMatch.h:2417
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2460
static constexpr bool foldable
Definition IRMatch.h:2463
static constexpr uint32_t binds
Definition IRMatch.h:2456
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2466
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2459
static constexpr bool canonical
Definition IRMatch.h:2461
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2586
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2587
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2593
static constexpr uint32_t binds
Definition IRMatch.h:2583
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2627
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2634
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2628
static constexpr uint32_t binds
Definition IRMatch.h:2624
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2550
static constexpr uint32_t binds
Definition IRMatch.h:2546
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2556
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2549
static constexpr bool foldable
Definition IRMatch.h:2553
static constexpr bool canonical
Definition IRMatch.h:2551
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2512
static constexpr bool foldable
Definition IRMatch.h:2509
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2505
static constexpr bool canonical
Definition IRMatch.h:2507
static constexpr uint32_t binds
Definition IRMatch.h:2502
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2506
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2671
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2677
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2670
static constexpr bool foldable
Definition IRMatch.h:2674
static constexpr uint32_t binds
Definition IRMatch.h:2667
static constexpr bool canonical
Definition IRMatch.h:2672
To save stack space, the matcher objects are largely stateless and immutable.
Definition IRMatch.h:82
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition IRMatch.h:127
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition IRMatch.h:103
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition IRMatch.h:115
static constexpr uint16_t special_values_mask
Definition IRMatch.h:88
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition IRMatch.h:121
halide_type_t bound_const_type[max_wild]
Definition IRMatch.h:90
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition IRMatch.h:93
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition IRMatch.h:134
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition IRMatch.h:98
halide_scalar_value_t bound_const[max_wild]
Definition IRMatch.h:84
const BaseExprNode * bindings[max_wild]
Definition IRMatch.h:83
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition IRMatch.h:109
static constexpr uint16_t signed_integer_overflow
Definition IRMatch.h:87
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1965
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1980
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition IRMatch.h:1975
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1989
static constexpr uint32_t binds
Definition IRMatch.h:1957
static constexpr bool canonical
Definition IRMatch.h:1962
static constexpr bool foldable
Definition IRMatch.h:1986
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1960
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1959
static constexpr uint32_t binds
Definition IRMatch.h:1616
static constexpr bool foldable
Definition IRMatch.h:1641
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1623
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1619
static constexpr bool canonical
Definition IRMatch.h:1620
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1632
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1637
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1644
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1618
static constexpr uint32_t binds
Definition IRMatch.h:2285
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2289
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2293
static constexpr bool canonical
Definition IRMatch.h:2290
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2302
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2310
static constexpr bool foldable
Definition IRMatch.h:2307
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2288
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2261
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2254
static constexpr uint32_t binds
Definition IRMatch.h:2250
static constexpr bool canonical
Definition IRMatch.h:2256
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2255
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1843
static constexpr bool canonical
Definition IRMatch.h:1818
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1816
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1815
static constexpr uint32_t binds
Definition IRMatch.h:1813
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1836
static constexpr bool foldable
Definition IRMatch.h:1855
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1821
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition IRMatch.h:2876
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition IRMatch.h:2953
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition IRMatch.h:2928
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition IRMatch.h:2871
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition IRMatch.h:2982
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition IRMatch.h:2910
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition IRMatch.h:3005
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition IRMatch.h:2887
static constexpr uint32_t binds
Definition IRMatch.h:1676
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1708
static constexpr bool foldable
Definition IRMatch.h:1705
static constexpr bool canonical
Definition IRMatch.h:1681
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition IRMatch.h:1694
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1684
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1701
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1679
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1678
static constexpr bool canonical
Definition IRMatch.h:2139
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2138
static constexpr bool foldable
Definition IRMatch.h:2168
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
Definition IRMatch.h:2171
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2137
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2142
static constexpr uint32_t binds
Definition IRMatch.h:2135
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2156
static constexpr IRNodeType min_node_type
Definition IRMatch.h:198
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:205
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:210
static constexpr IRNodeType max_node_type
Definition IRMatch.h:199
static constexpr uint32_t binds
Definition IRMatch.h:195
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1898
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1880
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1885
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1881
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1905
static constexpr uint32_t binds
Definition IRMatch.h:2085
static constexpr bool canonical
Definition IRMatch.h:2089
static constexpr bool foldable
Definition IRMatch.h:2112
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2106
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2088
HALIDE_ALWAYS_INLINE bool match(const WidenOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:2101
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2087
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2092
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:352
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:373
static constexpr IRNodeType max_node_type
Definition IRMatch.h:348
static constexpr IRNodeType min_node_type
Definition IRMatch.h:347
static constexpr uint32_t binds
Definition IRMatch.h:345
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:383
static constexpr bool canonical
Definition IRMatch.h:403
static constexpr IRNodeType max_node_type
Definition IRMatch.h:402
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:431
static constexpr uint32_t binds
Definition IRMatch.h:399
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:406
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:441
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition IRMatch.h:425
static constexpr IRNodeType min_node_type
Definition IRMatch.h:401
static constexpr bool foldable
Definition IRMatch.h:438
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:267
static constexpr uint32_t binds
Definition IRMatch.h:226
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:277
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:233
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition IRMatch.h:254
static constexpr IRNodeType min_node_type
Definition IRMatch.h:228
static constexpr IRNodeType max_node_type
Definition IRMatch.h:229
static constexpr uint32_t binds
Definition IRMatch.h:292
static constexpr IRNodeType max_node_type
Definition IRMatch.h:295
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:299
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:330
static constexpr IRNodeType min_node_type
Definition IRMatch.h:294
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:320
static constexpr IRNodeType min_node_type
Definition IRMatch.h:459
static constexpr uint32_t binds
Definition IRMatch.h:457
static constexpr IRNodeType max_node_type
Definition IRMatch.h:460
static constexpr bool canonical
Definition IRMatch.h:461
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:473
static constexpr bool foldable
Definition IRMatch.h:477
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:464
static constexpr uint32_t mask
Definition IRMatch.h:146
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition Expr.h:113
Integer constants.
Definition Expr.h:218
static const IntImm * make(Type t, int64_t value)
Is the first expression less than or equal to the second.
Definition IR.h:148
Is the first expression less than the second.
Definition IR.h:139
The greater of two values.
Definition IR.h:112
The lesser of two values.
Definition IR.h:103
The remainder of a / b.
Definition IR.h:94
The product of two expressions.
Definition IR.h:74
Is the first expression not equal to the second.
Definition IR.h:130
Logical not - true if the expression false.
Definition IR.h:193
static Expr make(Expr a)
Logical or - is at least one of the expression true.
Definition IR.h:184
A linear ramp vector node.
Definition IR.h:247
static const IRNodeType _node_type
Definition IR.h:253
static Expr make(Expr base, Expr stride, int lanes)
A ternary operator.
Definition IR.h:204
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Definition IR.h:209
Construct a new vector by taking elements from another sequence of vectors.
Definition IR.h:884
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector.
std::vector< Expr > vectors
Definition IR.h:885
bool is_slice() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:938
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:935
The difference of two expressions.
Definition IR.h:65
static const IRNodeType _node_type
Definition IR.h:70
static Expr make(Expr a, Expr b)
Unsigned integer constants.
Definition Expr.h:227
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition IR.h:1012
static const IRNodeType _node_type
Definition IR.h:1031
static Expr make(Operator op, Expr vec, int lanes)
Type widen() const
Return Type with the same type code and number of lanes, but with at least twice as many bits.
Definition Type.h:378
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition Type.h:435
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition Type.h:355
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition Type.h:441
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition Type.h:349
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type?
Definition Type.h:417
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition Type.h:423
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@314322015121151262135054202130057122113055355347 u
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.