Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
IRMatch.h
Go to the documentation of this file.
1#ifndef HALIDE_IR_MATCH_H
2#define HALIDE_IR_MATCH_H
3
4/** \file
5 * Defines a method to match a fragment of IR against a pattern containing wildcards
6 */
7
8#include <map>
9#include <random>
10#include <set>
11#include <vector>
12
13#include "IR.h"
14#include "IREquality.h"
15#include "IROperator.h"
16
17namespace Halide {
18namespace Internal {
19
20/** Does the first expression have the same structure as the second?
21 * Variables in the first expression with the name * are interpreted
22 * as wildcards, and their matching equivalent in the second
23 * expression is placed in the vector give as the third argument.
24 * Wildcards require the types to match. For the type bits and width,
25 * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26 * integer vectors of any width (including scalars), and a UInt(0, 0)
27 * will match any unsigned integer type.
28 *
29 * For example:
30 \code
31 Expr x = Variable::make(Int(32), "*");
32 match(x + x, 3 + (2*k), result)
33 \endcode
34 * should return true, and set result[0] to 3 and
35 * result[1] to 2*k.
36 */
37bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38
39/** Does the first expression have the same structure as the second?
40 * Variables are matched consistently. The first time a variable is
41 * matched, it assumes the value of the matching part of the second
42 * expression. Subsequent matches must be equal to the first match.
43 *
44 * For example:
45 \code
46 Var x("x"), y("y");
47 match(x*(x + y), a*(a + b), result)
48 \endcode
49 * should return true, and set result["x"] = a, and result["y"] = b.
50 */
51bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52
53/** Rewrite the expression x to have `lanes` lanes. This is useful
54 * for substituting the results of expr_match into a pattern expression. */
55Expr with_lanes(const Expr &x, int lanes);
56
58
59/** An alternative template-metaprogramming approach to expression
60 * matching. Potentially more efficient. We lift the expression
61 * pattern into a type, and then use force-inlined functions to
62 * generate efficient matching and reconstruction code for any
63 * pattern. Pattern elements are either one of the classes in the
64 * namespace IRMatcher, or are non-null Exprs (represented as
65 * BaseExprNode &).
66 *
67 * Pattern elements that are fully specified by their pattern can be
68 * built into an expression using the make method. Some patterns,
69 * such as a broadcast that matches any number of lanes, don't have
70 * enough information to recreate an Expr.
71 */
72namespace IRMatcher {
73
74constexpr int max_wild = 6;
75
76static const halide_type_t i64_type = {halide_type_int, 64, 1};
77
78/** To save stack space, the matcher objects are largely stateless and
79 * immutable. This state object is built up during matching and then
80 * consumed when constructing a replacement Expr.
81 */
85
86 // values of the lanes field with special meaning.
87 static constexpr uint16_t signed_integer_overflow = 0x8000;
88 static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89
91
93 void set_binding(int i, const BaseExprNode &n) noexcept {
94 bindings[i] = &n;
95 }
96
98 const BaseExprNode *get_binding(int i) const noexcept {
99 return bindings[i];
100 }
101
103 void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104 bound_const[i].u.i64 = s;
105 bound_const_type[i] = t;
106 }
107
109 void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110 bound_const[i].u.u64 = u;
111 bound_const_type[i] = t;
112 }
113
115 void set_bound_const(int i, double f, halide_type_t t) noexcept {
116 bound_const[i].u.f64 = f;
117 bound_const_type[i] = t;
118 }
119
122 bound_const[i] = val;
123 bound_const_type[i] = t;
124 }
125
127 void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128 val = bound_const[i];
129 type = bound_const_type[i];
130 }
131
133 // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
134 MatcherState() noexcept {
135 }
136};
137
138template<typename T,
139 typename = typename std::remove_reference<T>::type::pattern_tag>
141 struct type {};
142};
143
144template<typename T>
145struct bindings {
146 constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147};
148
151 ty.lanes &= ~MatcherState::special_values_mask;
154 }
155 // unreachable
156 return Expr();
157}
158
161 halide_type_t scalar_type = ty;
162 if (scalar_type.lanes & MatcherState::special_values_mask) {
163 return make_const_special_expr(scalar_type);
164 }
165
166 const int lanes = scalar_type.lanes;
167 scalar_type.lanes = 1;
168
169 Expr e;
170 switch (scalar_type.code) {
171 case halide_type_int:
172 e = IntImm::make(scalar_type, val.u.i64);
173 break;
174 case halide_type_uint:
175 e = UIntImm::make(scalar_type, val.u.u64);
176 break;
179 e = FloatImm::make(scalar_type, val.u.f64);
180 break;
181 default:
182 // Unreachable
183 return Expr();
184 }
185 if (lanes > 1) {
186 e = Broadcast::make(std::move(e), lanes);
187 }
188 return e;
189}
190
191// A pattern that matches a specific expression
193 struct pattern_tag {};
194
195 constexpr static uint32_t binds = 0;
196
197 // What is the weakest and strongest IR node this could possibly be
200 constexpr static bool canonical = true;
201
203
204 template<uint32_t bound>
205 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
206 return equal(expr, e);
207 }
208
210 Expr make(MatcherState &state, halide_type_t type_hint) const {
211 return Expr(&expr);
212 }
213
214 constexpr static bool foldable = false;
215};
216
217inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
218 s << Expr(&e.expr);
219 return s;
220}
221
222template<int i>
224 struct pattern_tag {};
225
226 constexpr static uint32_t binds = 1 << i;
227
230 constexpr static bool canonical = true;
231
232 template<uint32_t bound>
233 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
234 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
235 const BaseExprNode *op = &e;
236 if (op->node_type == IRNodeType::Broadcast) {
237 op = ((const Broadcast *)op)->value.get();
238 }
239 if (op->node_type != IRNodeType::IntImm) {
240 return false;
241 }
242 int64_t value = ((const IntImm *)op)->value;
243 if (bound & binds) {
245 halide_type_t type;
246 state.get_bound_const(i, val, type);
247 return (halide_type_t)e.type == type && value == val.u.i64;
248 }
249 state.set_bound_const(i, value, e.type);
250 return true;
251 }
252
253 template<uint32_t bound>
254 HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
255 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
256 if (bound & binds) {
258 halide_type_t type;
259 state.get_bound_const(i, val, type);
260 return type == i64_type && value == val.u.i64;
261 }
262 state.set_bound_const(i, value, i64_type);
263 return true;
264 }
265
267 Expr make(MatcherState &state, halide_type_t type_hint) const {
269 halide_type_t type;
270 state.get_bound_const(i, val, type);
271 return make_const_expr(val, type);
272 }
273
274 constexpr static bool foldable = true;
275
278 state.get_bound_const(i, val, ty);
279 }
280};
281
282template<int i>
283std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
284 s << "ci" << i;
285 return s;
286}
287
288template<int i>
290 struct pattern_tag {};
291
292 constexpr static uint32_t binds = 1 << i;
293
296 constexpr static bool canonical = true;
297
298 template<uint32_t bound>
299 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
300 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
301 const BaseExprNode *op = &e;
302 if (op->node_type == IRNodeType::Broadcast) {
303 op = ((const Broadcast *)op)->value.get();
304 }
305 if (op->node_type != IRNodeType::UIntImm) {
306 return false;
307 }
308 uint64_t value = ((const UIntImm *)op)->value;
309 if (bound & binds) {
311 halide_type_t type;
312 state.get_bound_const(i, val, type);
313 return (halide_type_t)e.type == type && value == val.u.u64;
314 }
315 state.set_bound_const(i, value, e.type);
316 return true;
317 }
318
320 Expr make(MatcherState &state, halide_type_t type_hint) const {
322 halide_type_t type;
323 state.get_bound_const(i, val, type);
324 return make_const_expr(val, type);
325 }
326
327 constexpr static bool foldable = true;
328
330 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
331 state.get_bound_const(i, val, ty);
332 }
333};
334
335template<int i>
336std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
337 s << "cu" << i;
338 return s;
339}
340
341template<int i>
343 struct pattern_tag {};
344
345 constexpr static uint32_t binds = 1 << i;
346
349 constexpr static bool canonical = true;
350
351 template<uint32_t bound>
352 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
353 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
354 const BaseExprNode *op = &e;
355 if (op->node_type == IRNodeType::Broadcast) {
356 op = ((const Broadcast *)op)->value.get();
357 }
358 if (op->node_type != IRNodeType::FloatImm) {
359 return false;
360 }
361 double value = ((const FloatImm *)op)->value;
362 if (bound & binds) {
364 halide_type_t type;
365 state.get_bound_const(i, val, type);
366 return (halide_type_t)e.type == type && value == val.u.f64;
367 }
368 state.set_bound_const(i, value, e.type);
369 return true;
370 }
371
373 Expr make(MatcherState &state, halide_type_t type_hint) const {
375 halide_type_t type;
376 state.get_bound_const(i, val, type);
377 return make_const_expr(val, type);
378 }
379
380 constexpr static bool foldable = true;
381
383 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
384 state.get_bound_const(i, val, ty);
385 }
386};
387
388template<int i>
389std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
390 s << "cf" << i;
391 return s;
392}
393
394// Matches and binds to any constant Expr. Does not support constant-folding.
395template<int i>
396struct WildConst {
397 struct pattern_tag {};
398
399 constexpr static uint32_t binds = 1 << i;
400
403 constexpr static bool canonical = true;
404
405 template<uint32_t bound>
406 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
407 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
408 const BaseExprNode *op = &e;
409 if (op->node_type == IRNodeType::Broadcast) {
410 op = ((const Broadcast *)op)->value.get();
411 }
412 switch (op->node_type) {
414 return WildConstInt<i>().template match<bound>(e, state);
416 return WildConstUInt<i>().template match<bound>(e, state);
418 return WildConstFloat<i>().template match<bound>(e, state);
419 default:
420 return false;
421 }
422 }
423
424 template<uint32_t bound>
425 HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
426 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
427 return WildConstInt<i>().template match<bound>(e, state);
428 }
429
431 Expr make(MatcherState &state, halide_type_t type_hint) const {
433 halide_type_t type;
434 state.get_bound_const(i, val, type);
435 return make_const_expr(val, type);
436 }
437
438 constexpr static bool foldable = true;
439
441 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
442 state.get_bound_const(i, val, ty);
443 }
444};
445
446template<int i>
447std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
448 s << "c" << i;
449 return s;
450}
451
452// Matches and binds to any Expr
453template<int i>
454struct Wild {
455 struct pattern_tag {};
456
457 constexpr static uint32_t binds = 1 << (i + 16);
458
461 constexpr static bool canonical = true;
462
463 template<uint32_t bound>
464 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
465 if (bound & binds) {
466 return equal(*state.get_binding(i), e);
467 }
468 state.set_binding(i, e);
469 return true;
470 }
471
473 Expr make(MatcherState &state, halide_type_t type_hint) const {
474 return state.get_binding(i);
475 }
476
477 constexpr static bool foldable = false;
478};
479
480template<int i>
481std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
482 s << "_" << i;
483 return s;
484}
485
486// Matches a specific constant or broadcast of that constant. The
487// constant must be representable as an int64_t.
489 struct pattern_tag {};
491
492 constexpr static uint32_t binds = 0;
493
496 constexpr static bool canonical = true;
497
500 : v(v) {
501 }
502
503 template<uint32_t bound>
504 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
505 const BaseExprNode *op = &e;
506 if (e.node_type == IRNodeType::Broadcast) {
507 op = ((const Broadcast *)op)->value.get();
508 }
509 switch (op->node_type) {
511 return ((const IntImm *)op)->value == (int64_t)v;
513 return ((const UIntImm *)op)->value == (uint64_t)v;
515 return ((const FloatImm *)op)->value == (double)v;
516 default:
517 return false;
518 }
519 }
520
521 template<uint32_t bound>
522 HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
523 return v == val;
524 }
525
526 template<uint32_t bound>
527 HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
528 return v == b.v;
529 }
530
532 Expr make(MatcherState &state, halide_type_t type_hint) const {
533 return make_const(type_hint, v);
534 }
535
536 constexpr static bool foldable = true;
537
539 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
540 // Assume type is already correct
541 switch (ty.code) {
542 case halide_type_int:
543 val.u.i64 = v;
544 break;
545 case halide_type_uint:
546 val.u.u64 = (uint64_t)v;
547 break;
550 val.u.f64 = (double)v;
551 break;
552 default:
553 // Unreachable
554 ;
555 }
556 }
557};
558
562
563// Convert a provided pattern, expr, or constant int into the internal
564// representation we use in the matcher trees.
565template<typename T,
566 typename = typename std::decay<T>::type::pattern_tag>
568 return t;
569}
572 return IntLiteral{x};
573}
574
575template<typename T>
577 static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
578 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
579}
580
582 return {*e.get()};
583}
584
585// Helpers to deref SpecificExprs to const BaseExprNode & rather than
586// passing them by value anywhere (incurring lots of refcounting)
587template<typename T,
588 // T must be a pattern node
589 typename = typename std::decay<T>::type::pattern_tag,
590 // But T may not be SpecificExpr
591 typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
593 return t;
594}
595
598 return e.expr;
599}
600
601inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
602 s << op.v;
603 return s;
604}
605
606template<typename Op>
608
609template<typename Op>
611
612template<typename Op>
613double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
614
615constexpr bool commutative(IRNodeType t) {
616 return (t == IRNodeType::Add ||
617 t == IRNodeType::Mul ||
618 t == IRNodeType::And ||
619 t == IRNodeType::Or ||
620 t == IRNodeType::Min ||
621 t == IRNodeType::Max ||
622 t == IRNodeType::EQ ||
623 t == IRNodeType::NE);
624}
625
626// Matches one of the binary operators
627template<typename Op, typename A, typename B>
628struct BinOp {
629 struct pattern_tag {};
630 A a;
631 B b;
632
634
635 constexpr static IRNodeType min_node_type = Op::_node_type;
636 constexpr static IRNodeType max_node_type = Op::_node_type;
637
638 // For commutative bin ops, we expect the weaker IR node type on
639 // the right. That is, for the rule to be canonical it must be
640 // possible that A is at least as strong as B.
641 constexpr static bool canonical =
642 A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
643
644 template<uint32_t bound>
645 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
646 if (e.node_type != Op::_node_type) {
647 return false;
648 }
649 const Op &op = (const Op &)e;
650 return (a.template match<bound>(*op.a.get(), state) &&
651 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
652 }
653
654 template<uint32_t bound, typename Op2, typename A2, typename B2>
655 HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
656 return (std::is_same<Op, Op2>::value &&
657 a.template match<bound>(unwrap(op.a), state) &&
658 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
659 }
660
661 constexpr static bool foldable = A::foldable && B::foldable;
662
664 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
665 halide_scalar_value_t val_a, val_b;
666 if (std::is_same<A, IntLiteral>::value) {
667 b.make_folded_const(val_b, ty, state);
668 if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
669 (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
670 // Short circuit
671 val = val_b;
672 return;
673 }
674 const uint16_t l = ty.lanes;
675 a.make_folded_const(val_a, ty, state);
676 ty.lanes |= l; // Make sure the overflow bits are sticky
677 } else {
678 a.make_folded_const(val_a, ty, state);
679 if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
680 (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
681 // Short circuit
682 val = val_a;
683 return;
684 }
685 const uint16_t l = ty.lanes;
686 b.make_folded_const(val_b, ty, state);
687 ty.lanes |= l;
688 }
689 switch (ty.code) {
690 case halide_type_int:
691 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
692 break;
693 case halide_type_uint:
694 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
695 break;
698 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
699 break;
700 default:
701 // unreachable
702 ;
703 }
704 }
705
707 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
708 Expr ea, eb;
709 if (std::is_same<A, IntLiteral>::value) {
710 eb = b.make(state, type_hint);
711 ea = a.make(state, eb.type());
712 } else {
713 ea = a.make(state, type_hint);
714 eb = b.make(state, ea.type());
715 }
716 return Op::make(std::move(ea), std::move(eb));
717 }
718};
719
720template<typename Op>
722
723template<typename Op>
725
726template<typename Op>
727uint64_t constant_fold_cmp_op(double, double) noexcept;
728
729// Matches one of the comparison operators
730template<typename Op, typename A, typename B>
731struct CmpOp {
732 struct pattern_tag {};
733 A a;
734 B b;
735
737
738 constexpr static IRNodeType min_node_type = Op::_node_type;
739 constexpr static IRNodeType max_node_type = Op::_node_type;
740 constexpr static bool canonical = (A::canonical &&
741 B::canonical &&
742 (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
743 (Op::_node_type != IRNodeType::GE) &&
744 (Op::_node_type != IRNodeType::GT));
745
746 template<uint32_t bound>
747 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
748 if (e.node_type != Op::_node_type) {
749 return false;
750 }
751 const Op &op = (const Op &)e;
752 return (a.template match<bound>(*op.a.get(), state) &&
753 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
754 }
755
756 template<uint32_t bound, typename Op2, typename A2, typename B2>
757 HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
758 return (std::is_same<Op, Op2>::value &&
759 a.template match<bound>(unwrap(op.a), state) &&
760 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
761 }
762
763 constexpr static bool foldable = A::foldable && B::foldable;
764
766 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
767 halide_scalar_value_t val_a, val_b;
768 // If one side is an untyped const, evaluate the other side first to get a type hint.
769 if (std::is_same<A, IntLiteral>::value) {
770 b.make_folded_const(val_b, ty, state);
771 const uint16_t l = ty.lanes;
772 a.make_folded_const(val_a, ty, state);
773 ty.lanes |= l;
774 } else {
775 a.make_folded_const(val_a, ty, state);
776 const uint16_t l = ty.lanes;
777 b.make_folded_const(val_b, ty, state);
778 ty.lanes |= l;
779 }
780 switch (ty.code) {
781 case halide_type_int:
782 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
783 break;
784 case halide_type_uint:
785 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
786 break;
789 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
790 break;
791 default:
792 // unreachable
793 ;
794 }
795 ty.code = halide_type_uint;
796 ty.bits = 1;
797 }
798
800 Expr make(MatcherState &state, halide_type_t type_hint) const {
801 // If one side is an untyped const, evaluate the other side first to get a type hint.
802 Expr ea, eb;
803 if (std::is_same<A, IntLiteral>::value) {
804 eb = b.make(state, {});
805 ea = a.make(state, eb.type());
806 } else {
807 ea = a.make(state, {});
808 eb = b.make(state, ea.type());
809 }
810 return Op::make(std::move(ea), std::move(eb));
811 }
812};
813
814template<typename A, typename B>
815std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
816 s << "(" << op.a << " + " << op.b << ")";
817 return s;
818}
819
820template<typename A, typename B>
821std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
822 s << "(" << op.a << " - " << op.b << ")";
823 return s;
824}
825
826template<typename A, typename B>
827std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
828 s << "(" << op.a << " * " << op.b << ")";
829 return s;
830}
831
832template<typename A, typename B>
833std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
834 s << "(" << op.a << " / " << op.b << ")";
835 return s;
836}
837
838template<typename A, typename B>
839std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
840 s << "(" << op.a << " && " << op.b << ")";
841 return s;
842}
843
844template<typename A, typename B>
845std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
846 s << "(" << op.a << " || " << op.b << ")";
847 return s;
848}
849
850template<typename A, typename B>
851std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
852 s << "min(" << op.a << ", " << op.b << ")";
853 return s;
854}
855
856template<typename A, typename B>
857std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
858 s << "max(" << op.a << ", " << op.b << ")";
859 return s;
860}
861
862template<typename A, typename B>
863std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
864 s << "(" << op.a << " <= " << op.b << ")";
865 return s;
866}
867
868template<typename A, typename B>
869std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
870 s << "(" << op.a << " < " << op.b << ")";
871 return s;
872}
873
874template<typename A, typename B>
875std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
876 s << "(" << op.a << " >= " << op.b << ")";
877 return s;
878}
879
880template<typename A, typename B>
881std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
882 s << "(" << op.a << " > " << op.b << ")";
883 return s;
884}
885
886template<typename A, typename B>
887std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
888 s << "(" << op.a << " == " << op.b << ")";
889 return s;
890}
891
892template<typename A, typename B>
893std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
894 s << "(" << op.a << " != " << op.b << ")";
895 return s;
896}
897
898template<typename A, typename B>
899std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
900 s << "(" << op.a << " % " << op.b << ")";
901 return s;
902}
903
904template<typename A, typename B>
905HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
908 return {pattern_arg(a), pattern_arg(b)};
909}
910
911template<typename A, typename B>
912HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
915 return IRMatcher::operator+(a, b);
916}
917
918template<>
920 t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
921 int dead_bits = 64 - t.bits;
922 // Drop the high bits then sign-extend them back
923 return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
924}
925
926template<>
928 uint64_t ones = (uint64_t)(-1);
929 return (a + b) & (ones >> (64 - t.bits));
930}
931
932template<>
933HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
934 return a + b;
935}
936
937template<typename A, typename B>
938HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
941 return {pattern_arg(a), pattern_arg(b)};
942}
943
944template<typename A, typename B>
945HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
948 return IRMatcher::operator-(a, b);
949}
950
951template<>
953 t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
954 // Drop the high bits then sign-extend them back
955 int dead_bits = 64 - t.bits;
956 return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
957}
958
959template<>
961 uint64_t ones = (uint64_t)(-1);
962 return (a - b) & (ones >> (64 - t.bits));
963}
964
965template<>
966HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
967 return a - b;
968}
969
970template<typename A, typename B>
971HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
974 return {pattern_arg(a), pattern_arg(b)};
975}
976
977template<typename A, typename B>
978HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
981 return IRMatcher::operator*(a, b);
982}
983
984template<>
986 t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
987 int dead_bits = 64 - t.bits;
988 // Drop the high bits then sign-extend them back
989 return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
990}
991
992template<>
994 uint64_t ones = (uint64_t)(-1);
995 return (a * b) & (ones >> (64 - t.bits));
996}
997
998template<>
999HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1000 return a * b;
1001}
1002
1003template<typename A, typename B>
1004HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1007 return {pattern_arg(a), pattern_arg(b)};
1008}
1009
1010template<typename A, typename B>
1011HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1012 return IRMatcher::operator/(a, b);
1013}
1014
1015template<>
1019
1020template<>
1024
1025template<>
1026HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1027 return div_imp(a, b);
1028}
1029
1030template<typename A, typename B>
1031HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1034 return {pattern_arg(a), pattern_arg(b)};
1035}
1036
1037template<typename A, typename B>
1038HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1041 return IRMatcher::operator%(a, b);
1042}
1043
1044template<>
1048
1049template<>
1053
1054template<>
1055HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1056 return mod_imp(a, b);
1057}
1058
1059template<typename A, typename B>
1060HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1063 return {pattern_arg(a), pattern_arg(b)};
1064}
1065
1066template<>
1068 return std::min(a, b);
1069}
1070
1071template<>
1073 return std::min(a, b);
1074}
1075
1076template<>
1077HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1078 return std::min(a, b);
1079}
1080
1081template<typename A, typename B>
1082HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1085 return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1086}
1087
1088template<>
1090 return std::max(a, b);
1091}
1092
1093template<>
1095 return std::max(a, b);
1096}
1097
1098template<>
1099HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1100 return std::max(a, b);
1101}
1102
1103template<typename A, typename B>
1104HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1105 return {pattern_arg(a), pattern_arg(b)};
1106}
1107
1108template<typename A, typename B>
1109HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1110 return IRMatcher::operator<(a, b);
1111}
1112
1113template<>
1117
1118template<>
1122
1123template<>
1125 return a < b;
1126}
1127
1128template<typename A, typename B>
1129HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1130 return {pattern_arg(a), pattern_arg(b)};
1131}
1132
1133template<typename A, typename B>
1134HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1135 return IRMatcher::operator>(a, b);
1136}
1137
1138template<>
1142
1143template<>
1147
1148template<>
1150 return a > b;
1151}
1152
1153template<typename A, typename B>
1154HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1155 return {pattern_arg(a), pattern_arg(b)};
1156}
1157
1158template<typename A, typename B>
1159HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1160 return IRMatcher::operator<=(a, b);
1161}
1162
1163template<>
1165 return a <= b;
1166}
1167
1168template<>
1172
1173template<>
1175 return a <= b;
1176}
1177
1178template<typename A, typename B>
1179HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1180 return {pattern_arg(a), pattern_arg(b)};
1181}
1182
1183template<typename A, typename B>
1184HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1185 return IRMatcher::operator>=(a, b);
1186}
1187
1188template<>
1190 return a >= b;
1191}
1192
1193template<>
1197
1198template<>
1200 return a >= b;
1201}
1202
1203template<typename A, typename B>
1204HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1205 return {pattern_arg(a), pattern_arg(b)};
1206}
1207
1208template<typename A, typename B>
1209HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1210 return IRMatcher::operator==(a, b);
1211}
1212
1213template<>
1215 return a == b;
1216}
1217
1218template<>
1222
1223template<>
1225 return a == b;
1226}
1227
1228template<typename A, typename B>
1229HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1230 return {pattern_arg(a), pattern_arg(b)};
1231}
1232
1233template<typename A, typename B>
1234HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1235 return IRMatcher::operator!=(a, b);
1236}
1237
1238template<>
1240 return a != b;
1241}
1242
1243template<>
1247
1248template<>
1250 return a != b;
1251}
1252
1253template<typename A, typename B>
1254HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1255 return {pattern_arg(a), pattern_arg(b)};
1256}
1257
1258template<typename A, typename B>
1259HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1260 return IRMatcher::operator||(a, b);
1261}
1262
1263template<>
1265 return (a | b) & 1;
1266}
1267
1268template<>
1270 return (a | b) & 1;
1271}
1272
1273template<>
1274HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1275 // Unreachable, as it would be a type mismatch.
1276 return 0;
1277}
1278
1279template<typename A, typename B>
1280HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1281 return {pattern_arg(a), pattern_arg(b)};
1282}
1283
1284template<typename A, typename B>
1285HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1286 return IRMatcher::operator&&(a, b);
1287}
1288
1289template<>
1291 return a & b & 1;
1292}
1293
1294template<>
1298
1299template<>
1300HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1301 // Unreachable
1302 return 0;
1303}
1304
1306 return 0;
1307}
1308
1309template<typename... Args>
1310constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1311 return first | bitwise_or_reduce(rest...);
1312}
1313
1314constexpr bool and_reduce() {
1315 return true;
1316}
1317
1318template<typename... Args>
1319constexpr bool and_reduce(bool first, Args... rest) {
1320 return first && and_reduce(rest...);
1321}
1322
1323// TODO: this can be replaced with std::min() once we require C++14 or later
1324constexpr int const_min(int a, int b) {
1325 return a < b ? a : b;
1326}
1327
1328template<Call::IntrinsicOp intrin>
1330 bool check(const Type &) const {
1331 return true;
1332 }
1333};
1334
1335template<>
1338 bool check(const Type &t) const {
1339 return t == Type(type);
1340 }
1341};
1342
1343template<Call::IntrinsicOp intrin, typename... Args>
1344struct Intrin {
1345 struct pattern_tag {};
1346 std::tuple<Args...> args;
1347 // The type of the output of the intrinsic node.
1348 // Only necessary in cases where it can't be inferred
1349 // from the input types (e.g. saturating_cast).
1350
1352
1354
1357 constexpr static bool canonical = and_reduce((Args::canonical)...);
1358
1359 template<int i,
1360 uint32_t bound,
1361 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1362 HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1363 using T = decltype(std::get<i>(args));
1364 return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1365 match_args<i + 1, bound | bindings<T>::mask>(0, c, state));
1366 }
1367
1368 template<int i, uint32_t binds>
1369 HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1370 return true;
1371 }
1372
1373 template<uint32_t bound>
1374 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1375 if (e.node_type != IRNodeType::Call) {
1376 return false;
1377 }
1378 const Call &c = (const Call &)e;
1379 return (c.is_intrinsic(intrin) &&
1380 optional_type_hint.check(e.type) &&
1381 match_args<0, bound>(0, c, state));
1382 }
1383
1384 template<int i,
1385 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1386 HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1387 s << std::get<i>(args);
1388 if (i + 1 < sizeof...(Args)) {
1389 s << ", ";
1390 }
1391 print_args<i + 1>(0, s);
1392 }
1393
1394 template<int i>
1395 HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1396 }
1397
1399 void print_args(std::ostream &s) const {
1400 print_args<0>(0, s);
1401 }
1402
1404 Expr make(MatcherState &state, halide_type_t type_hint) const {
1405 Expr arg0 = std::get<0>(args).make(state, type_hint);
1406 if (intrin == Call::likely) {
1407 return likely(std::move(arg0));
1408 } else if (intrin == Call::likely_if_innermost) {
1409 return likely_if_innermost(std::move(arg0));
1410 } else if (intrin == Call::abs) {
1411 return abs(std::move(arg0));
1412 } else if constexpr (intrin == Call::saturating_cast) {
1413 return saturating_cast(optional_type_hint.type, std::move(arg0));
1414 }
1415
1416 Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1417 if (intrin == Call::absd) {
1418 return absd(std::move(arg0), std::move(arg1));
1419 } else if (intrin == Call::widen_right_add) {
1420 return widen_right_add(std::move(arg0), std::move(arg1));
1421 } else if (intrin == Call::widen_right_mul) {
1422 return widen_right_mul(std::move(arg0), std::move(arg1));
1423 } else if (intrin == Call::widen_right_sub) {
1424 return widen_right_sub(std::move(arg0), std::move(arg1));
1425 } else if (intrin == Call::widening_add) {
1426 return widening_add(std::move(arg0), std::move(arg1));
1427 } else if (intrin == Call::widening_sub) {
1428 return widening_sub(std::move(arg0), std::move(arg1));
1429 } else if (intrin == Call::widening_mul) {
1430 return widening_mul(std::move(arg0), std::move(arg1));
1431 } else if (intrin == Call::saturating_add) {
1432 return saturating_add(std::move(arg0), std::move(arg1));
1433 } else if (intrin == Call::saturating_sub) {
1434 return saturating_sub(std::move(arg0), std::move(arg1));
1435 } else if (intrin == Call::halving_add) {
1436 return halving_add(std::move(arg0), std::move(arg1));
1437 } else if (intrin == Call::halving_sub) {
1438 return halving_sub(std::move(arg0), std::move(arg1));
1439 } else if (intrin == Call::rounding_halving_add) {
1440 return rounding_halving_add(std::move(arg0), std::move(arg1));
1441 } else if (intrin == Call::shift_left) {
1442 return std::move(arg0) << std::move(arg1);
1443 } else if (intrin == Call::shift_right) {
1444 return std::move(arg0) >> std::move(arg1);
1445 } else if (intrin == Call::rounding_shift_left) {
1446 return rounding_shift_left(std::move(arg0), std::move(arg1));
1447 } else if (intrin == Call::rounding_shift_right) {
1448 return rounding_shift_right(std::move(arg0), std::move(arg1));
1449 }
1450
1451 Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1452 if (intrin == Call::mul_shift_right) {
1453 return mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2));
1454 } else if (intrin == Call::rounding_mul_shift_right) {
1455 return rounding_mul_shift_right(std::move(arg0), std::move(arg1), std::move(arg2));
1456 }
1457
1458 internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1459 return Expr();
1460 }
1461
1462 constexpr static bool foldable = true;
1463
1466 // Assuming the args have the same type as the intrinsic is incorrect in
1467 // general. But for the intrinsics we can fold (just shifts), the LHS
1468 // has the same type as the intrinsic, and we can always treat the RHS
1469 // as a signed int, because we're using 64 bits for it.
1470 std::get<0>(args).make_folded_const(val, ty, state);
1471 halide_type_t signed_ty = ty;
1472 signed_ty.code = halide_type_int;
1473 // We can just directly get the second arg here, because we only want to
1474 // instantiate this method for shifts, which have two args.
1475 std::get<1>(args).make_folded_const(arg1, signed_ty, state);
1476
1477 if (intrin == Call::shift_left) {
1478 if (arg1.u.i64 < 0) {
1479 if (ty.code == halide_type_int) {
1480 // Arithmetic shift
1481 val.u.i64 >>= -arg1.u.i64;
1482 } else {
1483 // Logical shift
1484 val.u.u64 >>= -arg1.u.i64;
1485 }
1486 } else {
1487 val.u.u64 <<= arg1.u.i64;
1488 }
1489 } else if (intrin == Call::shift_right) {
1490 if (arg1.u.i64 > 0) {
1491 if (ty.code == halide_type_int) {
1492 // Arithmetic shift
1493 val.u.i64 >>= arg1.u.i64;
1494 } else {
1495 // Logical shift
1496 val.u.u64 >>= arg1.u.i64;
1497 }
1498 } else {
1499 val.u.u64 <<= -arg1.u.i64;
1500 }
1501 } else {
1502 internal_error << "Folding not implemented for intrinsic: " << intrin;
1503 }
1504 }
1505
1507 Intrin(Args... args) noexcept
1508 : args(args...) {
1509 }
1510};
1511
1512template<Call::IntrinsicOp intrin, typename... Args>
1513std::ostream &operator<<(std::ostream &s, const Intrin<intrin, Args...> &op) {
1514 s << intrin << "(";
1515 op.print_args(s);
1516 s << ")";
1517 return s;
1518}
1519
1520template<typename A, typename B>
1521auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1522 return {pattern_arg(a), pattern_arg(b)};
1523}
1524template<typename A, typename B>
1525auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1526 return {pattern_arg(a), pattern_arg(b)};
1527}
1528template<typename A, typename B>
1529auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<Call::widen_right_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1530 return {pattern_arg(a), pattern_arg(b)};
1531}
1532
1533template<typename A, typename B>
1534auto widening_add(A &&a, B &&b) noexcept -> Intrin<Call::widening_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1535 return {pattern_arg(a), pattern_arg(b)};
1536}
1537template<typename A, typename B>
1538auto widening_sub(A &&a, B &&b) noexcept -> Intrin<Call::widening_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1539 return {pattern_arg(a), pattern_arg(b)};
1540}
1541template<typename A, typename B>
1542auto widening_mul(A &&a, B &&b) noexcept -> Intrin<Call::widening_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1543 return {pattern_arg(a), pattern_arg(b)};
1544}
1545template<typename A, typename B>
1546auto saturating_add(A &&a, B &&b) noexcept -> Intrin<Call::saturating_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1547 return {pattern_arg(a), pattern_arg(b)};
1548}
1549template<typename A, typename B>
1550auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<Call::saturating_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1551 return {pattern_arg(a), pattern_arg(b)};
1552}
1553template<typename A>
1554auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<Call::saturating_cast, decltype(pattern_arg(a))> {
1555 Intrin<Call::saturating_cast, decltype(pattern_arg(a))> p = {pattern_arg(a)};
1556 p.optional_type_hint.type = t;
1557 return p;
1558}
1559template<typename A, typename B>
1560auto halving_add(A &&a, B &&b) noexcept -> Intrin<Call::halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1561 return {pattern_arg(a), pattern_arg(b)};
1562}
1563template<typename A, typename B>
1564auto halving_sub(A &&a, B &&b) noexcept -> Intrin<Call::halving_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1565 return {pattern_arg(a), pattern_arg(b)};
1566}
1567template<typename A, typename B>
1568auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<Call::rounding_halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1569 return {pattern_arg(a), pattern_arg(b)};
1570}
1571template<typename A, typename B>
1572auto shift_left(A &&a, B &&b) noexcept -> Intrin<Call::shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1573 return {pattern_arg(a), pattern_arg(b)};
1574}
1575template<typename A, typename B>
1576auto shift_right(A &&a, B &&b) noexcept -> Intrin<Call::shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1577 return {pattern_arg(a), pattern_arg(b)};
1578}
1579template<typename A, typename B>
1580auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1581 return {pattern_arg(a), pattern_arg(b)};
1582}
1583template<typename A, typename B>
1584auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<Call::rounding_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1585 return {pattern_arg(a), pattern_arg(b)};
1586}
1587template<typename A, typename B, typename C>
1588auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<Call::mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1589 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1590}
1591template<typename A, typename B, typename C>
1592auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<Call::rounding_mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1593 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1594}
1595
1596template<typename A>
1597auto abs(A &&a) noexcept -> Intrin<Call::abs, decltype(pattern_arg(a))> {
1598 return {pattern_arg(a)};
1599}
1600
1601template<typename A, typename B>
1602auto absd(A &&a, B &&b) noexcept -> Intrin<Call::absd, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1603 return {pattern_arg(a), pattern_arg(b)};
1604}
1605
1606template<typename A>
1607auto likely(A &&a) noexcept -> Intrin<Call::likely, decltype(pattern_arg(a))> {
1608 return {pattern_arg(a)};
1609}
1610
1611template<typename A>
1612auto likely_if_innermost(A &&a) noexcept -> Intrin<Call::likely_if_innermost, decltype(pattern_arg(a))> {
1613 return {pattern_arg(a)};
1614}
1615
1616template<typename A>
1617struct NotOp {
1618 struct pattern_tag {};
1619 A a;
1620
1622
1625 constexpr static bool canonical = A::canonical;
1626
1627 template<uint32_t bound>
1628 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1629 if (e.node_type != IRNodeType::Not) {
1630 return false;
1631 }
1632 const Not &op = (const Not &)e;
1633 return (a.template match<bound>(*op.a.get(), state));
1634 }
1635
1636 template<uint32_t bound, typename A2>
1637 HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1638 return a.template match<bound>(unwrap(op.a), state);
1639 }
1640
1642 Expr make(MatcherState &state, halide_type_t type_hint) const {
1643 return Not::make(a.make(state, type_hint));
1644 }
1645
1646 constexpr static bool foldable = A::foldable;
1647
1648 template<typename A1 = A>
1650 a.make_folded_const(val, ty, state);
1651 val.u.u64 = ~val.u.u64;
1652 val.u.u64 &= 1;
1653 }
1654};
1655
1656template<typename A>
1657HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
1659 return {pattern_arg(a)};
1660}
1661
1662template<typename A>
1667
1668template<typename A>
1669inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1670 s << "!(" << op.a << ")";
1671 return s;
1672}
1673
1674template<typename C, typename T, typename F>
1675struct SelectOp {
1676 struct pattern_tag {};
1678 T t;
1679 F f;
1680
1682
1685
1686 constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1687
1688 template<uint32_t bound>
1689 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1690 if (e.node_type != Select::_node_type) {
1691 return false;
1692 }
1693 const Select &op = (const Select &)e;
1694 return (c.template match<bound>(*op.condition.get(), state) &&
1695 t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
1696 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
1697 }
1698 template<uint32_t bound, typename C2, typename T2, typename F2>
1699 HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1700 return (c.template match<bound>(unwrap(instance.c), state) &&
1701 t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
1702 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
1703 }
1704
1706 Expr make(MatcherState &state, halide_type_t type_hint) const {
1707 return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1708 }
1709
1710 constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1711
1712 template<typename C1 = C>
1714 halide_scalar_value_t c_val, t_val, f_val;
1715 halide_type_t c_ty;
1716 c.make_folded_const(c_val, c_ty, state);
1717 if ((c_val.u.u64 & 1) == 1) {
1718 t.make_folded_const(val, ty, state);
1719 } else {
1720 f.make_folded_const(val, ty, state);
1721 }
1723 }
1724};
1725
1726template<typename C, typename T, typename F>
1727std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1728 s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1729 return s;
1730}
1731
1732template<typename C, typename T, typename F>
1733HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
1737 return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1738}
1739
1740template<typename A, typename B>
1742 struct pattern_tag {};
1743 A a;
1745
1747
1750
1751 constexpr static bool canonical = A::canonical && B::canonical;
1752
1753 template<uint32_t bound>
1754 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1755 if (e.node_type == Broadcast::_node_type) {
1756 const Broadcast &op = (const Broadcast &)e;
1757 if (a.template match<bound>(*op.value.get(), state) &&
1758 lanes.template match<bound>(op.lanes, state)) {
1759 return true;
1760 }
1761 }
1762 return false;
1763 }
1764
1765 template<uint32_t bound, typename A2, typename B2>
1766 HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1767 return (a.template match<bound>(unwrap(op.a), state) &&
1768 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1769 }
1770
1772 Expr make(MatcherState &state, halide_type_t type_hint) const {
1773 halide_scalar_value_t lanes_val;
1774 halide_type_t ty;
1775 lanes.make_folded_const(lanes_val, ty, state);
1776 int32_t l = (int32_t)lanes_val.u.i64;
1777 type_hint.lanes /= l;
1778 Expr val = a.make(state, type_hint);
1779 if (l == 1) {
1780 return val;
1781 } else {
1782 return Broadcast::make(std::move(val), l);
1783 }
1784 }
1785
1786 constexpr static bool foldable = false;
1787
1788 template<typename A1 = A>
1790 halide_scalar_value_t lanes_val;
1791 halide_type_t lanes_ty;
1792 lanes.make_folded_const(lanes_val, lanes_ty, state);
1793 uint16_t l = (uint16_t)lanes_val.u.i64;
1794 a.make_folded_const(val, ty, state);
1795 ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1796 }
1797};
1798
1799template<typename A, typename B>
1800inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1801 s << "broadcast(" << op.a << ", " << op.lanes << ")";
1802 return s;
1803}
1804
1805template<typename A, typename B>
1806HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
1808 return {pattern_arg(a), pattern_arg(lanes)};
1809}
1810
1811template<typename A, typename B, typename C>
1812struct RampOp {
1813 struct pattern_tag {};
1814 A a;
1815 B b;
1817
1819
1822
1823 constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1824
1825 template<uint32_t bound>
1826 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1827 if (e.node_type != Ramp::_node_type) {
1828 return false;
1829 }
1830 const Ramp &op = (const Ramp &)e;
1831 if (a.template match<bound>(*op.base.get(), state) &&
1832 b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
1833 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
1834 return true;
1835 } else {
1836 return false;
1837 }
1838 }
1839
1840 template<uint32_t bound, typename A2, typename B2, typename C2>
1841 HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1842 return (a.template match<bound>(unwrap(op.a), state) &&
1843 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
1844 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
1845 }
1846
1848 Expr make(MatcherState &state, halide_type_t type_hint) const {
1849 halide_scalar_value_t lanes_val;
1850 halide_type_t ty;
1851 lanes.make_folded_const(lanes_val, ty, state);
1852 int32_t l = (int32_t)lanes_val.u.i64;
1853 type_hint.lanes /= l;
1854 Expr ea, eb;
1855 eb = b.make(state, type_hint);
1856 ea = a.make(state, eb.type());
1857 return Ramp::make(std::move(ea), std::move(eb), l);
1858 }
1859
1860 constexpr static bool foldable = false;
1861};
1862
1863template<typename A, typename B, typename C>
1864std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1865 s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1866 return s;
1867}
1868
1869template<typename A, typename B, typename C>
1870HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1874 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1875}
1876
1877template<typename A, typename B, VectorReduce::Operator reduce_op>
1879 struct pattern_tag {};
1880 A a;
1882
1884
1887 constexpr static bool canonical = A::canonical;
1888
1889 template<uint32_t bound>
1890 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1891 if (e.node_type == VectorReduce::_node_type) {
1892 const VectorReduce &op = (const VectorReduce &)e;
1893 if (op.op == reduce_op &&
1894 a.template match<bound>(*op.value.get(), state) &&
1895 lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
1896 return true;
1897 }
1898 }
1899 return false;
1900 }
1901
1902 template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1904 return (reduce_op == reduce_op_2 &&
1905 a.template match<bound>(unwrap(op.a), state) &&
1906 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1907 }
1908
1910 Expr make(MatcherState &state, halide_type_t type_hint) const {
1911 halide_scalar_value_t lanes_val;
1912 halide_type_t ty;
1913 lanes.make_folded_const(lanes_val, ty, state);
1914 int l = (int)lanes_val.u.i64;
1915 return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1916 }
1917
1918 constexpr static bool foldable = false;
1919};
1920
1921template<typename A, typename B, VectorReduce::Operator reduce_op>
1922inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1923 s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1924 return s;
1925}
1926
1927template<typename A, typename B>
1928HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
1930 return {pattern_arg(a), pattern_arg(lanes)};
1931}
1932
1933template<typename A, typename B>
1934HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
1936 return {pattern_arg(a), pattern_arg(lanes)};
1937}
1938
1939template<typename A, typename B>
1940HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
1942 return {pattern_arg(a), pattern_arg(lanes)};
1943}
1944
1945template<typename A, typename B>
1946HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
1948 return {pattern_arg(a), pattern_arg(lanes)};
1949}
1950
1951template<typename A, typename B>
1952HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
1954 return {pattern_arg(a), pattern_arg(lanes)};
1955}
1956
1957template<typename A>
1958struct NegateOp {
1959 struct pattern_tag {};
1960 A a;
1961
1963
1966
1967 constexpr static bool canonical = A::canonical;
1968
1969 template<uint32_t bound>
1970 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1971 if (e.node_type != Sub::_node_type) {
1972 return false;
1973 }
1974 const Sub &op = (const Sub &)e;
1975 return (a.template match<bound>(*op.b.get(), state) &&
1976 is_const_zero(op.a));
1977 }
1978
1979 template<uint32_t bound, typename A2>
1980 HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
1981 return a.template match<bound>(unwrap(p.a), state);
1982 }
1983
1985 Expr make(MatcherState &state, halide_type_t type_hint) const {
1986 Expr ea = a.make(state, type_hint);
1987 Expr z = make_zero(ea.type());
1988 return Sub::make(std::move(z), std::move(ea));
1989 }
1990
1991 constexpr static bool foldable = A::foldable;
1992
1993 template<typename A1 = A>
1995 a.make_folded_const(val, ty, state);
1996 int dead_bits = 64 - ty.bits;
1997 switch (ty.code) {
1998 case halide_type_int:
1999 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
2000 // Trying to negate the most negative signed int for a no-overflow type.
2002 } else {
2003 // Negate, drop the high bits, and then sign-extend them back
2004 val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
2005 }
2006 break;
2007 case halide_type_uint:
2008 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
2009 break;
2010 case halide_type_float:
2011 case halide_type_bfloat:
2012 val.u.f64 = -val.u.f64;
2013 break;
2014 default:
2015 // unreachable
2016 ;
2017 }
2018 }
2019};
2020
2021template<typename A>
2022std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
2023 s << "-" << op.a;
2024 return s;
2025}
2026
2027template<typename A>
2028HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
2030 return {pattern_arg(a)};
2031}
2032
2033template<typename A>
2038
2039template<typename A>
2040struct CastOp {
2041 struct pattern_tag {};
2043 A a;
2044
2046
2049 constexpr static bool canonical = A::canonical;
2050
2051 template<uint32_t bound>
2052 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2053 if (e.node_type != Cast::_node_type) {
2054 return false;
2055 }
2056 const Cast &op = (const Cast &)e;
2057 return (e.type == t &&
2058 a.template match<bound>(*op.value.get(), state));
2059 }
2060 template<uint32_t bound, typename A2>
2061 HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2062 return t == op.t && a.template match<bound>(unwrap(op.a), state);
2063 }
2064
2066 Expr make(MatcherState &state, halide_type_t type_hint) const {
2067 return cast(t, a.make(state, {}));
2068 }
2069
2070 constexpr static bool foldable = false;
2071};
2072
2073template<typename A>
2074std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2075 s << "cast(" << op.t << ", " << op.a << ")";
2076 return s;
2077}
2078
2079template<typename A>
2080HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
2082 return {t, pattern_arg(a)};
2083}
2084
2085template<typename A>
2086struct WidenOp {
2087 struct pattern_tag {};
2088 A a;
2089
2091
2094 constexpr static bool canonical = A::canonical;
2095
2096 template<uint32_t bound>
2097 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2098 if (e.node_type != Cast::_node_type) {
2099 return false;
2100 }
2101 const Cast &op = (const Cast &)e;
2102 return (e.type == op.value.type().widen() &&
2103 a.template match<bound>(*op.value.get(), state));
2104 }
2105 template<uint32_t bound, typename A2>
2106 HALIDE_ALWAYS_INLINE bool match(const WidenOp<A2> &op, MatcherState &state) const noexcept {
2107 return a.template match<bound>(unwrap(op.a), state);
2108 }
2109
2111 Expr make(MatcherState &state, halide_type_t type_hint) const {
2112 Expr e = a.make(state, {});
2113 Type w = e.type().widen();
2114 return cast(w, std::move(e));
2115 }
2116
2117 constexpr static bool foldable = false;
2118};
2119
2120template<typename A>
2121std::ostream &operator<<(std::ostream &s, const WidenOp<A> &op) {
2122 s << "widen(" << op.a << ")";
2123 return s;
2124}
2125
2126template<typename A>
2127HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp<decltype(pattern_arg(a))> {
2129 return {pattern_arg(a)};
2130}
2131
2132template<typename Vec, typename Base, typename Stride, typename Lanes>
2133struct SliceOp {
2134 struct pattern_tag {};
2135 Vec vec;
2136 Base base;
2137 Stride stride;
2138 Lanes lanes;
2139
2140 static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2141
2144 constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2145
2146 template<uint32_t bound>
2147 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2148 if (e.node_type != IRNodeType::Shuffle) {
2149 return false;
2150 }
2151 const Shuffle &v = (const Shuffle &)e;
2152 return v.vectors.size() == 1 &&
2153 v.is_slice() &&
2154 vec.template match<bound>(*v.vectors[0].get(), state) &&
2155 base.template match<bound | bindings<Vec>::mask>(v.slice_begin(), state) &&
2156 stride.template match<bound | bindings<Vec>::mask | bindings<Base>::mask>(v.slice_stride(), state) &&
2157 lanes.template match<bound | bindings<Vec>::mask | bindings<Base>::mask | bindings<Stride>::mask>(v.type.lanes(), state);
2158 }
2159
2161 Expr make(MatcherState &state, halide_type_t type_hint) const {
2162 halide_scalar_value_t base_val, stride_val, lanes_val;
2163 halide_type_t ty;
2164 base.make_folded_const(base_val, ty, state);
2165 int b = (int)base_val.u.i64;
2166 stride.make_folded_const(stride_val, ty, state);
2167 int s = (int)stride_val.u.i64;
2168 lanes.make_folded_const(lanes_val, ty, state);
2169 int l = (int)lanes_val.u.i64;
2170 return Shuffle::make_slice(vec.make(state, type_hint), b, s, l);
2171 }
2172
2173 constexpr static bool foldable = false;
2174
2176 SliceOp(Vec v, Base b, Stride s, Lanes l)
2177 : vec(v), base(b), stride(s), lanes(l) {
2178 static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold");
2179 static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold");
2180 static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold");
2181 }
2182};
2183
2184template<typename Vec, typename Base, typename Stride, typename Lanes>
2185std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
2186 s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")";
2187 return s;
2188}
2189
2190template<typename Vec, typename Base, typename Stride, typename Lanes>
2191HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept
2192 -> SliceOp<decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))> {
2193 return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)};
2194}
2195
2196template<typename A>
2197struct Fold {
2198 struct pattern_tag {};
2199 A a;
2200
2202
2205 constexpr static bool canonical = true;
2206
2208 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
2210 halide_type_t ty = type_hint;
2211 a.make_folded_const(c, ty, state);
2212
2213 // The result of the fold may have an underspecified type
2214 // (e.g. because it's from an int literal). Make the type code
2215 // and bits match the required type, if there is one (we can
2216 // tell from the bits field).
2217 if (type_hint.bits) {
2218 if (((int)ty.code == (int)halide_type_int) &&
2219 ((int)type_hint.code == (int)halide_type_float)) {
2220 int64_t x = c.u.i64;
2221 c.u.f64 = (double)x;
2222 }
2223 ty.code = type_hint.code;
2224 ty.bits = type_hint.bits;
2225 }
2226
2227 return make_const_expr(c, ty);
2228 }
2229
2230 constexpr static bool foldable = A::foldable;
2231
2232 template<typename A1 = A>
2234 a.make_folded_const(val, ty, state);
2235 }
2236};
2237
2238template<typename A>
2239HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
2241 return {pattern_arg(a)};
2242}
2243
2244template<typename A>
2245std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2246 s << "fold(" << op.a << ")";
2247 return s;
2248}
2249
2250template<typename A>
2252 struct pattern_tag {};
2253 A a;
2254
2256
2257 // This rule is a predicate, so it always evaluates to a boolean,
2258 // which has IRNodeType UIntImm
2261 constexpr static bool canonical = true;
2262
2263 constexpr static bool foldable = A::foldable;
2264
2265 template<typename A1 = A>
2267 a.make_folded_const(val, ty, state);
2268 ty.code = halide_type_uint;
2269 ty.bits = 64;
2270 val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2271 ty.lanes = 1;
2272 }
2273};
2274
2275template<typename A>
2276HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
2278 return {pattern_arg(a)};
2279}
2280
2281template<typename A>
2282std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2283 s << "overflows(" << op.a << ")";
2284 return s;
2285}
2286
2287struct Overflow {
2288 struct pattern_tag {};
2289
2290 constexpr static uint32_t binds = 0;
2291
2292 // Overflow is an intrinsic, represented as a Call node
2295 constexpr static bool canonical = true;
2296
2297 template<uint32_t bound>
2298 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2299 if (e.node_type != Call::_node_type) {
2300 return false;
2301 }
2302 const Call &op = (const Call &)e;
2304 }
2305
2307 Expr make(MatcherState &state, halide_type_t type_hint) const {
2309 return make_const_special_expr(type_hint);
2310 }
2311
2312 constexpr static bool foldable = true;
2313
2316 val.u.u64 = 0;
2318 }
2319};
2320
2321inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2322 s << "overflow()";
2323 return s;
2324}
2325
2326template<typename A>
2327struct IsConst {
2328 struct pattern_tag {};
2329
2331
2332 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2335 constexpr static bool canonical = true;
2336
2337 A a;
2340
2341 constexpr static bool foldable = true;
2342
2343 template<typename A1 = A>
2345 Expr e = a.make(state, {});
2346 ty.code = halide_type_uint;
2347 ty.bits = 64;
2348 ty.lanes = 1;
2349 if (check_v) {
2350 val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2351 } else {
2352 val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2353 }
2354 }
2355};
2356
2357template<typename A>
2358HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
2360 return {pattern_arg(a), false, 0};
2361}
2362
2363template<typename A>
2364HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
2366 return {pattern_arg(a), true, value};
2367}
2368
2369template<typename A>
2370std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2371 if (op.check_v) {
2372 s << "is_const(" << op.a << ")";
2373 } else {
2374 s << "is_const(" << op.a << ", " << op.v << ")";
2375 }
2376 return s;
2377}
2378
2379template<typename A, typename Prover>
2380struct CanProve {
2381 struct pattern_tag {};
2382 A a;
2383 Prover *prover; // An existing simplifying mutator
2384
2386
2387 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2390 constexpr static bool canonical = true;
2391
2392 constexpr static bool foldable = true;
2393
2394 // Includes a raw call to an inlined make method, so don't inline.
2396 Expr condition = a.make(state, {});
2397 condition = prover->mutate(condition, nullptr);
2398 val.u.u64 = is_const_one(condition);
2400 ty.bits = 1;
2401 ty.lanes = condition.type().lanes();
2402 }
2403};
2404
2405template<typename A, typename Prover>
2406HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
2408 return {pattern_arg(a), p};
2409}
2410
2411template<typename A, typename Prover>
2412std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2413 s << "can_prove(" << op.a << ")";
2414 return s;
2415}
2416
2417template<typename A>
2418struct IsFloat {
2419 struct pattern_tag {};
2420 A a;
2421
2423
2424 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2427 constexpr static bool canonical = true;
2428
2429 constexpr static bool foldable = true;
2430
2433 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2434 Type t = a.make(state, {}).type();
2435 val.u.u64 = t.is_float();
2437 ty.bits = 1;
2438 ty.lanes = t.lanes();
2439 }
2440};
2441
2442template<typename A>
2443HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
2445 return {pattern_arg(a)};
2446}
2447
2448template<typename A>
2449std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2450 s << "is_float(" << op.a << ")";
2451 return s;
2452}
2453
2454template<typename A>
2455struct IsInt {
2456 struct pattern_tag {};
2457 A a;
2460
2462
2463 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2466 constexpr static bool canonical = true;
2467
2468 constexpr static bool foldable = true;
2469
2472 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2473 Type t = a.make(state, {}).type();
2474 val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2476 ty.bits = 1;
2477 ty.lanes = t.lanes();
2478 }
2479};
2480
2481template<typename A>
2482HALIDE_ALWAYS_INLINE auto is_int(A &&a, uint8_t bits = 0, uint16_t lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
2484 return {pattern_arg(a), bits, lanes};
2485}
2486
2487template<typename A>
2488std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2489 s << "is_int(" << op.a;
2490 if (op.bits > 0) {
2491 s << ", " << op.bits;
2492 }
2493 if (op.lanes > 0) {
2494 s << ", " << op.lanes;
2495 }
2496 s << ")";
2497 return s;
2498}
2499
2500template<typename A>
2501struct IsUInt {
2502 struct pattern_tag {};
2503 A a;
2506
2508
2509 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2512 constexpr static bool canonical = true;
2513
2514 constexpr static bool foldable = true;
2515
2518 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2519 Type t = a.make(state, {}).type();
2520 val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2522 ty.bits = 1;
2523 ty.lanes = t.lanes();
2524 }
2525};
2526
2527template<typename A>
2528HALIDE_ALWAYS_INLINE auto is_uint(A &&a, uint8_t bits = 0, uint16_t lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2530 return {pattern_arg(a), bits, lanes};
2531}
2532
2533template<typename A>
2534std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2535 s << "is_uint(" << op.a;
2536 if (op.bits > 0) {
2537 s << ", " << op.bits;
2538 }
2539 if (op.lanes > 0) {
2540 s << ", " << op.lanes;
2541 }
2542 s << ")";
2543 return s;
2544}
2545
2546template<typename A>
2547struct IsScalar {
2548 struct pattern_tag {};
2549 A a;
2550
2552
2553 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2556 constexpr static bool canonical = true;
2557
2558 constexpr static bool foldable = true;
2559
2562 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2563 Type t = a.make(state, {}).type();
2564 val.u.u64 = t.is_scalar();
2566 ty.bits = 1;
2567 ty.lanes = t.lanes();
2568 }
2569};
2570
2571template<typename A>
2572HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
2574 return {pattern_arg(a)};
2575}
2576
2577template<typename A>
2578std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2579 s << "is_scalar(" << op.a << ")";
2580 return s;
2581}
2582
2583template<typename A>
2585 struct pattern_tag {};
2586 A a;
2587
2589
2590 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2593 constexpr static bool canonical = true;
2594
2595 constexpr static bool foldable = true;
2596
2599 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2600 a.make_folded_const(val, ty, state);
2601 const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2602 if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2603 val.u.u64 = (val.u.u64 == max_bits);
2604 } else {
2605 val.u.u64 = 0;
2606 }
2608 ty.bits = 1;
2609 }
2610};
2611
2612template<typename A>
2613HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
2615 return {pattern_arg(a)};
2616}
2617
2618template<typename A>
2619std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2620 s << "is_max_value(" << op.a << ")";
2621 return s;
2622}
2623
2624template<typename A>
2626 struct pattern_tag {};
2627 A a;
2628
2630
2631 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2634 constexpr static bool canonical = true;
2635
2636 constexpr static bool foldable = true;
2637
2640 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2641 a.make_folded_const(val, ty, state);
2642 if (ty.code == halide_type_int) {
2643 const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2644 val.u.u64 = (val.u.u64 == min_bits);
2645 } else if (ty.code == halide_type_uint) {
2646 val.u.u64 = (val.u.u64 == 0);
2647 } else {
2648 val.u.u64 = 0;
2649 }
2651 ty.bits = 1;
2652 }
2653};
2654
2655template<typename A>
2656HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
2658 return {pattern_arg(a)};
2659}
2660
2661template<typename A>
2662std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2663 s << "is_min_value(" << op.a << ")";
2664 return s;
2665}
2666
2667template<typename A>
2668struct LanesOf {
2669 struct pattern_tag {};
2670 A a;
2671
2673
2674 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2677 constexpr static bool canonical = true;
2678
2679 constexpr static bool foldable = true;
2680
2683 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2684 Type t = a.make(state, {}).type();
2685 val.u.u64 = t.lanes();
2687 ty.bits = 32;
2688 ty.lanes = 1;
2689 }
2690};
2691
2692template<typename A>
2693HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf<decltype(pattern_arg(a))> {
2695 return {pattern_arg(a)};
2696}
2697
2698template<typename A>
2699std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
2700 s << "lanes_of(" << op.a << ")";
2701 return s;
2702}
2703
2704// Verify properties of each rewrite rule. Currently just fuzz tests them.
2705template<typename Before,
2706 typename After,
2707 typename Predicate,
2708 typename = typename std::enable_if<std::decay<Before>::type::foldable &&
2709 std::decay<After>::type::foldable>::type>
2710HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2711 halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2712
2713 // We only validate the rules in the scalar case
2714 wildcard_type.lanes = output_type.lanes = 1;
2715
2716 // Track which types this rule has been tested for before
2717 static std::set<uint32_t> tested;
2718
2719 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2720 return;
2721 }
2722
2723 // Print it in a form where it can be piped into a python/z3 validator
2724 debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2725
2726 // Substitute some random constants into the before and after
2727 // expressions and see if the rule holds true. This should catch
2728 // silly errors, but not necessarily corner cases.
2729 static std::mt19937_64 rng(0);
2730 MatcherState state;
2731
2732 Expr exprs[max_wild];
2733
2734 for (int trials = 0; trials < 100; trials++) {
2735 // We want to test small constants more frequently than
2736 // large ones, otherwise we'll just get coverage of
2737 // overflow rules.
2738 int shift = (int)(rng() & (wildcard_type.bits - 1));
2739
2740 for (int i = 0; i < max_wild; i++) {
2741 // Bind all the exprs and constants
2742 switch (wildcard_type.code) {
2743 case halide_type_uint: {
2744 // Normalize to the type's range by adding zero
2745 uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2746 state.set_bound_const(i, val, wildcard_type);
2747 val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2748 exprs[i] = make_const(wildcard_type, val);
2749 state.set_binding(i, *exprs[i].get());
2750 } break;
2751 case halide_type_int: {
2752 int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2753 state.set_bound_const(i, val, wildcard_type);
2754 val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2755 exprs[i] = make_const(wildcard_type, val);
2756 } break;
2757 case halide_type_float:
2758 case halide_type_bfloat: {
2759 // Use a very narrow range of precise floats, so
2760 // that none of the rules a human is likely to
2761 // write have instabilities.
2762 double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2763 state.set_bound_const(i, val, wildcard_type);
2764 val = ((int64_t)(rng() & 15) - 8) / 2.0;
2765 exprs[i] = make_const(wildcard_type, val);
2766 } break;
2767 default:
2768 return; // Don't care about handles
2769 }
2770 state.set_binding(i, *exprs[i].get());
2771 }
2772
2773 halide_scalar_value_t val_pred, val_before, val_after;
2774 halide_type_t type = output_type;
2775 if (!evaluate_predicate(pred, state)) {
2776 continue;
2777 }
2778 before.make_folded_const(val_before, type, state);
2779 uint16_t lanes = type.lanes;
2780 after.make_folded_const(val_after, type, state);
2781 lanes |= type.lanes;
2782
2784 continue;
2785 }
2786
2787 bool ok = true;
2788 switch (output_type.code) {
2789 case halide_type_uint:
2790 // Compare normalized representations
2791 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2792 constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2793 break;
2794 case halide_type_int:
2795 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2796 constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2797 break;
2798 case halide_type_float:
2799 case halide_type_bfloat: {
2800 double error = std::abs(val_before.u.f64 - val_after.u.f64);
2801 // We accept an equal bit pattern (e.g. inf vs inf),
2802 // a small floating point difference, or turning a nan into not-a-nan.
2803 ok &= (error < 0.01 ||
2804 val_before.u.u64 == val_after.u.u64 ||
2805 std::isnan(val_before.u.f64));
2806 break;
2807 }
2808 default:
2809 return;
2810 }
2811
2812 if (!ok) {
2813 debug(0) << "Fails with values:\n";
2814 for (int i = 0; i < max_wild; i++) {
2816 state.get_bound_const(i, val, wildcard_type);
2817 debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2818 }
2819 for (int i = 0; i < max_wild; i++) {
2820 debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2821 }
2822 debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2823 debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2824 debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2826 }
2827 }
2828}
2829
2830template<typename Before,
2831 typename After,
2832 typename Predicate,
2833 typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
2834 std::decay<After>::type::foldable)>::type>
2835HALIDE_ALWAYS_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2836 halide_type_t, halide_type_t, int dummy = 0) noexcept {
2837 // We can't verify rewrite rules that can't be constant-folded.
2838}
2839
2841bool evaluate_predicate(bool x, MatcherState &) noexcept {
2842 return x;
2843}
2844
2845template<typename Pattern,
2846 typename = typename enable_if_pattern<Pattern>::type>
2849 halide_type_t ty = halide_type_of<bool>();
2850 p.make_folded_const(c, ty, state);
2851 // Overflow counts as a failed predicate
2852 return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2853}
2854
2855// #defines for testing
2856
2857// Print all successful or failed matches
2858#define HALIDE_DEBUG_MATCHED_RULES 0
2859#define HALIDE_DEBUG_UNMATCHED_RULES 0
2860
2861// Set to true if you want to fuzz test every rewrite passed to
2862// operator() to ensure the input and the output have the same value
2863// for lots of random values of the wildcards. Run
2864// correctness_simplify with this on.
2865#define HALIDE_FUZZ_TEST_RULES 0
2866
2867template<typename Instance>
2868struct Rewriter {
2869 Instance instance;
2874
2877 : instance(std::move(instance)), output_type(ot), wildcard_type(wt) {
2878 }
2879
2880 template<typename After>
2882#if HALIDE_DEBUG_MATCHED_RULES
2883 debug(0) << instance << " -> " << after << "\n";
2884#endif
2885 result = after.make(state, output_type);
2886 }
2887
2888 template<typename Before,
2889 typename After,
2890 typename = typename enable_if_pattern<Before>::type,
2891 typename = typename enable_if_pattern<After>::type>
2892 HALIDE_ALWAYS_INLINE bool operator()(Before before, After after) {
2893 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2894 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2895 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2896#if HALIDE_FUZZ_TEST_RULES
2897 fuzz_test_rule(before, after, true, wildcard_type, output_type);
2898#endif
2899 if (before.template match<0>(unwrap(instance), state)) {
2900 build_replacement(after);
2901#if HALIDE_DEBUG_MATCHED_RULES
2902 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2903#endif
2904 return true;
2905 } else {
2906#if HALIDE_DEBUG_UNMATCHED_RULES
2907 debug(0) << instance << " does not match " << before << "\n";
2908#endif
2909 return false;
2910 }
2911 }
2912
2913 template<typename Before,
2914 typename = typename enable_if_pattern<Before>::type>
2915 HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept {
2916 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2917 if (before.template match<0>(unwrap(instance), state)) {
2918 result = after;
2919#if HALIDE_DEBUG_MATCHED_RULES
2920 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2921#endif
2922 return true;
2923 } else {
2924#if HALIDE_DEBUG_UNMATCHED_RULES
2925 debug(0) << instance << " does not match " << before << "\n";
2926#endif
2927 return false;
2928 }
2929 }
2930
2931 template<typename Before,
2932 typename = typename enable_if_pattern<Before>::type>
2933 HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept {
2934 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2935#if HALIDE_FUZZ_TEST_RULES
2936 fuzz_test_rule(before, IntLiteral(after), true, wildcard_type, output_type);
2937#endif
2938 if (before.template match<0>(unwrap(instance), state)) {
2939 result = make_const(output_type, after);
2940#if HALIDE_DEBUG_MATCHED_RULES
2941 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2942#endif
2943 return true;
2944 } else {
2945#if HALIDE_DEBUG_UNMATCHED_RULES
2946 debug(0) << instance << " does not match " << before << "\n";
2947#endif
2948 return false;
2949 }
2950 }
2951
2952 template<typename Before,
2953 typename After,
2954 typename Predicate,
2955 typename = typename enable_if_pattern<Before>::type,
2956 typename = typename enable_if_pattern<After>::type,
2957 typename = typename enable_if_pattern<Predicate>::type>
2958 HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred) {
2959 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2960 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2961 static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
2962 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2963 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2964
2965#if HALIDE_FUZZ_TEST_RULES
2966 fuzz_test_rule(before, after, pred, wildcard_type, output_type);
2967#endif
2968 if (before.template match<0>(unwrap(instance), state) &&
2969 evaluate_predicate(pred, state)) {
2970 build_replacement(after);
2971#if HALIDE_DEBUG_MATCHED_RULES
2972 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2973#endif
2974 return true;
2975 } else {
2976#if HALIDE_DEBUG_UNMATCHED_RULES
2977 debug(0) << instance << " does not match " << before << "\n";
2978#endif
2979 return false;
2980 }
2981 }
2982
2983 template<typename Before,
2984 typename Predicate,
2985 typename = typename enable_if_pattern<Before>::type,
2986 typename = typename enable_if_pattern<Predicate>::type>
2987 HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred) {
2988 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2989 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2990
2991 if (before.template match<0>(unwrap(instance), state) &&
2992 evaluate_predicate(pred, state)) {
2993 result = after;
2994#if HALIDE_DEBUG_MATCHED_RULES
2995 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2996#endif
2997 return true;
2998 } else {
2999#if HALIDE_DEBUG_UNMATCHED_RULES
3000 debug(0) << instance << " does not match " << before << "\n";
3001#endif
3002 return false;
3003 }
3004 }
3005
3006 template<typename Before,
3007 typename Predicate,
3008 typename = typename enable_if_pattern<Before>::type,
3009 typename = typename enable_if_pattern<Predicate>::type>
3010 HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred) {
3011 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
3012 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
3013#if HALIDE_FUZZ_TEST_RULES
3014 fuzz_test_rule(before, IntLiteral(after), pred, wildcard_type, output_type);
3015#endif
3016 if (before.template match<0>(unwrap(instance), state) &&
3017 evaluate_predicate(pred, state)) {
3018 result = make_const(output_type, after);
3019#if HALIDE_DEBUG_MATCHED_RULES
3020 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
3021#endif
3022 return true;
3023 } else {
3024#if HALIDE_DEBUG_UNMATCHED_RULES
3025 debug(0) << instance << " does not match " << before << "\n";
3026#endif
3027 return false;
3028 }
3029 }
3030};
3031
3032/** Construct a rewriter for the given instance, which may be a pattern
3033 * with concrete expressions as leaves, or just an expression. The
3034 * second optional argument (wildcard_type) is a hint as to what the
3035 * type of the wildcards is likely to be. If omitted it uses the same
3036 * type as the expression itself. They are not required to be this
3037 * type, but the rule will only be tested for wildcards of that type
3038 * when testing is enabled.
3039 *
3040 * The rewriter can be used to check to see if the instance is one of
3041 * some number of patterns and if so rewrite it into another form,
3042 * using its operator() method. See Simplify.cpp for a bunch of
3043 * example usage.
3044 *
3045 * Important: Any Exprs in patterns are captured by reference, not by
3046 * value, so ensure they outlive the rewriter.
3047 */
3048// @{
3049template<typename Instance,
3050 typename = typename enable_if_pattern<Instance>::type>
3051HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3052 return {pattern_arg(instance), output_type, wildcard_type};
3053}
3054
3055template<typename Instance,
3056 typename = typename enable_if_pattern<Instance>::type>
3057HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3058 return {pattern_arg(instance), output_type, output_type};
3059}
3060
3062auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3063 return {pattern_arg(e), e.type(), wildcard_type};
3064}
3065
3067auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3068 return {pattern_arg(e), e.type(), e.type()};
3069}
3070// @}
3071
3072} // namespace IRMatcher
3073
3074} // namespace Internal
3075} // namespace Halide
3076
3077#endif
#define internal_error
Definition Errors.h:23
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
Definition Debug.h:49
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition IRMatch.h:3051
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition IRMatch.h:567
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< Call::rounding_halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1568
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition IRMatch.h:1259
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition IRMatch.h:1657
auto shift_right(A &&a, B &&b) noexcept -> Intrin< Call::shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1576
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1060
auto widening_add(A &&a, B &&b) noexcept -> Intrin< Call::widening_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1534
HALIDE_ALWAYS_INLINE auto is_int(A &&a, uint8_t bits=0, uint16_t lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2482
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition IRMatch.h:2841
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1016
auto abs(A &&a) noexcept -> Intrin< Call::abs, decltype(pattern_arg(a))>
Definition IRMatch.h:1597
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition IRMatch.h:1234
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, uint8_t bits=0, uint16_t lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2528
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition IRMatch.h:2034
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1154
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:905
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2613
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition IRMatch.h:217
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition IRMatch.h:1285
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition IRMatch.h:1946
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition IRMatch.h:1134
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition IRMatch.h:2358
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1164
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:971
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition IRMatch.h:912
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1521
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition IRMatch.h:1011
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1525
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition IRMatch.h:978
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1082
auto absd(A &&a, B &&b) noexcept -> Intrin< Call::absd, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1602
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:2191
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1870
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1004
HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp< decltype(pattern_arg(a))>
Definition IRMatch.h:2127
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1045
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1290
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition IRMatch.h:559
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< Call::widening_mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1542
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1129
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition IRMatch.h:2080
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition IRMatch.h:2276
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< Call::saturating_cast, decltype(pattern_arg(a))>
Definition IRMatch.h:1554
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition IRMatch.h:576
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1031
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:952
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< Call::rounding_shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1580
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition IRMatch.h:2572
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition IRMatch.h:2239
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition IRMatch.h:1663
auto likely(A &&a) noexcept -> Intrin< Call::likely, decltype(pattern_arg(a))>
Definition IRMatch.h:1607
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1089
constexpr bool and_reduce()
Definition IRMatch.h:1314
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1254
constexpr int max_wild
Definition IRMatch.h:74
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1229
auto halving_add(A &&a, B &&b) noexcept -> Intrin< Call::halving_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1560
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< Call::mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1588
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition IRMatch.h:2443
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< Call::widening_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1538
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1179
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1104
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1280
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition IRMatch.h:1952
constexpr bool commutative(IRNodeType t)
Definition IRMatch.h:615
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition IRMatch.h:945
auto likely_if_innermost(A &&a) noexcept -> Intrin< Call::likely_if_innermost, decltype(pattern_arg(a))>
Definition IRMatch.h:1612
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition IRMatch.h:1940
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:1806
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition IRMatch.h:1733
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< Call::saturating_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1550
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2656
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1067
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition IRMatch.h:2710
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1139
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:985
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1189
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:938
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< Call::rounding_mul_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1592
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< Call::saturating_add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1546
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition IRMatch.h:1159
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition IRMatch.h:1109
auto shift_left(A &&a, B &&b) noexcept -> Intrin< Call::shift_left, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1572
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
Definition IRMatch.h:2693
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< Call::rounding_shift_right, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1584
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1114
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition IRMatch.h:1934
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition IRMatch.h:1928
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1264
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition IRMatch.h:160
constexpr uint32_t bitwise_or_reduce()
Definition IRMatch.h:1305
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1214
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition IRMatch.h:149
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< Call::halving_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1564
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition IRMatch.h:1184
constexpr int const_min(int a, int b)
Definition IRMatch.h:1324
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1239
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition IRMatch.h:1038
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1204
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< Call::widen_right_sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1529
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:919
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition IRMatch.h:2406
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition IRMatch.h:1209
T div_imp(T a, T b)
Definition IROperator.h:273
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
void expr_match_test()
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
bool equal(const RDom &bounds0, const RDom &bounds1)
Return true if bounds0 and bounds1 represent the same bounds.
constexpr IRNodeType StrongestExprNodeType
Definition Expr.h:81
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition IROperator.h:252
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
DstType reinterpret_bits(const SrcType &src)
An aggressive form of reinterpret cast used for correct type-punning.
Definition Util.h:135
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition Expr.h:25
bool is_const(const Expr &e)
Is the expression either an IntImm, a FloatImm, a StringImm, or a Cast of the same,...
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
@ C
No name mangling.
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT8_TYPE__ uint8_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition Expr.h:258
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition Expr.h:327
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition Expr.h:321
A base class for expression nodes.
Definition Expr.h:143
A vector with 'lanes' elements, in which every element is 'value'.
Definition IR.h:259
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
Definition IR.h:265
A function call.
Definition IR.h:490
bool is_intrinsic() const
Definition IR.h:721
static const IRNodeType _node_type
Definition IR.h:766
std::vector< Expr > args
Definition IR.h:492
The actual IR nodes begin here.
Definition IR.h:30
static const IRNodeType _node_type
Definition IR.h:35
Floating point constants.
Definition Expr.h:236
static const FloatImm * make(Type t, double value)
static constexpr bool canonical
Definition IRMatch.h:641
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:664
static constexpr uint32_t binds
Definition IRMatch.h:633
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:645
static constexpr bool foldable
Definition IRMatch.h:661
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:707
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:655
static constexpr IRNodeType max_node_type
Definition IRMatch.h:636
static constexpr IRNodeType min_node_type
Definition IRMatch.h:635
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1748
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1772
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1766
static constexpr uint32_t binds
Definition IRMatch.h:1746
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1754
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1749
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1789
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2395
static constexpr uint32_t binds
Definition IRMatch.h:2385
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2388
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2389
static constexpr bool foldable
Definition IRMatch.h:2392
static constexpr bool canonical
Definition IRMatch.h:2390
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2048
static constexpr bool foldable
Definition IRMatch.h:2070
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2052
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2047
static constexpr uint32_t binds
Definition IRMatch.h:2045
static constexpr bool canonical
Definition IRMatch.h:2049
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:2061
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2066
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:800
static constexpr IRNodeType max_node_type
Definition IRMatch.h:739
static constexpr uint32_t binds
Definition IRMatch.h:736
static constexpr bool canonical
Definition IRMatch.h:740
static constexpr bool foldable
Definition IRMatch.h:763
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:747
static constexpr IRNodeType min_node_type
Definition IRMatch.h:738
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:766
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:757
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2204
static constexpr uint32_t binds
Definition IRMatch.h:2201
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2203
static constexpr bool canonical
Definition IRMatch.h:2205
static constexpr bool foldable
Definition IRMatch.h:2230
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:2208
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2233
static constexpr IRNodeType max_node_type
Definition IRMatch.h:495
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:504
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition IRMatch.h:499
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition IRMatch.h:527
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:539
static constexpr IRNodeType min_node_type
Definition IRMatch.h:494
static constexpr bool canonical
Definition IRMatch.h:496
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:532
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition IRMatch.h:522
static constexpr uint32_t binds
Definition IRMatch.h:492
HALIDE_ALWAYS_INLINE Intrin(Args... args) noexcept
Definition IRMatch.h:1507
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition IRMatch.h:1399
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1374
static constexpr bool foldable
Definition IRMatch.h:1462
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1464
std::tuple< Args... > args
Definition IRMatch.h:1346
static constexpr uint32_t binds
Definition IRMatch.h:1353
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition IRMatch.h:1395
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1362
static constexpr bool canonical
Definition IRMatch.h:1357
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1369
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition IRMatch.h:1386
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1356
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1404
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1355
OptionalIntrinType< intrin > optional_type_hint
Definition IRMatch.h:1351
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2333
static constexpr bool canonical
Definition IRMatch.h:2335
static constexpr bool foldable
Definition IRMatch.h:2341
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2334
static constexpr uint32_t binds
Definition IRMatch.h:2330
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2344
static constexpr bool foldable
Definition IRMatch.h:2429
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2432
static constexpr bool canonical
Definition IRMatch.h:2427
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2425
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2426
static constexpr uint32_t binds
Definition IRMatch.h:2422
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2465
static constexpr bool foldable
Definition IRMatch.h:2468
static constexpr uint32_t binds
Definition IRMatch.h:2461
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2471
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2464
static constexpr bool canonical
Definition IRMatch.h:2466
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2591
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2592
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2598
static constexpr uint32_t binds
Definition IRMatch.h:2588
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2632
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2639
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2633
static constexpr uint32_t binds
Definition IRMatch.h:2629
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2555
static constexpr uint32_t binds
Definition IRMatch.h:2551
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2561
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2554
static constexpr bool foldable
Definition IRMatch.h:2558
static constexpr bool canonical
Definition IRMatch.h:2556
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2517
static constexpr bool foldable
Definition IRMatch.h:2514
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2510
static constexpr bool canonical
Definition IRMatch.h:2512
static constexpr uint32_t binds
Definition IRMatch.h:2507
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2511
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2676
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2682
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2675
static constexpr bool foldable
Definition IRMatch.h:2679
static constexpr uint32_t binds
Definition IRMatch.h:2672
static constexpr bool canonical
Definition IRMatch.h:2677
To save stack space, the matcher objects are largely stateless and immutable.
Definition IRMatch.h:82
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition IRMatch.h:127
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition IRMatch.h:103
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition IRMatch.h:115
static constexpr uint16_t special_values_mask
Definition IRMatch.h:88
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition IRMatch.h:121
halide_type_t bound_const_type[max_wild]
Definition IRMatch.h:90
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition IRMatch.h:93
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition IRMatch.h:134
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition IRMatch.h:98
halide_scalar_value_t bound_const[max_wild]
Definition IRMatch.h:84
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition IRMatch.h:109
static constexpr uint16_t signed_integer_overflow
Definition IRMatch.h:87
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1970
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1985
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition IRMatch.h:1980
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1994
static constexpr uint32_t binds
Definition IRMatch.h:1962
static constexpr bool canonical
Definition IRMatch.h:1967
static constexpr bool foldable
Definition IRMatch.h:1991
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1965
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1964
static constexpr uint32_t binds
Definition IRMatch.h:1621
static constexpr bool foldable
Definition IRMatch.h:1646
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1628
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1624
static constexpr bool canonical
Definition IRMatch.h:1625
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1637
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1642
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1649
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1623
static constexpr uint32_t binds
Definition IRMatch.h:2290
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2294
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2298
static constexpr bool canonical
Definition IRMatch.h:2295
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2307
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2315
static constexpr bool foldable
Definition IRMatch.h:2312
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2293
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2266
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2259
static constexpr uint32_t binds
Definition IRMatch.h:2255
static constexpr bool canonical
Definition IRMatch.h:2261
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2260
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1848
static constexpr bool canonical
Definition IRMatch.h:1823
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1821
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1820
static constexpr uint32_t binds
Definition IRMatch.h:1818
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1841
static constexpr bool foldable
Definition IRMatch.h:1860
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1826
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition IRMatch.h:2881
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition IRMatch.h:2958
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition IRMatch.h:2933
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition IRMatch.h:2876
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition IRMatch.h:2987
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition IRMatch.h:2915
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition IRMatch.h:3010
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition IRMatch.h:2892
static constexpr uint32_t binds
Definition IRMatch.h:1681
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1713
static constexpr bool foldable
Definition IRMatch.h:1710
static constexpr bool canonical
Definition IRMatch.h:1686
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition IRMatch.h:1699
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1689
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1706
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1684
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1683
static constexpr bool canonical
Definition IRMatch.h:2144
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2143
static constexpr bool foldable
Definition IRMatch.h:2173
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
Definition IRMatch.h:2176
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2142
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2147
static constexpr uint32_t binds
Definition IRMatch.h:2140
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2161
static constexpr IRNodeType min_node_type
Definition IRMatch.h:198
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:205
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:210
static constexpr IRNodeType max_node_type
Definition IRMatch.h:199
static constexpr uint32_t binds
Definition IRMatch.h:195
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1903
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1885
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1890
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1886
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1910
static constexpr uint32_t binds
Definition IRMatch.h:2090
static constexpr bool canonical
Definition IRMatch.h:2094
static constexpr bool foldable
Definition IRMatch.h:2117
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2111
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2093
HALIDE_ALWAYS_INLINE bool match(const WidenOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:2106
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2092
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2097
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:352
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:373
static constexpr IRNodeType max_node_type
Definition IRMatch.h:348
static constexpr IRNodeType min_node_type
Definition IRMatch.h:347
static constexpr uint32_t binds
Definition IRMatch.h:345
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:383
static constexpr bool canonical
Definition IRMatch.h:403
static constexpr IRNodeType max_node_type
Definition IRMatch.h:402
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:431
static constexpr uint32_t binds
Definition IRMatch.h:399
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:406
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:441
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition IRMatch.h:425
static constexpr IRNodeType min_node_type
Definition IRMatch.h:401
static constexpr bool foldable
Definition IRMatch.h:438
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:267
static constexpr uint32_t binds
Definition IRMatch.h:226
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:277
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:233
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition IRMatch.h:254
static constexpr IRNodeType min_node_type
Definition IRMatch.h:228
static constexpr IRNodeType max_node_type
Definition IRMatch.h:229
static constexpr uint32_t binds
Definition IRMatch.h:292
static constexpr IRNodeType max_node_type
Definition IRMatch.h:295
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:299
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:330
static constexpr IRNodeType min_node_type
Definition IRMatch.h:294
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:320
static constexpr IRNodeType min_node_type
Definition IRMatch.h:459
static constexpr uint32_t binds
Definition IRMatch.h:457
static constexpr IRNodeType max_node_type
Definition IRMatch.h:460
static constexpr bool canonical
Definition IRMatch.h:461
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:473
static constexpr bool foldable
Definition IRMatch.h:477
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:464
static constexpr uint32_t mask
Definition IRMatch.h:146
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition Expr.h:113
Integer constants.
Definition Expr.h:218
static const IntImm * make(Type t, int64_t value)
Logical not - true if the expression false.
Definition IR.h:193
static Expr make(Expr a)
A linear ramp vector node.
Definition IR.h:247
static const IRNodeType _node_type
Definition IR.h:253
static Expr make(Expr base, Expr stride, int lanes)
A ternary operator.
Definition IR.h:204
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Definition IR.h:209
Construct a new vector by taking elements from another sequence of vectors.
Definition IR.h:855
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector.
std::vector< Expr > vectors
Definition IR.h:856
bool is_slice() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:909
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:906
The difference of two expressions.
Definition IR.h:65
static const IRNodeType _node_type
Definition IR.h:70
static Expr make(Expr a, Expr b)
Unsigned integer constants.
Definition Expr.h:227
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition IR.h:979
static const IRNodeType _node_type
Definition IR.h:998
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Definition Type.h:283
Type widen() const
Return Type with the same type code and number of lanes, but with at least twice as many bits.
Definition Type.h:378
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition Type.h:435
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition Type.h:355
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition Type.h:441
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition Type.h:349
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition Type.h:417
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition Type.h:423
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@3 u
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.