Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
IRMatch.h
Go to the documentation of this file.
1#ifndef HALIDE_IR_MATCH_H
2#define HALIDE_IR_MATCH_H
3
4/** \file
5 * Defines a method to match a fragment of IR against a pattern containing wildcards
6 */
7
8#include <map>
9#include <random>
10#include <set>
11#include <vector>
12
13#include "IR.h"
14#include "IREquality.h"
15#include "IROperator.h"
16
17namespace Halide {
18namespace Internal {
19
20/** Does the first expression have the same structure as the second?
21 * Variables in the first expression with the name * are interpreted
22 * as wildcards, and their matching equivalent in the second
23 * expression is placed in the vector give as the third argument.
24 * Wildcards require the types to match. For the type bits and width,
25 * a 0 indicates "match anything". So an Int(8, 0) will match 8-bit
26 * integer vectors of any width (including scalars), and a UInt(0, 0)
27 * will match any unsigned integer type.
28 *
29 * For example:
30 \code
31 Expr x = Variable::make(Int(32), "*");
32 match(x + x, 3 + (2*k), result)
33 \endcode
34 * should return true, and set result[0] to 3 and
35 * result[1] to 2*k.
36 */
37bool expr_match(const Expr &pattern, const Expr &expr, std::vector<Expr> &result);
38
39/** Does the first expression have the same structure as the second?
40 * Variables are matched consistently. The first time a variable is
41 * matched, it assumes the value of the matching part of the second
42 * expression. Subsequent matches must be equal to the first match.
43 *
44 * For example:
45 \code
46 Var x("x"), y("y");
47 match(x*(x + y), a*(a + b), result)
48 \endcode
49 * should return true, and set result["x"] = a, and result["y"] = b.
50 */
51bool expr_match(const Expr &pattern, const Expr &expr, std::map<std::string, Expr> &result);
52
53/** Rewrite the expression x to have `lanes` lanes. This is useful
54 * for substituting the results of expr_match into a pattern expression. */
55Expr with_lanes(const Expr &x, int lanes);
56
58
59/** An alternative template-metaprogramming approach to expression
60 * matching. Potentially more efficient. We lift the expression
61 * pattern into a type, and then use force-inlined functions to
62 * generate efficient matching and reconstruction code for any
63 * pattern. Pattern elements are either one of the classes in the
64 * namespace IRMatcher, or are non-null Exprs (represented as
65 * BaseExprNode &).
66 *
67 * Pattern elements that are fully specified by their pattern can be
68 * built into an expression using the make method. Some patterns,
69 * such as a broadcast that matches any number of lanes, don't have
70 * enough information to recreate an Expr.
71 */
72namespace IRMatcher {
73
74constexpr int max_wild = 6;
75
76static const halide_type_t i64_type = {halide_type_int, 64, 1};
77
78/** To save stack space, the matcher objects are largely stateless and
79 * immutable. This state object is built up during matching and then
80 * consumed when constructing a replacement Expr.
81 */
85
86 // values of the lanes field with special meaning.
87 static constexpr uint16_t signed_integer_overflow = 0x8000;
88 static constexpr uint16_t special_values_mask = 0x8000; // currently only one
89
91
93 void set_binding(int i, const BaseExprNode &n) noexcept {
94 bindings[i] = &n;
95 }
96
98 const BaseExprNode *get_binding(int i) const noexcept {
99 return bindings[i];
100 }
101
103 void set_bound_const(int i, int64_t s, halide_type_t t) noexcept {
104 bound_const[i].u.i64 = s;
105 bound_const_type[i] = t;
106 }
107
109 void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept {
110 bound_const[i].u.u64 = u;
111 bound_const_type[i] = t;
112 }
113
115 void set_bound_const(int i, double f, halide_type_t t) noexcept {
116 bound_const[i].u.f64 = f;
117 bound_const_type[i] = t;
118 }
119
122 bound_const[i] = val;
123 bound_const_type[i] = t;
124 }
125
127 void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept {
128 val = bound_const[i];
129 type = bound_const_type[i];
130 }
131
133 // NOLINTNEXTLINE(modernize-use-equals-default): Can't use `= default`; clang-tidy complains about noexcept mismatch
134 MatcherState() noexcept {
135 }
136};
137
138template<typename T,
139 typename = typename std::remove_reference<T>::type::pattern_tag>
141 struct type {};
142};
143
144template<typename T>
145struct bindings {
146 constexpr static uint32_t mask = std::remove_reference<T>::type::binds;
147};
148
151 ty.lanes &= ~MatcherState::special_values_mask;
154 }
155 // unreachable
156 return Expr();
157}
158
161 halide_type_t scalar_type = ty;
162 if (scalar_type.lanes & MatcherState::special_values_mask) {
163 return make_const_special_expr(scalar_type);
164 }
165
166 const int lanes = scalar_type.lanes;
167 scalar_type.lanes = 1;
168
169 Expr e;
170 switch (scalar_type.code) {
171 case halide_type_int:
172 e = IntImm::make(scalar_type, val.u.i64);
173 break;
174 case halide_type_uint:
175 e = UIntImm::make(scalar_type, val.u.u64);
176 break;
179 e = FloatImm::make(scalar_type, val.u.f64);
180 break;
181 default:
182 // Unreachable
183 return Expr();
184 }
185 if (lanes > 1) {
186 e = Broadcast::make(e, lanes);
187 }
188 return e;
189}
190
191// A pattern that matches a specific expression
193 struct pattern_tag {};
194
195 constexpr static uint32_t binds = 0;
196
197 // What is the weakest and strongest IR node this could possibly be
200 constexpr static bool canonical = true;
201
203
204 template<uint32_t bound>
205 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
206 return equal(expr, e);
207 }
208
210 Expr make(MatcherState &state, halide_type_t type_hint) const {
211 return Expr(&expr);
212 }
213
214 constexpr static bool foldable = false;
215};
216
217inline std::ostream &operator<<(std::ostream &s, const SpecificExpr &e) {
218 s << Expr(&e.expr);
219 return s;
220}
221
222template<int i>
224 struct pattern_tag {};
225
226 constexpr static uint32_t binds = 1 << i;
227
230 constexpr static bool canonical = true;
231
232 template<uint32_t bound>
233 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
234 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
235 const BaseExprNode *op = &e;
236 if (op->node_type == IRNodeType::Broadcast) {
237 op = ((const Broadcast *)op)->value.get();
238 }
239 if (op->node_type != IRNodeType::IntImm) {
240 return false;
241 }
242 int64_t value = ((const IntImm *)op)->value;
243 if (bound & binds) {
245 halide_type_t type;
246 state.get_bound_const(i, val, type);
247 return (halide_type_t)e.type == type && value == val.u.i64;
248 }
249 state.set_bound_const(i, value, e.type);
250 return true;
251 }
252
253 template<uint32_t bound>
254 HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept {
255 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
256 if (bound & binds) {
258 halide_type_t type;
259 state.get_bound_const(i, val, type);
260 return type == i64_type && value == val.u.i64;
261 }
262 state.set_bound_const(i, value, i64_type);
263 return true;
264 }
265
267 Expr make(MatcherState &state, halide_type_t type_hint) const {
269 halide_type_t type;
270 state.get_bound_const(i, val, type);
271 return make_const_expr(val, type);
272 }
273
274 constexpr static bool foldable = true;
275
278 state.get_bound_const(i, val, ty);
279 }
280};
281
282template<int i>
283std::ostream &operator<<(std::ostream &s, const WildConstInt<i> &c) {
284 s << "ci" << i;
285 return s;
286}
287
288template<int i>
290 struct pattern_tag {};
291
292 constexpr static uint32_t binds = 1 << i;
293
296 constexpr static bool canonical = true;
297
298 template<uint32_t bound>
299 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
300 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
301 const BaseExprNode *op = &e;
302 if (op->node_type == IRNodeType::Broadcast) {
303 op = ((const Broadcast *)op)->value.get();
304 }
305 if (op->node_type != IRNodeType::UIntImm) {
306 return false;
307 }
308 uint64_t value = ((const UIntImm *)op)->value;
309 if (bound & binds) {
311 halide_type_t type;
312 state.get_bound_const(i, val, type);
313 return (halide_type_t)e.type == type && value == val.u.u64;
314 }
315 state.set_bound_const(i, value, e.type);
316 return true;
317 }
318
320 Expr make(MatcherState &state, halide_type_t type_hint) const {
322 halide_type_t type;
323 state.get_bound_const(i, val, type);
324 return make_const_expr(val, type);
325 }
326
327 constexpr static bool foldable = true;
328
330 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
331 state.get_bound_const(i, val, ty);
332 }
333};
334
335template<int i>
336std::ostream &operator<<(std::ostream &s, const WildConstUInt<i> &c) {
337 s << "cu" << i;
338 return s;
339}
340
341template<int i>
343 struct pattern_tag {};
344
345 constexpr static uint32_t binds = 1 << i;
346
349 constexpr static bool canonical = true;
350
351 template<uint32_t bound>
352 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
353 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
354 const BaseExprNode *op = &e;
355 if (op->node_type == IRNodeType::Broadcast) {
356 op = ((const Broadcast *)op)->value.get();
357 }
358 if (op->node_type != IRNodeType::FloatImm) {
359 return false;
360 }
361 double value = ((const FloatImm *)op)->value;
362 if (bound & binds) {
364 halide_type_t type;
365 state.get_bound_const(i, val, type);
366 return (halide_type_t)e.type == type && value == val.u.f64;
367 }
368 state.set_bound_const(i, value, e.type);
369 return true;
370 }
371
373 Expr make(MatcherState &state, halide_type_t type_hint) const {
375 halide_type_t type;
376 state.get_bound_const(i, val, type);
377 return make_const_expr(val, type);
378 }
379
380 constexpr static bool foldable = true;
381
383 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
384 state.get_bound_const(i, val, ty);
385 }
386};
387
388template<int i>
389std::ostream &operator<<(std::ostream &s, const WildConstFloat<i> &c) {
390 s << "cf" << i;
391 return s;
392}
393
394// Matches and binds to any constant Expr. Does not support constant-folding.
395template<int i>
396struct WildConst {
397 struct pattern_tag {};
398
399 constexpr static uint32_t binds = 1 << i;
400
403 constexpr static bool canonical = true;
404
405 template<uint32_t bound>
406 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
407 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
408 const BaseExprNode *op = &e;
409 if (op->node_type == IRNodeType::Broadcast) {
410 op = ((const Broadcast *)op)->value.get();
411 }
412 switch (op->node_type) {
414 return WildConstInt<i>().template match<bound>(e, state);
416 return WildConstUInt<i>().template match<bound>(e, state);
418 return WildConstFloat<i>().template match<bound>(e, state);
419 default:
420 return false;
421 }
422 }
423
424 template<uint32_t bound>
425 HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept {
426 static_assert(i >= 0 && i < max_wild, "Wild with out-of-range index");
427 return WildConstInt<i>().template match<bound>(e, state);
428 }
429
431 Expr make(MatcherState &state, halide_type_t type_hint) const {
433 halide_type_t type;
434 state.get_bound_const(i, val, type);
435 return make_const_expr(val, type);
436 }
437
438 constexpr static bool foldable = true;
439
441 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
442 state.get_bound_const(i, val, ty);
443 }
444};
445
446template<int i>
447std::ostream &operator<<(std::ostream &s, const WildConst<i> &c) {
448 s << "c" << i;
449 return s;
450}
451
452// Matches and binds to any Expr
453template<int i>
454struct Wild {
455 struct pattern_tag {};
456
457 constexpr static uint32_t binds = 1 << (i + 16);
458
461 constexpr static bool canonical = true;
462
463 template<uint32_t bound>
464 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
465 if (bound & binds) {
466 return equal(*state.get_binding(i), e);
467 }
468 state.set_binding(i, e);
469 return true;
470 }
471
473 Expr make(MatcherState &state, halide_type_t type_hint) const {
474 return state.get_binding(i);
475 }
476
477 constexpr static bool foldable = false;
478};
479
480template<int i>
481std::ostream &operator<<(std::ostream &s, const Wild<i> &op) {
482 s << "_" << i;
483 return s;
484}
485
486// Matches a specific constant or broadcast of that constant. The
487// constant must be representable as an int64_t.
489 struct pattern_tag {};
491
492 constexpr static uint32_t binds = 0;
493
496 constexpr static bool canonical = true;
497
500 : v(v) {
501 }
502
503 template<uint32_t bound>
504 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
505 const BaseExprNode *op = &e;
506 if (e.node_type == IRNodeType::Broadcast) {
507 op = ((const Broadcast *)op)->value.get();
508 }
509 switch (op->node_type) {
511 return ((const IntImm *)op)->value == (int64_t)v;
513 return ((const UIntImm *)op)->value == (uint64_t)v;
515 return ((const FloatImm *)op)->value == (double)v;
516 default:
517 return false;
518 }
519 }
520
521 template<uint32_t bound>
522 HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept {
523 return v == val;
524 }
525
526 template<uint32_t bound>
527 HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept {
528 return v == b.v;
529 }
530
532 Expr make(MatcherState &state, halide_type_t type_hint) const {
533 return make_const(type_hint, v);
534 }
535
536 constexpr static bool foldable = true;
537
539 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
540 // Assume type is already correct
541 switch (ty.code) {
542 case halide_type_int:
543 val.u.i64 = v;
544 break;
545 case halide_type_uint:
546 val.u.u64 = (uint64_t)v;
547 break;
550 val.u.f64 = (double)v;
551 break;
552 default:
553 // Unreachable
554 ;
555 }
556 }
557};
558
562
563// Convert a provided pattern, expr, or constant int into the internal
564// representation we use in the matcher trees.
565template<typename T,
566 typename = typename std::decay<T>::type::pattern_tag>
568 return t;
569}
572 return IntLiteral{x};
573}
574
575template<typename T>
577 static_assert(!std::is_same<typename std::decay<T>::type, Expr>::value || std::is_lvalue_reference<T>::value,
578 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
579}
580
582 return {*e.get()};
583}
584
585// Helpers to deref SpecificExprs to const BaseExprNode & rather than
586// passing them by value anywhere (incurring lots of refcounting)
587template<typename T,
588 // T must be a pattern node
589 typename = typename std::decay<T>::type::pattern_tag,
590 // But T may not be SpecificExpr
591 typename = typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
593 return t;
594}
595
598 return e.expr;
599}
600
601inline std::ostream &operator<<(std::ostream &s, const IntLiteral &op) {
602 s << op.v;
603 return s;
604}
605
606template<typename Op>
608
609template<typename Op>
611
612template<typename Op>
613double constant_fold_bin_op(halide_type_t &, double, double) noexcept;
614
615constexpr bool commutative(IRNodeType t) {
616 return (t == IRNodeType::Add ||
617 t == IRNodeType::Mul ||
618 t == IRNodeType::And ||
619 t == IRNodeType::Or ||
620 t == IRNodeType::Min ||
621 t == IRNodeType::Max ||
622 t == IRNodeType::EQ ||
623 t == IRNodeType::NE);
624}
625
626// Matches one of the binary operators
627template<typename Op, typename A, typename B>
628struct BinOp {
629 struct pattern_tag {};
630 A a;
631 B b;
632
634
635 constexpr static IRNodeType min_node_type = Op::_node_type;
636 constexpr static IRNodeType max_node_type = Op::_node_type;
637
638 // For commutative bin ops, we expect the weaker IR node type on
639 // the right. That is, for the rule to be canonical it must be
640 // possible that A is at least as strong as B.
641 constexpr static bool canonical =
642 A::canonical && B::canonical && (!commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
643
644 template<uint32_t bound>
645 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
646 if (e.node_type != Op::_node_type) {
647 return false;
648 }
649 const Op &op = (const Op &)e;
650 return (a.template match<bound>(*op.a.get(), state) &&
651 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
652 }
653
654 template<uint32_t bound, typename Op2, typename A2, typename B2>
655 HALIDE_ALWAYS_INLINE bool match(const BinOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
656 return (std::is_same<Op, Op2>::value &&
657 a.template match<bound>(unwrap(op.a), state) &&
658 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
659 }
660
661 constexpr static bool foldable = A::foldable && B::foldable;
662
664 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
665 halide_scalar_value_t val_a, val_b;
666 if (std::is_same<A, IntLiteral>::value) {
667 b.make_folded_const(val_b, ty, state);
668 if ((std::is_same<Op, And>::value && val_b.u.u64 == 0) ||
669 (std::is_same<Op, Or>::value && val_b.u.u64 == 1)) {
670 // Short circuit
671 val = val_b;
672 return;
673 }
674 const uint16_t l = ty.lanes;
675 a.make_folded_const(val_a, ty, state);
676 ty.lanes |= l; // Make sure the overflow bits are sticky
677 } else {
678 a.make_folded_const(val_a, ty, state);
679 if ((std::is_same<Op, And>::value && val_a.u.u64 == 0) ||
680 (std::is_same<Op, Or>::value && val_a.u.u64 == 1)) {
681 // Short circuit
682 val = val_a;
683 return;
684 }
685 const uint16_t l = ty.lanes;
686 b.make_folded_const(val_b, ty, state);
687 ty.lanes |= l;
688 }
689 switch (ty.code) {
690 case halide_type_int:
691 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.u.i64, val_b.u.i64);
692 break;
693 case halide_type_uint:
694 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.u.u64, val_b.u.u64);
695 break;
698 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.u.f64, val_b.u.f64);
699 break;
700 default:
701 // unreachable
702 ;
703 }
704 }
705
707 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
708 Expr ea, eb;
709 if (std::is_same<A, IntLiteral>::value) {
710 eb = b.make(state, type_hint);
711 ea = a.make(state, eb.type());
712 } else {
713 ea = a.make(state, type_hint);
714 eb = b.make(state, ea.type());
715 }
716 // We sometimes mix vectors and scalars in the rewrite rules,
717 // so insert a broadcast if necessary.
718 if (ea.type().is_vector() && !eb.type().is_vector()) {
719 eb = Broadcast::make(eb, ea.type().lanes());
720 }
721 if (eb.type().is_vector() && !ea.type().is_vector()) {
722 ea = Broadcast::make(ea, eb.type().lanes());
723 }
724 return Op::make(std::move(ea), std::move(eb));
725 }
726};
727
728template<typename Op>
730
731template<typename Op>
733
734template<typename Op>
735uint64_t constant_fold_cmp_op(double, double) noexcept;
736
737// Matches one of the comparison operators
738template<typename Op, typename A, typename B>
739struct CmpOp {
740 struct pattern_tag {};
741 A a;
742 B b;
743
745
746 constexpr static IRNodeType min_node_type = Op::_node_type;
747 constexpr static IRNodeType max_node_type = Op::_node_type;
748 constexpr static bool canonical = (A::canonical &&
749 B::canonical &&
750 (!commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
751 (Op::_node_type != IRNodeType::GE) &&
752 (Op::_node_type != IRNodeType::GT));
753
754 template<uint32_t bound>
755 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
756 if (e.node_type != Op::_node_type) {
757 return false;
758 }
759 const Op &op = (const Op &)e;
760 return (a.template match<bound>(*op.a.get(), state) &&
761 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
762 }
763
764 template<uint32_t bound, typename Op2, typename A2, typename B2>
765 HALIDE_ALWAYS_INLINE bool match(const CmpOp<Op2, A2, B2> &op, MatcherState &state) const noexcept {
766 return (std::is_same<Op, Op2>::value &&
767 a.template match<bound>(unwrap(op.a), state) &&
768 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state));
769 }
770
771 constexpr static bool foldable = A::foldable && B::foldable;
772
774 void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept {
775 halide_scalar_value_t val_a, val_b;
776 // If one side is an untyped const, evaluate the other side first to get a type hint.
777 if (std::is_same<A, IntLiteral>::value) {
778 b.make_folded_const(val_b, ty, state);
779 const uint16_t l = ty.lanes;
780 a.make_folded_const(val_a, ty, state);
781 ty.lanes |= l;
782 } else {
783 a.make_folded_const(val_a, ty, state);
784 const uint16_t l = ty.lanes;
785 b.make_folded_const(val_b, ty, state);
786 ty.lanes |= l;
787 }
788 switch (ty.code) {
789 case halide_type_int:
790 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.i64, val_b.u.i64);
791 break;
792 case halide_type_uint:
793 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.u64, val_b.u.u64);
794 break;
797 val.u.u64 = constant_fold_cmp_op<Op>(val_a.u.f64, val_b.u.f64);
798 break;
799 default:
800 // unreachable
801 ;
802 }
803 ty.code = halide_type_uint;
804 ty.bits = 1;
805 }
806
808 Expr make(MatcherState &state, halide_type_t type_hint) const {
809 // If one side is an untyped const, evaluate the other side first to get a type hint.
810 Expr ea, eb;
811 if (std::is_same<A, IntLiteral>::value) {
812 eb = b.make(state, {});
813 ea = a.make(state, eb.type());
814 } else {
815 ea = a.make(state, {});
816 eb = b.make(state, ea.type());
817 }
818 // We sometimes mix vectors and scalars in the rewrite rules,
819 // so insert a broadcast if necessary.
820 if (ea.type().is_vector() && !eb.type().is_vector()) {
821 eb = Broadcast::make(eb, ea.type().lanes());
822 }
823 if (eb.type().is_vector() && !ea.type().is_vector()) {
824 ea = Broadcast::make(ea, eb.type().lanes());
825 }
826 return Op::make(std::move(ea), std::move(eb));
827 }
828};
829
830template<typename A, typename B>
831std::ostream &operator<<(std::ostream &s, const BinOp<Add, A, B> &op) {
832 s << "(" << op.a << " + " << op.b << ")";
833 return s;
834}
835
836template<typename A, typename B>
837std::ostream &operator<<(std::ostream &s, const BinOp<Sub, A, B> &op) {
838 s << "(" << op.a << " - " << op.b << ")";
839 return s;
840}
841
842template<typename A, typename B>
843std::ostream &operator<<(std::ostream &s, const BinOp<Mul, A, B> &op) {
844 s << "(" << op.a << " * " << op.b << ")";
845 return s;
846}
847
848template<typename A, typename B>
849std::ostream &operator<<(std::ostream &s, const BinOp<Div, A, B> &op) {
850 s << "(" << op.a << " / " << op.b << ")";
851 return s;
852}
853
854template<typename A, typename B>
855std::ostream &operator<<(std::ostream &s, const BinOp<And, A, B> &op) {
856 s << "(" << op.a << " && " << op.b << ")";
857 return s;
858}
859
860template<typename A, typename B>
861std::ostream &operator<<(std::ostream &s, const BinOp<Or, A, B> &op) {
862 s << "(" << op.a << " || " << op.b << ")";
863 return s;
864}
865
866template<typename A, typename B>
867std::ostream &operator<<(std::ostream &s, const BinOp<Min, A, B> &op) {
868 s << "min(" << op.a << ", " << op.b << ")";
869 return s;
870}
871
872template<typename A, typename B>
873std::ostream &operator<<(std::ostream &s, const BinOp<Max, A, B> &op) {
874 s << "max(" << op.a << ", " << op.b << ")";
875 return s;
876}
877
878template<typename A, typename B>
879std::ostream &operator<<(std::ostream &s, const CmpOp<LE, A, B> &op) {
880 s << "(" << op.a << " <= " << op.b << ")";
881 return s;
882}
883
884template<typename A, typename B>
885std::ostream &operator<<(std::ostream &s, const CmpOp<LT, A, B> &op) {
886 s << "(" << op.a << " < " << op.b << ")";
887 return s;
888}
889
890template<typename A, typename B>
891std::ostream &operator<<(std::ostream &s, const CmpOp<GE, A, B> &op) {
892 s << "(" << op.a << " >= " << op.b << ")";
893 return s;
894}
895
896template<typename A, typename B>
897std::ostream &operator<<(std::ostream &s, const CmpOp<GT, A, B> &op) {
898 s << "(" << op.a << " > " << op.b << ")";
899 return s;
900}
901
902template<typename A, typename B>
903std::ostream &operator<<(std::ostream &s, const CmpOp<EQ, A, B> &op) {
904 s << "(" << op.a << " == " << op.b << ")";
905 return s;
906}
907
908template<typename A, typename B>
909std::ostream &operator<<(std::ostream &s, const CmpOp<NE, A, B> &op) {
910 s << "(" << op.a << " != " << op.b << ")";
911 return s;
912}
913
914template<typename A, typename B>
915std::ostream &operator<<(std::ostream &s, const BinOp<Mod, A, B> &op) {
916 s << "(" << op.a << " % " << op.b << ")";
917 return s;
918}
919
920template<typename A, typename B>
921HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp<Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
924 return {pattern_arg(a), pattern_arg(b)};
925}
926
927template<typename A, typename B>
928HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b)) {
931 return IRMatcher::operator+(a, b);
932}
933
934template<>
936 t.lanes |= ((t.bits >= 32) && add_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
937 int dead_bits = 64 - t.bits;
938 // Drop the high bits then sign-extend them back
939 return int64_t((uint64_t(a) + uint64_t(b)) << dead_bits) >> dead_bits;
940}
941
942template<>
944 uint64_t ones = (uint64_t)(-1);
945 return (a + b) & (ones >> (64 - t.bits));
946}
947
948template<>
949HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Add>(halide_type_t &t, double a, double b) noexcept {
950 return a + b;
951}
952
953template<typename A, typename B>
954HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp<Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
957 return {pattern_arg(a), pattern_arg(b)};
958}
959
960template<typename A, typename B>
961HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b)) {
964 return IRMatcher::operator-(a, b);
965}
966
967template<>
969 t.lanes |= ((t.bits >= 32) && sub_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
970 // Drop the high bits then sign-extend them back
971 int dead_bits = 64 - t.bits;
972 return int64_t((uint64_t(a) - uint64_t(b)) << dead_bits) >> dead_bits;
973}
974
975template<>
977 uint64_t ones = (uint64_t)(-1);
978 return (a - b) & (ones >> (64 - t.bits));
979}
980
981template<>
982HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Sub>(halide_type_t &t, double a, double b) noexcept {
983 return a - b;
984}
985
986template<typename A, typename B>
987HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp<Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
990 return {pattern_arg(a), pattern_arg(b)};
991}
992
993template<typename A, typename B>
994HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b)) {
997 return IRMatcher::operator*(a, b);
998}
999
1000template<>
1002 t.lanes |= ((t.bits >= 32) && mul_would_overflow(t.bits, a, b)) ? MatcherState::signed_integer_overflow : 0;
1003 int dead_bits = 64 - t.bits;
1004 // Drop the high bits then sign-extend them back
1005 return int64_t((uint64_t(a) * uint64_t(b)) << dead_bits) >> dead_bits;
1006}
1007
1008template<>
1010 uint64_t ones = (uint64_t)(-1);
1011 return (a * b) & (ones >> (64 - t.bits));
1012}
1013
1014template<>
1015HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mul>(halide_type_t &t, double a, double b) noexcept {
1016 return a * b;
1017}
1018
1019template<typename A, typename B>
1020HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp<Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1023 return {pattern_arg(a), pattern_arg(b)};
1024}
1025
1026template<typename A, typename B>
1027HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b)) {
1028 return IRMatcher::operator/(a, b);
1029}
1030
1031template<>
1035
1036template<>
1040
1041template<>
1042HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Div>(halide_type_t &t, double a, double b) noexcept {
1043 return div_imp(a, b);
1044}
1045
1046template<typename A, typename B>
1047HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp<Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1050 return {pattern_arg(a), pattern_arg(b)};
1051}
1052
1053template<typename A, typename B>
1054HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b)) {
1057 return IRMatcher::operator%(a, b);
1058}
1059
1060template<>
1064
1065template<>
1069
1070template<>
1071HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Mod>(halide_type_t &t, double a, double b) noexcept {
1072 return mod_imp(a, b);
1073}
1074
1075template<typename A, typename B>
1076HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp<Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1079 return {pattern_arg(a), pattern_arg(b)};
1080}
1081
1082template<>
1084 return std::min(a, b);
1085}
1086
1087template<>
1089 return std::min(a, b);
1090}
1091
1092template<>
1093HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Min>(halide_type_t &t, double a, double b) noexcept {
1094 return std::min(a, b);
1095}
1096
1097template<typename A, typename B>
1098HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp<Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1101 return {pattern_arg(std::forward<A>(a)), pattern_arg(std::forward<B>(b))};
1102}
1103
1104template<>
1106 return std::max(a, b);
1107}
1108
1109template<>
1111 return std::max(a, b);
1112}
1113
1114template<>
1115HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Max>(halide_type_t &t, double a, double b) noexcept {
1116 return std::max(a, b);
1117}
1118
1119template<typename A, typename B>
1120HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp<LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1121 return {pattern_arg(a), pattern_arg(b)};
1122}
1123
1124template<typename A, typename B>
1125HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b)) {
1126 return IRMatcher::operator<(a, b);
1127}
1128
1129template<>
1133
1134template<>
1138
1139template<>
1141 return a < b;
1142}
1143
1144template<typename A, typename B>
1145HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp<GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1146 return {pattern_arg(a), pattern_arg(b)};
1147}
1148
1149template<typename A, typename B>
1150HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b)) {
1151 return IRMatcher::operator>(a, b);
1152}
1153
1154template<>
1158
1159template<>
1163
1164template<>
1166 return a > b;
1167}
1168
1169template<typename A, typename B>
1170HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp<LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1171 return {pattern_arg(a), pattern_arg(b)};
1172}
1173
1174template<typename A, typename B>
1175HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b)) {
1176 return IRMatcher::operator<=(a, b);
1177}
1178
1179template<>
1181 return a <= b;
1182}
1183
1184template<>
1188
1189template<>
1191 return a <= b;
1192}
1193
1194template<typename A, typename B>
1195HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp<GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1196 return {pattern_arg(a), pattern_arg(b)};
1197}
1198
1199template<typename A, typename B>
1200HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b)) {
1201 return IRMatcher::operator>=(a, b);
1202}
1203
1204template<>
1206 return a >= b;
1207}
1208
1209template<>
1213
1214template<>
1216 return a >= b;
1217}
1218
1219template<typename A, typename B>
1220HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp<EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1221 return {pattern_arg(a), pattern_arg(b)};
1222}
1223
1224template<typename A, typename B>
1225HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b)) {
1226 return IRMatcher::operator==(a, b);
1227}
1228
1229template<>
1231 return a == b;
1232}
1233
1234template<>
1238
1239template<>
1241 return a == b;
1242}
1243
1244template<typename A, typename B>
1245HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp<NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1246 return {pattern_arg(a), pattern_arg(b)};
1247}
1248
1249template<typename A, typename B>
1250HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b)) {
1251 return IRMatcher::operator!=(a, b);
1252}
1253
1254template<>
1256 return a != b;
1257}
1258
1259template<>
1263
1264template<>
1266 return a != b;
1267}
1268
1269template<typename A, typename B>
1270HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp<Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1271 return {pattern_arg(a), pattern_arg(b)};
1272}
1273
1274template<typename A, typename B>
1275HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b)) {
1276 return IRMatcher::operator||(a, b);
1277}
1278
1279template<>
1281 return (a | b) & 1;
1282}
1283
1284template<>
1286 return (a | b) & 1;
1287}
1288
1289template<>
1290HALIDE_ALWAYS_INLINE double constant_fold_bin_op<Or>(halide_type_t &t, double a, double b) noexcept {
1291 // Unreachable, as it would be a type mismatch.
1292 return 0;
1293}
1294
1295template<typename A, typename B>
1296HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp<And, decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1297 return {pattern_arg(a), pattern_arg(b)};
1298}
1299
1300template<typename A, typename B>
1301HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b)) {
1302 return IRMatcher::operator&&(a, b);
1303}
1304
1305template<>
1307 return a & b & 1;
1308}
1309
1310template<>
1314
1315template<>
1316HALIDE_ALWAYS_INLINE double constant_fold_bin_op<And>(halide_type_t &t, double a, double b) noexcept {
1317 // Unreachable
1318 return 0;
1319}
1320
1322 return 0;
1323}
1324
1325template<typename... Args>
1326constexpr uint32_t bitwise_or_reduce(uint32_t first, Args... rest) {
1327 return first | bitwise_or_reduce(rest...);
1328}
1329
1330constexpr bool and_reduce() {
1331 return true;
1332}
1333
1334template<typename... Args>
1335constexpr bool and_reduce(bool first, Args... rest) {
1336 return first && and_reduce(rest...);
1337}
1338
1339// TODO: this can be replaced with std::min() once we require C++14 or later
1340constexpr int const_min(int a, int b) {
1341 return a < b ? a : b;
1342}
1343
1344template<typename... Args>
1345struct Intrin {
1346 struct pattern_tag {};
1348 std::tuple<Args...> args;
1349 // The type of the output of the intrinsic node.
1350 // Only necessary in cases where it can't be inferred
1351 // from the input types (e.g. saturating_cast).
1353
1355
1358 constexpr static bool canonical = and_reduce((Args::canonical)...);
1359
1360 template<int i,
1361 uint32_t bound,
1362 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1363 HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept {
1364 using T = decltype(std::get<i>(args));
1365 return (std::get<i>(args).template match<bound>(*c.args[i].get(), state) &&
1366 match_args<i + 1, bound | bindings<T>::mask>(0, c, state));
1367 }
1368
1369 template<int i, uint32_t binds>
1370 HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept {
1371 return true;
1372 }
1373
1374 template<uint32_t bound>
1375 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1376 if (e.node_type != IRNodeType::Call) {
1377 return false;
1378 }
1379 const Call &c = (const Call &)e;
1380 return (c.is_intrinsic(intrin) &&
1381 ((optional_type_hint == Type()) || optional_type_hint == e.type) &&
1382 match_args<0, bound>(0, c, state));
1383 }
1384
1385 template<int i,
1386 typename = typename std::enable_if<(i < sizeof...(Args))>::type>
1387 HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const {
1388 s << std::get<i>(args);
1389 if (i + 1 < sizeof...(Args)) {
1390 s << ", ";
1391 }
1392 print_args<i + 1>(0, s);
1393 }
1394
1395 template<int i>
1396 HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const {
1397 }
1398
1400 void print_args(std::ostream &s) const {
1401 print_args<0>(0, s);
1402 }
1403
1405 Expr make(MatcherState &state, halide_type_t type_hint) const {
1406 Expr arg0 = std::get<0>(args).make(state, type_hint);
1407 if (intrin == Call::likely) {
1408 return likely(arg0);
1409 } else if (intrin == Call::likely_if_innermost) {
1410 return likely_if_innermost(arg0);
1411 } else if (intrin == Call::abs) {
1412 return abs(arg0);
1413 } else if (intrin == Call::saturating_cast) {
1414 return saturating_cast(optional_type_hint, arg0);
1415 }
1416
1417 Expr arg1 = std::get<const_min(1, sizeof...(Args) - 1)>(args).make(state, type_hint);
1418 if (intrin == Call::absd) {
1419 return absd(arg0, arg1);
1420 } else if (intrin == Call::widen_right_add) {
1421 return widen_right_add(arg0, arg1);
1422 } else if (intrin == Call::widen_right_mul) {
1423 return widen_right_mul(arg0, arg1);
1424 } else if (intrin == Call::widen_right_sub) {
1425 return widen_right_sub(arg0, arg1);
1426 } else if (intrin == Call::widening_add) {
1427 return widening_add(arg0, arg1);
1428 } else if (intrin == Call::widening_sub) {
1429 return widening_sub(arg0, arg1);
1430 } else if (intrin == Call::widening_mul) {
1431 return widening_mul(arg0, arg1);
1432 } else if (intrin == Call::saturating_add) {
1433 return saturating_add(arg0, arg1);
1434 } else if (intrin == Call::saturating_sub) {
1435 return saturating_sub(arg0, arg1);
1436 } else if (intrin == Call::halving_add) {
1437 return halving_add(arg0, arg1);
1438 } else if (intrin == Call::halving_sub) {
1439 return halving_sub(arg0, arg1);
1440 } else if (intrin == Call::rounding_halving_add) {
1441 return rounding_halving_add(arg0, arg1);
1442 } else if (intrin == Call::shift_left) {
1443 return arg0 << arg1;
1444 } else if (intrin == Call::shift_right) {
1445 return arg0 >> arg1;
1446 } else if (intrin == Call::rounding_shift_left) {
1447 return rounding_shift_left(arg0, arg1);
1448 } else if (intrin == Call::rounding_shift_right) {
1449 return rounding_shift_right(arg0, arg1);
1450 }
1451
1452 Expr arg2 = std::get<const_min(2, sizeof...(Args) - 1)>(args).make(state, type_hint);
1454 return mul_shift_right(arg0, arg1, arg2);
1455 } else if (intrin == Call::rounding_mul_shift_right) {
1456 return rounding_mul_shift_right(arg0, arg1, arg2);
1457 }
1458
1459 internal_error << "Unhandled intrinsic in IRMatcher: " << intrin;
1460 return Expr();
1461 }
1462
1463 constexpr static bool foldable = true;
1464
1467 // Assuming the args have the same type as the intrinsic is incorrect in
1468 // general. But for the intrinsics we can fold (just shifts), the LHS
1469 // has the same type as the intrinsic, and we can always treat the RHS
1470 // as a signed int, because we're using 64 bits for it.
1471 std::get<0>(args).make_folded_const(val, ty, state);
1472 halide_type_t signed_ty = ty;
1473 signed_ty.code = halide_type_int;
1474 // We can just directly get the second arg here, because we only want to
1475 // instantiate this method for shifts, which have two args.
1476 std::get<1>(args).make_folded_const(arg1, signed_ty, state);
1477
1478 if (intrin == Call::shift_left) {
1479 if (arg1.u.i64 < 0) {
1480 if (ty.code == halide_type_int) {
1481 // Arithmetic shift
1482 val.u.i64 >>= -arg1.u.i64;
1483 } else {
1484 // Logical shift
1485 val.u.u64 >>= -arg1.u.i64;
1486 }
1487 } else {
1488 val.u.u64 <<= arg1.u.i64;
1489 }
1490 } else if (intrin == Call::shift_right) {
1491 if (arg1.u.i64 > 0) {
1492 if (ty.code == halide_type_int) {
1493 // Arithmetic shift
1494 val.u.i64 >>= arg1.u.i64;
1495 } else {
1496 // Logical shift
1497 val.u.u64 >>= arg1.u.i64;
1498 }
1499 } else {
1500 val.u.u64 <<= -arg1.u.i64;
1501 }
1502 } else {
1503 internal_error << "Folding not implemented for intrinsic: " << intrin;
1504 }
1505 }
1506
1509 : intrin(intrin), args(args...) {
1510 }
1511};
1512
1513template<typename... Args>
1514std::ostream &operator<<(std::ostream &s, const Intrin<Args...> &op) {
1515 s << op.intrin << "(";
1516 op.print_args(s);
1517 s << ")";
1518 return s;
1519}
1520
1521template<typename... Args>
1522HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin<decltype(pattern_arg(args))...> {
1523 return {intrinsic_op, pattern_arg(args)...};
1524}
1525
1526template<typename A, typename B>
1527auto widen_right_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1529}
1530template<typename A, typename B>
1531auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1533}
1534template<typename A, typename B>
1535auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1537}
1538
1539template<typename A, typename B>
1540auto widening_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1542}
1543template<typename A, typename B>
1544auto widening_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1546}
1547template<typename A, typename B>
1548auto widening_mul(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1550}
1551template<typename A, typename B>
1552auto saturating_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1554}
1555template<typename A, typename B>
1556auto saturating_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1558}
1559template<typename A>
1560auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin<decltype(pattern_arg(a))> {
1561 Intrin<decltype(pattern_arg(a))> p = {Call::saturating_cast, pattern_arg(a)};
1562 p.optional_type_hint = t;
1563 return p;
1564}
1565template<typename A, typename B>
1566auto halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1567 return {Call::halving_add, pattern_arg(a), pattern_arg(b)};
1568}
1569template<typename A, typename B>
1570auto halving_sub(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1571 return {Call::halving_sub, pattern_arg(a), pattern_arg(b)};
1572}
1573template<typename A, typename B>
1574auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1576}
1577template<typename A, typename B>
1578auto shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1579 return {Call::shift_left, pattern_arg(a), pattern_arg(b)};
1580}
1581template<typename A, typename B>
1582auto shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1583 return {Call::shift_right, pattern_arg(a), pattern_arg(b)};
1584}
1585template<typename A, typename B>
1586auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1588}
1589template<typename A, typename B>
1590auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b))> {
1592}
1593template<typename A, typename B, typename C>
1594auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1596}
1597template<typename A, typename B, typename C>
1598auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1600}
1601
1602template<typename A>
1603struct NotOp {
1604 struct pattern_tag {};
1605 A a;
1606
1608
1611 constexpr static bool canonical = A::canonical;
1612
1613 template<uint32_t bound>
1614 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1615 if (e.node_type != IRNodeType::Not) {
1616 return false;
1617 }
1618 const Not &op = (const Not &)e;
1619 return (a.template match<bound>(*op.a.get(), state));
1620 }
1621
1622 template<uint32_t bound, typename A2>
1623 HALIDE_ALWAYS_INLINE bool match(const NotOp<A2> &op, MatcherState &state) const noexcept {
1624 return a.template match<bound>(unwrap(op.a), state);
1625 }
1626
1628 Expr make(MatcherState &state, halide_type_t type_hint) const {
1629 return Not::make(a.make(state, type_hint));
1630 }
1631
1632 constexpr static bool foldable = A::foldable;
1633
1634 template<typename A1 = A>
1636 a.make_folded_const(val, ty, state);
1637 val.u.u64 = ~val.u.u64;
1638 val.u.u64 &= 1;
1639 }
1640};
1641
1642template<typename A>
1643HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp<decltype(pattern_arg(a))> {
1645 return {pattern_arg(a)};
1646}
1647
1648template<typename A>
1653
1654template<typename A>
1655inline std::ostream &operator<<(std::ostream &s, const NotOp<A> &op) {
1656 s << "!(" << op.a << ")";
1657 return s;
1658}
1659
1660template<typename C, typename T, typename F>
1661struct SelectOp {
1662 struct pattern_tag {};
1664 T t;
1665 F f;
1666
1668
1671
1672 constexpr static bool canonical = C::canonical && T::canonical && F::canonical;
1673
1674 template<uint32_t bound>
1675 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1676 if (e.node_type != Select::_node_type) {
1677 return false;
1678 }
1679 const Select &op = (const Select &)e;
1680 return (c.template match<bound>(*op.condition.get(), state) &&
1681 t.template match<bound | bindings<C>::mask>(*op.true_value.get(), state) &&
1682 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(*op.false_value.get(), state));
1683 }
1684 template<uint32_t bound, typename C2, typename T2, typename F2>
1685 HALIDE_ALWAYS_INLINE bool match(const SelectOp<C2, T2, F2> &instance, MatcherState &state) const noexcept {
1686 return (c.template match<bound>(unwrap(instance.c), state) &&
1687 t.template match<bound | bindings<C>::mask>(unwrap(instance.t), state) &&
1688 f.template match<bound | bindings<C>::mask | bindings<T>::mask>(unwrap(instance.f), state));
1689 }
1690
1692 Expr make(MatcherState &state, halide_type_t type_hint) const {
1693 return Select::make(c.make(state, {}), t.make(state, type_hint), f.make(state, type_hint));
1694 }
1695
1696 constexpr static bool foldable = C::foldable && T::foldable && F::foldable;
1697
1698 template<typename C1 = C>
1700 halide_scalar_value_t c_val, t_val, f_val;
1701 halide_type_t c_ty;
1702 c.make_folded_const(c_val, c_ty, state);
1703 if ((c_val.u.u64 & 1) == 1) {
1704 t.make_folded_const(val, ty, state);
1705 } else {
1706 f.make_folded_const(val, ty, state);
1707 }
1709 }
1710};
1711
1712template<typename C, typename T, typename F>
1713std::ostream &operator<<(std::ostream &s, const SelectOp<C, T, F> &op) {
1714 s << "select(" << op.c << ", " << op.t << ", " << op.f << ")";
1715 return s;
1716}
1717
1718template<typename C, typename T, typename F>
1719HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp<decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))> {
1723 return {pattern_arg(c), pattern_arg(t), pattern_arg(f)};
1724}
1725
1726template<typename A, typename B>
1728 struct pattern_tag {};
1729 A a;
1731
1733
1736
1737 constexpr static bool canonical = A::canonical && B::canonical;
1738
1739 template<uint32_t bound>
1740 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1741 if (e.node_type == Broadcast::_node_type) {
1742 const Broadcast &op = (const Broadcast &)e;
1743 if (a.template match<bound>(*op.value.get(), state) &&
1744 lanes.template match<bound>(op.lanes, state)) {
1745 return true;
1746 }
1747 }
1748 return false;
1749 }
1750
1751 template<uint32_t bound, typename A2, typename B2>
1752 HALIDE_ALWAYS_INLINE bool match(const BroadcastOp<A2, B2> &op, MatcherState &state) const noexcept {
1753 return (a.template match<bound>(unwrap(op.a), state) &&
1754 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1755 }
1756
1758 Expr make(MatcherState &state, halide_type_t type_hint) const {
1759 halide_scalar_value_t lanes_val;
1760 halide_type_t ty;
1761 lanes.make_folded_const(lanes_val, ty, state);
1762 int32_t l = (int32_t)lanes_val.u.i64;
1763 type_hint.lanes /= l;
1764 Expr val = a.make(state, type_hint);
1765 if (l == 1) {
1766 return val;
1767 } else {
1768 return Broadcast::make(std::move(val), l);
1769 }
1770 }
1771
1772 constexpr static bool foldable = false;
1773
1774 template<typename A1 = A>
1776 halide_scalar_value_t lanes_val;
1777 halide_type_t lanes_ty;
1778 lanes.make_folded_const(lanes_val, lanes_ty, state);
1779 uint16_t l = (uint16_t)lanes_val.u.i64;
1780 a.make_folded_const(val, ty, state);
1781 ty.lanes = l | (ty.lanes & MatcherState::special_values_mask);
1782 }
1783};
1784
1785template<typename A, typename B>
1786inline std::ostream &operator<<(std::ostream &s, const BroadcastOp<A, B> &op) {
1787 s << "broadcast(" << op.a << ", " << op.lanes << ")";
1788 return s;
1789}
1790
1791template<typename A, typename B>
1792HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes))> {
1794 return {pattern_arg(a), pattern_arg(lanes)};
1795}
1796
1797template<typename A, typename B, typename C>
1798struct RampOp {
1799 struct pattern_tag {};
1800 A a;
1801 B b;
1803
1805
1808
1809 constexpr static bool canonical = A::canonical && B::canonical && C::canonical;
1810
1811 template<uint32_t bound>
1812 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1813 if (e.node_type != Ramp::_node_type) {
1814 return false;
1815 }
1816 const Ramp &op = (const Ramp &)e;
1817 if (a.template match<bound>(*op.base.get(), state) &&
1818 b.template match<bound | bindings<A>::mask>(*op.stride.get(), state) &&
1819 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(op.lanes, state)) {
1820 return true;
1821 } else {
1822 return false;
1823 }
1824 }
1825
1826 template<uint32_t bound, typename A2, typename B2, typename C2>
1827 HALIDE_ALWAYS_INLINE bool match(const RampOp<A2, B2, C2> &op, MatcherState &state) const noexcept {
1828 return (a.template match<bound>(unwrap(op.a), state) &&
1829 b.template match<bound | bindings<A>::mask>(unwrap(op.b), state) &&
1830 lanes.template match<bound | bindings<A>::mask | bindings<B>::mask>(unwrap(op.lanes), state));
1831 }
1832
1834 Expr make(MatcherState &state, halide_type_t type_hint) const {
1835 halide_scalar_value_t lanes_val;
1836 halide_type_t ty;
1837 lanes.make_folded_const(lanes_val, ty, state);
1838 int32_t l = (int32_t)lanes_val.u.i64;
1839 type_hint.lanes /= l;
1840 Expr ea, eb;
1841 eb = b.make(state, type_hint);
1842 ea = a.make(state, eb.type());
1843 return Ramp::make(ea, eb, l);
1844 }
1845
1846 constexpr static bool foldable = false;
1847};
1848
1849template<typename A, typename B, typename C>
1850std::ostream &operator<<(std::ostream &s, const RampOp<A, B, C> &op) {
1851 s << "ramp(" << op.a << ", " << op.b << ", " << op.lanes << ")";
1852 return s;
1853}
1854
1855template<typename A, typename B, typename C>
1856HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp<decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))> {
1860 return {pattern_arg(a), pattern_arg(b), pattern_arg(c)};
1861}
1862
1863template<typename A, typename B, VectorReduce::Operator reduce_op>
1865 struct pattern_tag {};
1866 A a;
1868
1870
1873 constexpr static bool canonical = A::canonical;
1874
1875 template<uint32_t bound>
1876 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1877 if (e.node_type == VectorReduce::_node_type) {
1878 const VectorReduce &op = (const VectorReduce &)e;
1879 if (op.op == reduce_op &&
1880 a.template match<bound>(*op.value.get(), state) &&
1881 lanes.template match<bound | bindings<A>::mask>(op.type.lanes(), state)) {
1882 return true;
1883 }
1884 }
1885 return false;
1886 }
1887
1888 template<uint32_t bound, typename A2, typename B2, VectorReduce::Operator reduce_op_2>
1890 return (reduce_op == reduce_op_2 &&
1891 a.template match<bound>(unwrap(op.a), state) &&
1892 lanes.template match<bound | bindings<A>::mask>(unwrap(op.lanes), state));
1893 }
1894
1896 Expr make(MatcherState &state, halide_type_t type_hint) const {
1897 halide_scalar_value_t lanes_val;
1898 halide_type_t ty;
1899 lanes.make_folded_const(lanes_val, ty, state);
1900 int l = (int)lanes_val.u.i64;
1901 return VectorReduce::make(reduce_op, a.make(state, type_hint), l);
1902 }
1903
1904 constexpr static bool foldable = false;
1905};
1906
1907template<typename A, typename B, VectorReduce::Operator reduce_op>
1908inline std::ostream &operator<<(std::ostream &s, const VectorReduceOp<A, B, reduce_op> &op) {
1909 s << "vector_reduce(" << reduce_op << ", " << op.a << ", " << op.lanes << ")";
1910 return s;
1911}
1912
1913template<typename A, typename B>
1914HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add> {
1916 return {pattern_arg(a), pattern_arg(lanes)};
1917}
1918
1919template<typename A, typename B>
1920HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min> {
1922 return {pattern_arg(a), pattern_arg(lanes)};
1923}
1924
1925template<typename A, typename B>
1926HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max> {
1928 return {pattern_arg(a), pattern_arg(lanes)};
1929}
1930
1931template<typename A, typename B>
1932HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And> {
1934 return {pattern_arg(a), pattern_arg(lanes)};
1935}
1936
1937template<typename A, typename B>
1938HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp<decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or> {
1940 return {pattern_arg(a), pattern_arg(lanes)};
1941}
1942
1943template<typename A>
1944struct NegateOp {
1945 struct pattern_tag {};
1946 A a;
1947
1949
1952
1953 constexpr static bool canonical = A::canonical;
1954
1955 template<uint32_t bound>
1956 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
1957 if (e.node_type != Sub::_node_type) {
1958 return false;
1959 }
1960 const Sub &op = (const Sub &)e;
1961 return (a.template match<bound>(*op.b.get(), state) &&
1962 is_const_zero(op.a));
1963 }
1964
1965 template<uint32_t bound, typename A2>
1966 HALIDE_ALWAYS_INLINE bool match(NegateOp<A2> &&p, MatcherState &state) const noexcept {
1967 return a.template match<bound>(unwrap(p.a), state);
1968 }
1969
1971 Expr make(MatcherState &state, halide_type_t type_hint) const {
1972 Expr ea = a.make(state, type_hint);
1973 Expr z = make_zero(ea.type());
1974 return Sub::make(std::move(z), std::move(ea));
1975 }
1976
1977 constexpr static bool foldable = A::foldable;
1978
1979 template<typename A1 = A>
1981 a.make_folded_const(val, ty, state);
1982 int dead_bits = 64 - ty.bits;
1983 switch (ty.code) {
1984 case halide_type_int:
1985 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
1986 // Trying to negate the most negative signed int for a no-overflow type.
1988 } else {
1989 // Negate, drop the high bits, and then sign-extend them back
1990 val.u.i64 = int64_t(uint64_t(-val.u.i64) << dead_bits) >> dead_bits;
1991 }
1992 break;
1993 case halide_type_uint:
1994 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
1995 break;
1996 case halide_type_float:
1997 case halide_type_bfloat:
1998 val.u.f64 = -val.u.f64;
1999 break;
2000 default:
2001 // unreachable
2002 ;
2003 }
2004 }
2005};
2006
2007template<typename A>
2008std::ostream &operator<<(std::ostream &s, const NegateOp<A> &op) {
2009 s << "-" << op.a;
2010 return s;
2011}
2012
2013template<typename A>
2014HALIDE_ALWAYS_INLINE auto operator-(A &&a) noexcept -> NegateOp<decltype(pattern_arg(a))> {
2016 return {pattern_arg(a)};
2017}
2018
2019template<typename A>
2024
2025template<typename A>
2026struct CastOp {
2027 struct pattern_tag {};
2029 A a;
2030
2032
2035 constexpr static bool canonical = A::canonical;
2036
2037 template<uint32_t bound>
2038 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2039 if (e.node_type != Cast::_node_type) {
2040 return false;
2041 }
2042 const Cast &op = (const Cast &)e;
2043 return (e.type == t &&
2044 a.template match<bound>(*op.value.get(), state));
2045 }
2046 template<uint32_t bound, typename A2>
2047 HALIDE_ALWAYS_INLINE bool match(const CastOp<A2> &op, MatcherState &state) const noexcept {
2048 return t == op.t && a.template match<bound>(unwrap(op.a), state);
2049 }
2050
2052 Expr make(MatcherState &state, halide_type_t type_hint) const {
2053 return cast(t, a.make(state, {}));
2054 }
2055
2056 constexpr static bool foldable = false;
2057};
2058
2059template<typename A>
2060std::ostream &operator<<(std::ostream &s, const CastOp<A> &op) {
2061 s << "cast(" << op.t << ", " << op.a << ")";
2062 return s;
2063}
2064
2065template<typename A>
2066HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp<decltype(pattern_arg(a))> {
2068 return {t, pattern_arg(a)};
2069}
2070
2071template<typename A>
2072struct WidenOp {
2073 struct pattern_tag {};
2074 A a;
2075
2077
2080 constexpr static bool canonical = A::canonical;
2081
2082 template<uint32_t bound>
2083 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2084 if (e.node_type != Cast::_node_type) {
2085 return false;
2086 }
2087 const Cast &op = (const Cast &)e;
2088 return (e.type == op.value.type().widen() &&
2089 a.template match<bound>(*op.value.get(), state));
2090 }
2091 template<uint32_t bound, typename A2>
2092 HALIDE_ALWAYS_INLINE bool match(const WidenOp<A2> &op, MatcherState &state) const noexcept {
2093 return a.template match<bound>(unwrap(op.a), state);
2094 }
2095
2097 Expr make(MatcherState &state, halide_type_t type_hint) const {
2098 Expr e = a.make(state, {});
2099 Type w = e.type().widen();
2100 return cast(w, std::move(e));
2101 }
2102
2103 constexpr static bool foldable = false;
2104};
2105
2106template<typename A>
2107std::ostream &operator<<(std::ostream &s, const WidenOp<A> &op) {
2108 s << "widen(" << op.a << ")";
2109 return s;
2110}
2111
2112template<typename A>
2113HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp<decltype(pattern_arg(a))> {
2115 return {pattern_arg(a)};
2116}
2117
2118template<typename Vec, typename Base, typename Stride, typename Lanes>
2119struct SliceOp {
2120 struct pattern_tag {};
2121 Vec vec;
2122 Base base;
2123 Stride stride;
2124 Lanes lanes;
2125
2126 static constexpr uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2127
2130 constexpr static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2131
2132 template<uint32_t bound>
2133 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2134 if (e.node_type != IRNodeType::Shuffle) {
2135 return false;
2136 }
2137 const Shuffle &v = (const Shuffle &)e;
2138 return v.vectors.size() == 1 &&
2139 v.is_slice() &&
2140 vec.template match<bound>(*v.vectors[0].get(), state) &&
2141 base.template match<bound | bindings<Vec>::mask>(v.slice_begin(), state) &&
2142 stride.template match<bound | bindings<Vec>::mask | bindings<Base>::mask>(v.slice_stride(), state) &&
2143 lanes.template match<bound | bindings<Vec>::mask | bindings<Base>::mask | bindings<Stride>::mask>(v.type.lanes(), state);
2144 }
2145
2147 Expr make(MatcherState &state, halide_type_t type_hint) const {
2148 halide_scalar_value_t base_val, stride_val, lanes_val;
2149 halide_type_t ty;
2150 base.make_folded_const(base_val, ty, state);
2151 int b = (int)base_val.u.i64;
2152 stride.make_folded_const(stride_val, ty, state);
2153 int s = (int)stride_val.u.i64;
2154 lanes.make_folded_const(lanes_val, ty, state);
2155 int l = (int)lanes_val.u.i64;
2156 return Shuffle::make_slice(vec.make(state, type_hint), b, s, l);
2157 }
2158
2159 constexpr static bool foldable = false;
2160
2162 SliceOp(Vec v, Base b, Stride s, Lanes l)
2163 : vec(v), base(b), stride(s), lanes(l) {
2164 static_assert(Base::foldable, "Base of slice should consist only of operations that constant-fold");
2165 static_assert(Stride::foldable, "Stride of slice should consist only of operations that constant-fold");
2166 static_assert(Lanes::foldable, "Lanes of slice should consist only of operations that constant-fold");
2167 }
2168};
2169
2170template<typename Vec, typename Base, typename Stride, typename Lanes>
2171std::ostream &operator<<(std::ostream &s, const SliceOp<Vec, Base, Stride, Lanes> &op) {
2172 s << "slice(" << op.vec << ", " << op.base << ", " << op.stride << ", " << op.lanes << ")";
2173 return s;
2174}
2175
2176template<typename Vec, typename Base, typename Stride, typename Lanes>
2177HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept
2178 -> SliceOp<decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))> {
2179 return {pattern_arg(vec), pattern_arg(base), pattern_arg(stride), pattern_arg(lanes)};
2180}
2181
2182template<typename A>
2183struct Fold {
2184 struct pattern_tag {};
2185 A a;
2186
2188
2191 constexpr static bool canonical = true;
2192
2194 Expr make(MatcherState &state, halide_type_t type_hint) const noexcept {
2196 halide_type_t ty = type_hint;
2197 a.make_folded_const(c, ty, state);
2198
2199 // The result of the fold may have an underspecified type
2200 // (e.g. because it's from an int literal). Make the type code
2201 // and bits match the required type, if there is one (we can
2202 // tell from the bits field).
2203 if (type_hint.bits) {
2204 if (((int)ty.code == (int)halide_type_int) &&
2205 ((int)type_hint.code == (int)halide_type_float)) {
2206 int64_t x = c.u.i64;
2207 c.u.f64 = (double)x;
2208 }
2209 ty.code = type_hint.code;
2210 ty.bits = type_hint.bits;
2211 }
2212
2213 Expr e = make_const_expr(c, ty);
2214 return e;
2215 }
2216
2217 constexpr static bool foldable = A::foldable;
2218
2219 template<typename A1 = A>
2221 a.make_folded_const(val, ty, state);
2222 }
2223};
2224
2225template<typename A>
2226HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold<decltype(pattern_arg(a))> {
2228 return {pattern_arg(a)};
2229}
2230
2231template<typename A>
2232std::ostream &operator<<(std::ostream &s, const Fold<A> &op) {
2233 s << "fold(" << op.a << ")";
2234 return s;
2235}
2236
2237template<typename A>
2239 struct pattern_tag {};
2240 A a;
2241
2243
2244 // This rule is a predicate, so it always evaluates to a boolean,
2245 // which has IRNodeType UIntImm
2248 constexpr static bool canonical = true;
2249
2250 constexpr static bool foldable = A::foldable;
2251
2252 template<typename A1 = A>
2254 a.make_folded_const(val, ty, state);
2255 ty.code = halide_type_uint;
2256 ty.bits = 64;
2257 val.u.u64 = (ty.lanes & MatcherState::special_values_mask) != 0;
2258 ty.lanes = 1;
2259 }
2260};
2261
2262template<typename A>
2263HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows<decltype(pattern_arg(a))> {
2265 return {pattern_arg(a)};
2266}
2267
2268template<typename A>
2269std::ostream &operator<<(std::ostream &s, const Overflows<A> &op) {
2270 s << "overflows(" << op.a << ")";
2271 return s;
2272}
2273
2274struct Overflow {
2275 struct pattern_tag {};
2276
2277 constexpr static uint32_t binds = 0;
2278
2279 // Overflow is an intrinsic, represented as a Call node
2282 constexpr static bool canonical = true;
2283
2284 template<uint32_t bound>
2285 HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept {
2286 if (e.node_type != Call::_node_type) {
2287 return false;
2288 }
2289 const Call &op = (const Call &)e;
2291 }
2292
2294 Expr make(MatcherState &state, halide_type_t type_hint) const {
2296 return make_const_special_expr(type_hint);
2297 }
2298
2299 constexpr static bool foldable = true;
2300
2303 val.u.u64 = 0;
2305 }
2306};
2307
2308inline std::ostream &operator<<(std::ostream &s, const Overflow &op) {
2309 s << "overflow()";
2310 return s;
2311}
2312
2313template<typename A>
2314struct IsConst {
2315 struct pattern_tag {};
2316
2318
2319 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2322 constexpr static bool canonical = true;
2323
2324 A a;
2327
2328 constexpr static bool foldable = true;
2329
2330 template<typename A1 = A>
2332 Expr e = a.make(state, {});
2333 ty.code = halide_type_uint;
2334 ty.bits = 64;
2335 ty.lanes = 1;
2336 if (check_v) {
2337 val.u.u64 = ::Halide::Internal::is_const(e, v) ? 1 : 0;
2338 } else {
2339 val.u.u64 = ::Halide::Internal::is_const(e) ? 1 : 0;
2340 }
2341 }
2342};
2343
2344template<typename A>
2345HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst<decltype(pattern_arg(a))> {
2347 return {pattern_arg(a), false, 0};
2348}
2349
2350template<typename A>
2351HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst<decltype(pattern_arg(a))> {
2353 return {pattern_arg(a), true, value};
2354}
2355
2356template<typename A>
2357std::ostream &operator<<(std::ostream &s, const IsConst<A> &op) {
2358 if (op.check_v) {
2359 s << "is_const(" << op.a << ")";
2360 } else {
2361 s << "is_const(" << op.a << ", " << op.v << ")";
2362 }
2363 return s;
2364}
2365
2366template<typename A, typename Prover>
2367struct CanProve {
2368 struct pattern_tag {};
2369 A a;
2370 Prover *prover; // An existing simplifying mutator
2371
2373
2374 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2377 constexpr static bool canonical = true;
2378
2379 constexpr static bool foldable = true;
2380
2381 // Includes a raw call to an inlined make method, so don't inline.
2383 Expr condition = a.make(state, {});
2384 condition = prover->mutate(condition, nullptr);
2385 val.u.u64 = is_const_one(condition);
2387 ty.bits = 1;
2388 ty.lanes = condition.type().lanes();
2389 }
2390};
2391
2392template<typename A, typename Prover>
2393HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve<decltype(pattern_arg(a)), Prover> {
2395 return {pattern_arg(a), p};
2396}
2397
2398template<typename A, typename Prover>
2399std::ostream &operator<<(std::ostream &s, const CanProve<A, Prover> &op) {
2400 s << "can_prove(" << op.a << ")";
2401 return s;
2402}
2403
2404template<typename A>
2405struct IsFloat {
2406 struct pattern_tag {};
2407 A a;
2408
2410
2411 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2414 constexpr static bool canonical = true;
2415
2416 constexpr static bool foldable = true;
2417
2420 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2421 Type t = a.make(state, {}).type();
2422 val.u.u64 = t.is_float();
2424 ty.bits = 1;
2425 ty.lanes = t.lanes();
2426 }
2427};
2428
2429template<typename A>
2430HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat<decltype(pattern_arg(a))> {
2432 return {pattern_arg(a)};
2433}
2434
2435template<typename A>
2436std::ostream &operator<<(std::ostream &s, const IsFloat<A> &op) {
2437 s << "is_float(" << op.a << ")";
2438 return s;
2439}
2440
2441template<typename A>
2442struct IsInt {
2443 struct pattern_tag {};
2444 A a;
2446
2448
2449 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2452 constexpr static bool canonical = true;
2453
2454 constexpr static bool foldable = true;
2455
2458 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2459 Type t = a.make(state, {}).type();
2460 val.u.u64 = t.is_int() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2462 ty.bits = 1;
2463 ty.lanes = t.lanes();
2464 }
2465};
2466
2467template<typename A>
2468HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits = 0, int lanes = 0) noexcept -> IsInt<decltype(pattern_arg(a))> {
2470 return {pattern_arg(a), bits, lanes};
2471}
2472
2473template<typename A>
2474std::ostream &operator<<(std::ostream &s, const IsInt<A> &op) {
2475 s << "is_int(" << op.a;
2476 if (op.bits > 0) {
2477 s << ", " << op.bits;
2478 }
2479 if (op.lanes > 0) {
2480 s << ", " << op.lanes;
2481 }
2482 s << ")";
2483 return s;
2484}
2485
2486template<typename A>
2487struct IsUInt {
2488 struct pattern_tag {};
2489 A a;
2491
2493
2494 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2497 constexpr static bool canonical = true;
2498
2499 constexpr static bool foldable = true;
2500
2503 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2504 Type t = a.make(state, {}).type();
2505 val.u.u64 = t.is_uint() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes);
2507 ty.bits = 1;
2508 ty.lanes = t.lanes();
2509 }
2510};
2511
2512template<typename A>
2513HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits = 0, int lanes = 0) noexcept -> IsUInt<decltype(pattern_arg(a))> {
2515 return {pattern_arg(a), bits, lanes};
2516}
2517
2518template<typename A>
2519std::ostream &operator<<(std::ostream &s, const IsUInt<A> &op) {
2520 s << "is_uint(" << op.a;
2521 if (op.bits > 0) {
2522 s << ", " << op.bits;
2523 }
2524 if (op.lanes > 0) {
2525 s << ", " << op.lanes;
2526 }
2527 s << ")";
2528 return s;
2529}
2530
2531template<typename A>
2532struct IsScalar {
2533 struct pattern_tag {};
2534 A a;
2535
2537
2538 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2541 constexpr static bool canonical = true;
2542
2543 constexpr static bool foldable = true;
2544
2547 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2548 Type t = a.make(state, {}).type();
2549 val.u.u64 = t.is_scalar();
2551 ty.bits = 1;
2552 ty.lanes = t.lanes();
2553 }
2554};
2555
2556template<typename A>
2557HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar<decltype(pattern_arg(a))> {
2559 return {pattern_arg(a)};
2560}
2561
2562template<typename A>
2563std::ostream &operator<<(std::ostream &s, const IsScalar<A> &op) {
2564 s << "is_scalar(" << op.a << ")";
2565 return s;
2566}
2567
2568template<typename A>
2570 struct pattern_tag {};
2571 A a;
2572
2574
2575 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2578 constexpr static bool canonical = true;
2579
2580 constexpr static bool foldable = true;
2581
2584 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2585 a.make_folded_const(val, ty, state);
2586 const uint64_t max_bits = (uint64_t)(-1) >> (64 - ty.bits + (ty.code == halide_type_int));
2587 if (ty.code == halide_type_uint || ty.code == halide_type_int) {
2588 val.u.u64 = (val.u.u64 == max_bits);
2589 } else {
2590 val.u.u64 = 0;
2591 }
2593 ty.bits = 1;
2594 }
2595};
2596
2597template<typename A>
2598HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue<decltype(pattern_arg(a))> {
2600 return {pattern_arg(a)};
2601}
2602
2603template<typename A>
2604std::ostream &operator<<(std::ostream &s, const IsMaxValue<A> &op) {
2605 s << "is_max_value(" << op.a << ")";
2606 return s;
2607}
2608
2609template<typename A>
2611 struct pattern_tag {};
2612 A a;
2613
2615
2616 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2619 constexpr static bool canonical = true;
2620
2621 constexpr static bool foldable = true;
2622
2625 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2626 a.make_folded_const(val, ty, state);
2627 if (ty.code == halide_type_int) {
2628 const uint64_t min_bits = (uint64_t)(-1) << (ty.bits - 1);
2629 val.u.u64 = (val.u.u64 == min_bits);
2630 } else if (ty.code == halide_type_uint) {
2631 val.u.u64 = (val.u.u64 == 0);
2632 } else {
2633 val.u.u64 = 0;
2634 }
2636 ty.bits = 1;
2637 }
2638};
2639
2640template<typename A>
2641HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue<decltype(pattern_arg(a))> {
2643 return {pattern_arg(a)};
2644}
2645
2646template<typename A>
2647std::ostream &operator<<(std::ostream &s, const IsMinValue<A> &op) {
2648 s << "is_min_value(" << op.a << ")";
2649 return s;
2650}
2651
2652template<typename A>
2653struct LanesOf {
2654 struct pattern_tag {};
2655 A a;
2656
2658
2659 // This rule is a boolean-valued predicate. Bools have type UIntImm.
2662 constexpr static bool canonical = true;
2663
2664 constexpr static bool foldable = true;
2665
2668 // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method.
2669 Type t = a.make(state, {}).type();
2670 val.u.u64 = t.lanes();
2672 ty.bits = 32;
2673 ty.lanes = 1;
2674 }
2675};
2676
2677template<typename A>
2678HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf<decltype(pattern_arg(a))> {
2680 return {pattern_arg(a)};
2681}
2682
2683template<typename A>
2684std::ostream &operator<<(std::ostream &s, const LanesOf<A> &op) {
2685 s << "lanes_of(" << op.a << ")";
2686 return s;
2687}
2688
2689// Verify properties of each rewrite rule. Currently just fuzz tests them.
2690template<typename Before,
2691 typename After,
2692 typename Predicate,
2693 typename = typename std::enable_if<std::decay<Before>::type::foldable &&
2694 std::decay<After>::type::foldable>::type>
2695HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2696 halide_type_t wildcard_type, halide_type_t output_type) noexcept {
2697
2698 // We only validate the rules in the scalar case
2699 wildcard_type.lanes = output_type.lanes = 1;
2700
2701 // Track which types this rule has been tested for before
2702 static std::set<uint32_t> tested;
2703
2704 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2705 return;
2706 }
2707
2708 // Print it in a form where it can be piped into a python/z3 validator
2709 debug(0) << "validate('" << before << "', '" << after << "', '" << pred << "', " << Type(wildcard_type) << ", " << Type(output_type) << ")\n";
2710
2711 // Substitute some random constants into the before and after
2712 // expressions and see if the rule holds true. This should catch
2713 // silly errors, but not necessarily corner cases.
2714 static std::mt19937_64 rng(0);
2715 MatcherState state;
2716
2717 Expr exprs[max_wild];
2718
2719 for (int trials = 0; trials < 100; trials++) {
2720 // We want to test small constants more frequently than
2721 // large ones, otherwise we'll just get coverage of
2722 // overflow rules.
2723 int shift = (int)(rng() & (wildcard_type.bits - 1));
2724
2725 for (int i = 0; i < max_wild; i++) {
2726 // Bind all the exprs and constants
2727 switch (wildcard_type.code) {
2728 case halide_type_uint: {
2729 // Normalize to the type's range by adding zero
2730 uint64_t val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2731 state.set_bound_const(i, val, wildcard_type);
2732 val = constant_fold_bin_op<Add>(wildcard_type, (uint64_t)rng() >> shift, 0);
2733 exprs[i] = make_const(wildcard_type, val);
2734 state.set_binding(i, *exprs[i].get());
2735 } break;
2736 case halide_type_int: {
2737 int64_t val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2738 state.set_bound_const(i, val, wildcard_type);
2739 val = constant_fold_bin_op<Add>(wildcard_type, (int64_t)rng() >> shift, 0);
2740 exprs[i] = make_const(wildcard_type, val);
2741 } break;
2742 case halide_type_float:
2743 case halide_type_bfloat: {
2744 // Use a very narrow range of precise floats, so
2745 // that none of the rules a human is likely to
2746 // write have instabilities.
2747 double val = ((int64_t)(rng() & 15) - 8) / 2.0;
2748 state.set_bound_const(i, val, wildcard_type);
2749 val = ((int64_t)(rng() & 15) - 8) / 2.0;
2750 exprs[i] = make_const(wildcard_type, val);
2751 } break;
2752 default:
2753 return; // Don't care about handles
2754 }
2755 state.set_binding(i, *exprs[i].get());
2756 }
2757
2758 halide_scalar_value_t val_pred, val_before, val_after;
2759 halide_type_t type = output_type;
2760 if (!evaluate_predicate(pred, state)) {
2761 continue;
2762 }
2763 before.make_folded_const(val_before, type, state);
2764 uint16_t lanes = type.lanes;
2765 after.make_folded_const(val_after, type, state);
2766 lanes |= type.lanes;
2767
2769 continue;
2770 }
2771
2772 bool ok = true;
2773 switch (output_type.code) {
2774 case halide_type_uint:
2775 // Compare normalized representations
2776 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.u64, 0) ==
2777 constant_fold_bin_op<Add>(output_type, val_after.u.u64, 0));
2778 break;
2779 case halide_type_int:
2780 ok &= (constant_fold_bin_op<Add>(output_type, val_before.u.i64, 0) ==
2781 constant_fold_bin_op<Add>(output_type, val_after.u.i64, 0));
2782 break;
2783 case halide_type_float:
2784 case halide_type_bfloat: {
2785 double error = std::abs(val_before.u.f64 - val_after.u.f64);
2786 // We accept an equal bit pattern (e.g. inf vs inf),
2787 // a small floating point difference, or turning a nan into not-a-nan.
2788 ok &= (error < 0.01 ||
2789 val_before.u.u64 == val_after.u.u64 ||
2790 std::isnan(val_before.u.f64));
2791 break;
2792 }
2793 default:
2794 return;
2795 }
2796
2797 if (!ok) {
2798 debug(0) << "Fails with values:\n";
2799 for (int i = 0; i < max_wild; i++) {
2801 state.get_bound_const(i, val, wildcard_type);
2802 debug(0) << " c" << i << ": " << make_const_expr(val, wildcard_type) << "\n";
2803 }
2804 for (int i = 0; i < max_wild; i++) {
2805 debug(0) << " _" << i << ": " << Expr(state.get_binding(i)) << "\n";
2806 }
2807 debug(0) << " Before: " << make_const_expr(val_before, output_type) << "\n";
2808 debug(0) << " After: " << make_const_expr(val_after, output_type) << "\n";
2809 debug(0) << val_before.u.u64 << " " << val_after.u.u64 << "\n";
2811 }
2812 }
2813}
2814
2815template<typename Before,
2816 typename After,
2817 typename Predicate,
2818 typename = typename std::enable_if<!(std::decay<Before>::type::foldable &&
2819 std::decay<After>::type::foldable)>::type>
2820HALIDE_ALWAYS_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred,
2821 halide_type_t, halide_type_t, int dummy = 0) noexcept {
2822 // We can't verify rewrite rules that can't be constant-folded.
2823}
2824
2826bool evaluate_predicate(bool x, MatcherState &) noexcept {
2827 return x;
2828}
2829
2830template<typename Pattern,
2831 typename = typename enable_if_pattern<Pattern>::type>
2834 halide_type_t ty = halide_type_of<bool>();
2835 p.make_folded_const(c, ty, state);
2836 // Overflow counts as a failed predicate
2837 return (c.u.u64 != 0) && ((ty.lanes & MatcherState::special_values_mask) == 0);
2838}
2839
2840// #defines for testing
2841
2842// Print all successful or failed matches
2843#define HALIDE_DEBUG_MATCHED_RULES 0
2844#define HALIDE_DEBUG_UNMATCHED_RULES 0
2845
2846// Set to true if you want to fuzz test every rewrite passed to
2847// operator() to ensure the input and the output have the same value
2848// for lots of random values of the wildcards. Run
2849// correctness_simplify with this on.
2850#define HALIDE_FUZZ_TEST_RULES 0
2851
2852template<typename Instance>
2853struct Rewriter {
2854 Instance instance;
2859
2862 : instance(std::move(instance)), output_type(ot), wildcard_type(wt) {
2863 }
2864
2865 template<typename After>
2867 result = after.make(state, output_type);
2868 }
2869
2870 template<typename Before,
2871 typename After,
2872 typename = typename enable_if_pattern<Before>::type,
2873 typename = typename enable_if_pattern<After>::type>
2874 HALIDE_ALWAYS_INLINE bool operator()(Before before, After after) {
2875 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2876 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2877 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2878#if HALIDE_FUZZ_TEST_RULES
2879 fuzz_test_rule(before, after, true, wildcard_type, output_type);
2880#endif
2881 if (before.template match<0>(unwrap(instance), state)) {
2882 build_replacement(after);
2883#if HALIDE_DEBUG_MATCHED_RULES
2884 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2885#endif
2886 return true;
2887 } else {
2888#if HALIDE_DEBUG_UNMATCHED_RULES
2889 debug(0) << instance << " does not match " << before << "\n";
2890#endif
2891 return false;
2892 }
2893 }
2894
2895 template<typename Before,
2896 typename = typename enable_if_pattern<Before>::type>
2897 HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept {
2898 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2899 if (before.template match<0>(unwrap(instance), state)) {
2900 result = after;
2901#if HALIDE_DEBUG_MATCHED_RULES
2902 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2903#endif
2904 return true;
2905 } else {
2906#if HALIDE_DEBUG_UNMATCHED_RULES
2907 debug(0) << instance << " does not match " << before << "\n";
2908#endif
2909 return false;
2910 }
2911 }
2912
2913 template<typename Before,
2914 typename = typename enable_if_pattern<Before>::type>
2915 HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept {
2916 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2917#if HALIDE_FUZZ_TEST_RULES
2918 fuzz_test_rule(before, IntLiteral(after), true, wildcard_type, output_type);
2919#endif
2920 if (before.template match<0>(unwrap(instance), state)) {
2921 result = make_const(output_type, after);
2922#if HALIDE_DEBUG_MATCHED_RULES
2923 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << "\n";
2924#endif
2925 return true;
2926 } else {
2927#if HALIDE_DEBUG_UNMATCHED_RULES
2928 debug(0) << instance << " does not match " << before << "\n";
2929#endif
2930 return false;
2931 }
2932 }
2933
2934 template<typename Before,
2935 typename After,
2936 typename Predicate,
2937 typename = typename enable_if_pattern<Before>::type,
2938 typename = typename enable_if_pattern<After>::type,
2939 typename = typename enable_if_pattern<Predicate>::type>
2940 HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred) {
2941 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2942 static_assert((Before::binds & After::binds) == After::binds, "Rule result uses unbound values");
2943 static_assert((Before::binds & Predicate::binds) == Predicate::binds, "Rule predicate uses unbound values");
2944 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2945 static_assert(After::canonical, "RHS of rewrite rule should be in canonical form");
2946
2947#if HALIDE_FUZZ_TEST_RULES
2948 fuzz_test_rule(before, after, pred, wildcard_type, output_type);
2949#endif
2950 if (before.template match<0>(unwrap(instance), state) &&
2951 evaluate_predicate(pred, state)) {
2952 build_replacement(after);
2953#if HALIDE_DEBUG_MATCHED_RULES
2954 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2955#endif
2956 return true;
2957 } else {
2958#if HALIDE_DEBUG_UNMATCHED_RULES
2959 debug(0) << instance << " does not match " << before << "\n";
2960#endif
2961 return false;
2962 }
2963 }
2964
2965 template<typename Before,
2966 typename Predicate,
2967 typename = typename enable_if_pattern<Before>::type,
2968 typename = typename enable_if_pattern<Predicate>::type>
2969 HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred) {
2970 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2971 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2972
2973 if (before.template match<0>(unwrap(instance), state) &&
2974 evaluate_predicate(pred, state)) {
2975 result = after;
2976#if HALIDE_DEBUG_MATCHED_RULES
2977 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
2978#endif
2979 return true;
2980 } else {
2981#if HALIDE_DEBUG_UNMATCHED_RULES
2982 debug(0) << instance << " does not match " << before << "\n";
2983#endif
2984 return false;
2985 }
2986 }
2987
2988 template<typename Before,
2989 typename Predicate,
2990 typename = typename enable_if_pattern<Before>::type,
2991 typename = typename enable_if_pattern<Predicate>::type>
2992 HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred) {
2993 static_assert(Predicate::foldable, "Predicates must consist only of operations that can constant-fold");
2994 static_assert(Before::canonical, "LHS of rewrite rule should be in canonical form");
2995#if HALIDE_FUZZ_TEST_RULES
2996 fuzz_test_rule(before, IntLiteral(after), pred, wildcard_type, output_type);
2997#endif
2998 if (before.template match<0>(unwrap(instance), state) &&
2999 evaluate_predicate(pred, state)) {
3000 result = make_const(output_type, after);
3001#if HALIDE_DEBUG_MATCHED_RULES
3002 debug(0) << instance << " -> " << result << " via " << before << " -> " << after << " when " << pred << "\n";
3003#endif
3004 return true;
3005 } else {
3006#if HALIDE_DEBUG_UNMATCHED_RULES
3007 debug(0) << instance << " does not match " << before << "\n";
3008#endif
3009 return false;
3010 }
3011 }
3012};
3013
3014/** Construct a rewriter for the given instance, which may be a pattern
3015 * with concrete expressions as leaves, or just an expression. The
3016 * second optional argument (wildcard_type) is a hint as to what the
3017 * type of the wildcards is likely to be. If omitted it uses the same
3018 * type as the expression itself. They are not required to be this
3019 * type, but the rule will only be tested for wildcards of that type
3020 * when testing is enabled.
3021 *
3022 * The rewriter can be used to check to see if the instance is one of
3023 * some number of patterns and if so rewrite it into another form,
3024 * using its operator() method. See Simplify.cpp for a bunch of
3025 * example usage.
3026 *
3027 * Important: Any Exprs in patterns are captured by reference, not by
3028 * value, so ensure they outlive the rewriter.
3029 */
3030// @{
3031template<typename Instance,
3032 typename = typename enable_if_pattern<Instance>::type>
3033HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3034 return {pattern_arg(instance), output_type, wildcard_type};
3035}
3036
3037template<typename Instance,
3038 typename = typename enable_if_pattern<Instance>::type>
3039HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type) noexcept -> Rewriter<decltype(pattern_arg(instance))> {
3040 return {pattern_arg(instance), output_type, output_type};
3041}
3042
3044auto rewriter(const Expr &e, halide_type_t wildcard_type) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3045 return {pattern_arg(e), e.type(), wildcard_type};
3046}
3047
3049auto rewriter(const Expr &e) noexcept -> Rewriter<decltype(pattern_arg(e))> {
3050 return {pattern_arg(e), e.type(), e.type()};
3051}
3052// @}
3053
3054} // namespace IRMatcher
3055
3056} // namespace Internal
3057} // namespace Halide
3058
3059#endif
#define internal_error
Definition Errors.h:23
@ halide_type_float
IEEE floating point numbers.
@ halide_type_bfloat
floating point numbers in the bfloat format
@ halide_type_int
signed integers
@ halide_type_uint
unsigned integers
#define HALIDE_NEVER_INLINE
#define HALIDE_ALWAYS_INLINE
Subtypes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Methods to test Exprs and Stmts for equality of value.
Defines various operator overloads and utility functions that make it more pleasant to work with Hali...
For optional debugging during codegen, use the debug class as follows:
Definition Debug.h:49
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1586
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1578
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
Definition IRMatch.h:3033
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
Definition IRMatch.h:567
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1527
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
Definition IRMatch.h:1275
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
Definition IRMatch.h:1643
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1076
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
Definition IRMatch.h:2826
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1032
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
Definition IRMatch.h:1250
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
Definition IRMatch.h:2020
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1170
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:921
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2598
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
Definition IRMatch.h:217
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
Definition IRMatch.h:1301
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
Definition IRMatch.h:1932
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
Definition IRMatch.h:1150
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Definition IRMatch.h:2345
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
Definition IRMatch.h:1522
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1180
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:987
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1574
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1590
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1535
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
Definition IRMatch.h:928
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
Definition IRMatch.h:1027
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1552
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
Definition IRMatch.h:994
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1098
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:2177
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1856
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1020
HALIDE_ALWAYS_INLINE auto widen(A &&a) noexcept -> WidenOp< decltype(pattern_arg(a))>
Definition IRMatch.h:2113
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1548
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1061
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1306
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Definition IRMatch.h:559
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1145
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
Definition IRMatch.h:2066
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
Definition IRMatch.h:2263
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1540
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
Definition IRMatch.h:576
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1047
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:968
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
Definition IRMatch.h:2557
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
Definition IRMatch.h:2226
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
Definition IRMatch.h:1649
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1566
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1105
constexpr bool and_reduce()
Definition IRMatch.h:1330
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1270
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1544
constexpr int max_wild
Definition IRMatch.h:74
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1245
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
Definition IRMatch.h:2430
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1195
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1120
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1296
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0, int lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2513
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
Definition IRMatch.h:1938
constexpr bool commutative(IRNodeType t)
Definition IRMatch.h:615
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1531
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
Definition IRMatch.h:961
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
Definition IRMatch.h:1926
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
Definition IRMatch.h:1792
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0, int lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
Definition IRMatch.h:2468
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition IRMatch.h:1719
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
Definition IRMatch.h:2641
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1083
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
Definition IRMatch.h:2695
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1155
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1570
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1556
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1001
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1594
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1582
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1205
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:954
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
Definition IRMatch.h:1175
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
Definition IRMatch.h:1125
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
Definition IRMatch.h:2678
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1130
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
Definition IRMatch.h:1920
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
Definition IRMatch.h:1914
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:1280
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
Definition IRMatch.h:160
constexpr uint32_t bitwise_or_reduce()
Definition IRMatch.h:1321
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
Definition IRMatch.h:1598
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1230
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
Definition IRMatch.h:149
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
Definition IRMatch.h:1200
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< decltype(pattern_arg(a))>
Definition IRMatch.h:1560
constexpr int const_min(int a, int b)
Definition IRMatch.h:1340
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
Definition IRMatch.h:1255
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition IRMatch.h:1054
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Definition IRMatch.h:1220
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
Definition IRMatch.h:935
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
Definition IRMatch.h:2393
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
Definition IRMatch.h:1225
T div_imp(T a, T b)
Definition IROperator.h:268
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
Expr make_zero(Type t)
Construct the representation of zero in the given type.
void expr_match_test()
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
bool equal(const RDom &bounds0, const RDom &bounds1)
Return true if bounds0 and bounds1 represent the same bounds.
constexpr IRNodeType StrongestExprNodeType
Definition Expr.h:81
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition IROperator.h:247
bool sub_would_overflow(int bits, int64_t a, int64_t b)
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
DstType reinterpret_bits(const SrcType &src)
An aggressive form of reinterpret cast used for correct type-punning.
Definition Util.h:135
bool mul_would_overflow(int bits, int64_t a, int64_t b)
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
ConstantInterval abs(const ConstantInterval &a)
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
Definition Expr.h:25
bool is_const(const Expr &e)
Is the expression either an IntImm, a FloatImm, a StringImm, or a Cast of the same,...
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
@ C
No name mangling.
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop.
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT16_TYPE__ uint16_t
unsigned __INT32_TYPE__ uint32_t
A fragment of Halide syntax.
Definition Expr.h:258
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition Expr.h:327
HALIDE_ALWAYS_INLINE const Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
Definition Expr.h:321
A base class for expression nodes.
Definition Expr.h:143
A vector with 'lanes' elements, in which every element is 'value'.
Definition IR.h:259
static Expr make(Expr value, int lanes)
static const IRNodeType _node_type
Definition IR.h:265
A function call.
Definition IR.h:490
bool is_intrinsic() const
Definition IR.h:721
static const IRNodeType _node_type
Definition IR.h:766
std::vector< Expr > args
Definition IR.h:492
The actual IR nodes begin here.
Definition IR.h:30
static const IRNodeType _node_type
Definition IR.h:35
Floating point constants.
Definition Expr.h:236
static const FloatImm * make(Type t, double value)
static constexpr bool canonical
Definition IRMatch.h:641
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:664
static constexpr uint32_t binds
Definition IRMatch.h:633
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:645
static constexpr bool foldable
Definition IRMatch.h:661
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:707
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:655
static constexpr IRNodeType max_node_type
Definition IRMatch.h:636
static constexpr IRNodeType min_node_type
Definition IRMatch.h:635
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1734
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1758
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1752
static constexpr uint32_t binds
Definition IRMatch.h:1732
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1740
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1735
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1775
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2382
static constexpr uint32_t binds
Definition IRMatch.h:2372
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2375
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2376
static constexpr bool foldable
Definition IRMatch.h:2379
static constexpr bool canonical
Definition IRMatch.h:2377
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2034
static constexpr bool foldable
Definition IRMatch.h:2056
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2038
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2033
static constexpr uint32_t binds
Definition IRMatch.h:2031
static constexpr bool canonical
Definition IRMatch.h:2035
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:2047
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2052
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:808
static constexpr IRNodeType max_node_type
Definition IRMatch.h:747
static constexpr uint32_t binds
Definition IRMatch.h:744
static constexpr bool canonical
Definition IRMatch.h:748
static constexpr bool foldable
Definition IRMatch.h:771
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:755
static constexpr IRNodeType min_node_type
Definition IRMatch.h:746
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:774
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:765
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2190
static constexpr uint32_t binds
Definition IRMatch.h:2187
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2189
static constexpr bool canonical
Definition IRMatch.h:2191
static constexpr bool foldable
Definition IRMatch.h:2217
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
Definition IRMatch.h:2194
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2220
static constexpr IRNodeType max_node_type
Definition IRMatch.h:495
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:504
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
Definition IRMatch.h:499
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Definition IRMatch.h:527
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:539
static constexpr IRNodeType min_node_type
Definition IRMatch.h:494
static constexpr bool canonical
Definition IRMatch.h:496
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:532
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
Definition IRMatch.h:522
static constexpr uint32_t binds
Definition IRMatch.h:492
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1370
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1357
static constexpr bool canonical
Definition IRMatch.h:1358
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1405
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
Definition IRMatch.h:1400
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1465
static constexpr uint32_t binds
Definition IRMatch.h:1354
static constexpr bool foldable
Definition IRMatch.h:1463
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
Definition IRMatch.h:1363
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
Definition IRMatch.h:1387
std::tuple< Args... > args
Definition IRMatch.h:1348
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1375
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Definition IRMatch.h:1396
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
Definition IRMatch.h:1508
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1356
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2320
static constexpr bool canonical
Definition IRMatch.h:2322
static constexpr bool foldable
Definition IRMatch.h:2328
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2321
static constexpr uint32_t binds
Definition IRMatch.h:2317
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2331
static constexpr bool foldable
Definition IRMatch.h:2416
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2419
static constexpr bool canonical
Definition IRMatch.h:2414
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2412
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2413
static constexpr uint32_t binds
Definition IRMatch.h:2409
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2451
static constexpr bool foldable
Definition IRMatch.h:2454
static constexpr uint32_t binds
Definition IRMatch.h:2447
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2457
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2450
static constexpr bool canonical
Definition IRMatch.h:2452
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2576
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2577
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2583
static constexpr uint32_t binds
Definition IRMatch.h:2573
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2617
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2624
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2618
static constexpr uint32_t binds
Definition IRMatch.h:2614
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2540
static constexpr uint32_t binds
Definition IRMatch.h:2536
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2546
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2539
static constexpr bool foldable
Definition IRMatch.h:2543
static constexpr bool canonical
Definition IRMatch.h:2541
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2502
static constexpr bool foldable
Definition IRMatch.h:2499
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2495
static constexpr bool canonical
Definition IRMatch.h:2497
static constexpr uint32_t binds
Definition IRMatch.h:2492
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2496
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2661
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:2667
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2660
static constexpr bool foldable
Definition IRMatch.h:2664
static constexpr uint32_t binds
Definition IRMatch.h:2657
static constexpr bool canonical
Definition IRMatch.h:2662
To save stack space, the matcher objects are largely stateless and immutable.
Definition IRMatch.h:82
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
Definition IRMatch.h:127
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
Definition IRMatch.h:103
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
Definition IRMatch.h:115
static constexpr uint16_t special_values_mask
Definition IRMatch.h:88
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
Definition IRMatch.h:121
halide_type_t bound_const_type[max_wild]
Definition IRMatch.h:90
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
Definition IRMatch.h:93
HALIDE_ALWAYS_INLINE MatcherState() noexcept
Definition IRMatch.h:134
HALIDE_ALWAYS_INLINE const BaseExprNode * get_binding(int i) const noexcept
Definition IRMatch.h:98
halide_scalar_value_t bound_const[max_wild]
Definition IRMatch.h:84
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
Definition IRMatch.h:109
static constexpr uint16_t signed_integer_overflow
Definition IRMatch.h:87
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1956
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1971
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
Definition IRMatch.h:1966
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1980
static constexpr uint32_t binds
Definition IRMatch.h:1948
static constexpr bool canonical
Definition IRMatch.h:1953
static constexpr bool foldable
Definition IRMatch.h:1977
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1951
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1950
static constexpr uint32_t binds
Definition IRMatch.h:1607
static constexpr bool foldable
Definition IRMatch.h:1632
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1614
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1610
static constexpr bool canonical
Definition IRMatch.h:1611
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1623
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1628
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1635
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1609
static constexpr uint32_t binds
Definition IRMatch.h:2277
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2281
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2285
static constexpr bool canonical
Definition IRMatch.h:2282
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2294
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2302
static constexpr bool foldable
Definition IRMatch.h:2299
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2280
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:2253
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2246
static constexpr uint32_t binds
Definition IRMatch.h:2242
static constexpr bool canonical
Definition IRMatch.h:2248
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2247
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1834
static constexpr bool canonical
Definition IRMatch.h:1809
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1807
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1806
static constexpr uint32_t binds
Definition IRMatch.h:1804
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1827
static constexpr bool foldable
Definition IRMatch.h:1846
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1812
HALIDE_NEVER_INLINE void build_replacement(After after)
Definition IRMatch.h:2866
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
Definition IRMatch.h:2940
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
Definition IRMatch.h:2915
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
Definition IRMatch.h:2861
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
Definition IRMatch.h:2969
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
Definition IRMatch.h:2897
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
Definition IRMatch.h:2992
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Definition IRMatch.h:2874
static constexpr uint32_t binds
Definition IRMatch.h:1667
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:1699
static constexpr bool foldable
Definition IRMatch.h:1696
static constexpr bool canonical
Definition IRMatch.h:1672
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
Definition IRMatch.h:1685
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1675
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1692
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1670
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1669
static constexpr bool canonical
Definition IRMatch.h:2130
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2129
static constexpr bool foldable
Definition IRMatch.h:2159
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
Definition IRMatch.h:2162
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2128
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2133
static constexpr uint32_t binds
Definition IRMatch.h:2126
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2147
static constexpr IRNodeType min_node_type
Definition IRMatch.h:198
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:205
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:210
static constexpr IRNodeType max_node_type
Definition IRMatch.h:199
static constexpr uint32_t binds
Definition IRMatch.h:195
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:1889
static constexpr IRNodeType min_node_type
Definition IRMatch.h:1871
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:1876
static constexpr IRNodeType max_node_type
Definition IRMatch.h:1872
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:1896
static constexpr uint32_t binds
Definition IRMatch.h:2076
static constexpr bool canonical
Definition IRMatch.h:2080
static constexpr bool foldable
Definition IRMatch.h:2103
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:2097
static constexpr IRNodeType max_node_type
Definition IRMatch.h:2079
HALIDE_ALWAYS_INLINE bool match(const WidenOp< A2 > &op, MatcherState &state) const noexcept
Definition IRMatch.h:2092
static constexpr IRNodeType min_node_type
Definition IRMatch.h:2078
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:2083
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:352
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:373
static constexpr IRNodeType max_node_type
Definition IRMatch.h:348
static constexpr IRNodeType min_node_type
Definition IRMatch.h:347
static constexpr uint32_t binds
Definition IRMatch.h:345
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:383
static constexpr bool canonical
Definition IRMatch.h:403
static constexpr IRNodeType max_node_type
Definition IRMatch.h:402
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:431
static constexpr uint32_t binds
Definition IRMatch.h:399
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:406
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:441
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
Definition IRMatch.h:425
static constexpr IRNodeType min_node_type
Definition IRMatch.h:401
static constexpr bool foldable
Definition IRMatch.h:438
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:267
static constexpr uint32_t binds
Definition IRMatch.h:226
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Definition IRMatch.h:277
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:233
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
Definition IRMatch.h:254
static constexpr IRNodeType min_node_type
Definition IRMatch.h:228
static constexpr IRNodeType max_node_type
Definition IRMatch.h:229
static constexpr uint32_t binds
Definition IRMatch.h:292
static constexpr IRNodeType max_node_type
Definition IRMatch.h:295
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:299
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
Definition IRMatch.h:330
static constexpr IRNodeType min_node_type
Definition IRMatch.h:294
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:320
static constexpr IRNodeType min_node_type
Definition IRMatch.h:459
static constexpr uint32_t binds
Definition IRMatch.h:457
static constexpr IRNodeType max_node_type
Definition IRMatch.h:460
static constexpr bool canonical
Definition IRMatch.h:461
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
Definition IRMatch.h:473
static constexpr bool foldable
Definition IRMatch.h:477
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Definition IRMatch.h:464
static constexpr uint32_t mask
Definition IRMatch.h:146
IRNodeType node_type
Each IR node subclass has a unique identifier.
Definition Expr.h:113
Integer constants.
Definition Expr.h:218
static const IntImm * make(Type t, int64_t value)
Logical not - true if the expression false.
Definition IR.h:193
static Expr make(Expr a)
A linear ramp vector node.
Definition IR.h:247
static const IRNodeType _node_type
Definition IR.h:253
static Expr make(Expr base, Expr stride, int lanes)
A ternary operator.
Definition IR.h:204
static Expr make(Expr condition, Expr true_value, Expr false_value)
static const IRNodeType _node_type
Definition IR.h:209
Construct a new vector by taking elements from another sequence of vectors.
Definition IR.h:855
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector.
std::vector< Expr > vectors
Definition IR.h:856
bool is_slice() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:909
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
Definition IR.h:906
The difference of two expressions.
Definition IR.h:65
static const IRNodeType _node_type
Definition IR.h:70
static Expr make(Expr a, Expr b)
Unsigned integer constants.
Definition Expr.h:227
static const UIntImm * make(Type t, uint64_t value)
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition IR.h:979
static const IRNodeType _node_type
Definition IR.h:998
static Expr make(Operator op, Expr vec, int lanes)
Types in the halide type system.
Definition Type.h:283
Type widen() const
Return Type with the same type code and number of lanes, but with at least twice as many bits.
Definition Type.h:378
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition Type.h:435
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
Definition Type.h:355
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
Definition Type.h:441
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition Type.h:349
HALIDE_ALWAYS_INLINE bool is_vector() const
Is this type a vector type? (lanes() != 1).
Definition Type.h:410
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition Type.h:417
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition Type.h:423
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
union halide_scalar_value_t::@3 u
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint16_t lanes
How many elements in a vector.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.