Go to the documentation of this file. 1 #ifndef HALIDE_IR_MATCH_H
2 #define HALIDE_IR_MATCH_H
37 bool expr_match(
const Expr &pattern,
const Expr &expr, std::vector<Expr> &result);
51 bool expr_match(
const Expr &pattern,
const Expr &expr, std::map<std::string, Expr> &result);
139 typename =
typename std::remove_reference<T>::type::pattern_tag>
146 constexpr
static uint32_t mask = std::remove_reference<T>::type::binds;
166 const int lanes = scalar_type.
lanes;
167 scalar_type.
lanes = 1;
170 switch (scalar_type.
code) {
198 ((a.type == b.type) &&
199 (a.node_type == b.node_type) &&
216 template<u
int32_t bound>
244 template<u
int32_t bound>
246 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
249 op = ((
const Broadcast *)op)->value.get();
258 state.get_bound_const(i, val, type);
261 state.set_bound_const(i, value, e.type);
265 template<u
int32_t bound>
267 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
271 state.get_bound_const(i, val, type);
272 return type == i64_type && value == val.
u.
i64;
274 state.set_bound_const(i, value, i64_type);
310 template<u
int32_t bound>
312 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
315 op = ((
const Broadcast *)op)->value.get();
324 state.get_bound_const(i, val, type);
327 state.set_bound_const(i, value, e.type);
343 state.get_bound_const(i, val, ty);
363 template<u
int32_t bound>
365 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
368 op = ((
const Broadcast *)op)->value.get();
373 double value = ((
const FloatImm *)op)->value;
377 state.get_bound_const(i, val, type);
380 state.set_bound_const(i, value, e.type);
396 state.get_bound_const(i, val, ty);
417 template<u
int32_t bound>
419 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
422 op = ((
const Broadcast *)op)->value.get();
436 template<u
int32_t bound>
438 static_assert(i >= 0 && i <
max_wild,
"Wild with out-of-range index");
454 state.get_bound_const(i, val, ty);
475 template<u
int32_t bound>
478 return equal(*state.get_binding(i), e);
480 state.set_binding(i, e);
515 template<u
int32_t bound>
519 op = ((
const Broadcast *)op)->value.get();
527 return ((
const FloatImm *)op)->value == (
double)
v;
533 template<u
int32_t bound>
538 template<u
int32_t bound>
562 val.u.f64 = (double)
v;
578 typename =
typename std::decay<T>::type::pattern_tag>
589 static_assert(!std::is_same<
typename std::decay<T>::type,
Expr>::value || std::is_lvalue_reference<T>::value,
590 "Exprs are captured by reference by IRMatcher objects and so must be lvalues");
601 typename =
typename std::decay<T>::type::pattern_tag,
603 typename =
typename std::enable_if<!std::is_same<typename std::decay<T>::type, SpecificExpr>::value>::type>
618 template<
typename Op>
621 template<
typename Op>
624 template<
typename Op>
639 template<
typename Op,
typename A,
typename B>
654 A::canonical && B::canonical && (!
commutative(Op::_node_type) || (A::max_node_type >= B::min_node_type));
656 template<u
int32_t bound>
658 if (e.node_type != Op::_node_type) {
661 const Op &op = (
const Op &)e;
662 return (
a.template match<bound>(*op.a.get(), state) &&
663 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
666 template<u
int32_t bound,
typename Op2,
typename A2,
typename B2>
668 return (std::is_same<Op, Op2>::value &&
669 a.template match<bound>(
unwrap(op.a), state) &&
673 constexpr
static bool foldable = A::foldable && B::foldable;
678 if (std::is_same<A, IntLiteral>::value) {
679 b.make_folded_const(val_b, ty, state);
680 if ((std::is_same<Op, And>::value && val_b.
u.
u64 == 0) ||
681 (std::is_same<Op, Or>::value && val_b.
u.
u64 == 1)) {
687 a.make_folded_const(val_a, ty, state);
690 a.make_folded_const(val_a, ty, state);
691 if ((std::is_same<Op, And>::value && val_a.
u.
u64 == 0) ||
692 (std::is_same<Op, Or>::value && val_a.
u.
u64 == 1)) {
698 b.make_folded_const(val_b, ty, state);
703 val.u.i64 = constant_fold_bin_op<Op>(ty, val_a.
u.
i64, val_b.
u.
i64);
706 val.u.u64 = constant_fold_bin_op<Op>(ty, val_a.
u.
u64, val_b.
u.
u64);
710 val.u.f64 = constant_fold_bin_op<Op>(ty, val_a.
u.
f64, val_b.
u.
f64);
721 if (std::is_same<A, IntLiteral>::value) {
722 eb =
b.make(state, type_hint);
723 ea =
a.make(state, eb.
type());
725 ea =
a.make(state, type_hint);
726 eb =
b.make(state, ea.
type());
736 return Op::make(std::move(ea), std::move(eb));
740 template<
typename Op>
743 template<
typename Op>
746 template<
typename Op>
750 template<
typename Op,
typename A,
typename B>
762 (!
commutative(Op::_node_type) || A::max_node_type >= B::min_node_type) &&
766 template<u
int32_t bound>
768 if (e.node_type != Op::_node_type) {
771 const Op &op = (
const Op &)e;
772 return (
a.template match<bound>(*op.a.get(), state) &&
773 b.template match<bound | bindings<A>::mask>(*op.b.get(), state));
776 template<u
int32_t bound,
typename Op2,
typename A2,
typename B2>
778 return (std::is_same<Op, Op2>::value &&
779 a.template match<bound>(
unwrap(op.a), state) &&
783 constexpr
static bool foldable = A::foldable && B::foldable;
789 if (std::is_same<A, IntLiteral>::value) {
790 b.make_folded_const(val_b, ty, state);
792 a.make_folded_const(val_a, ty, state);
795 a.make_folded_const(val_a, ty, state);
797 b.make_folded_const(val_b, ty, state);
802 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
i64, val_b.
u.
i64);
805 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
u64, val_b.
u.
u64);
809 val.u.u64 = constant_fold_cmp_op<Op>(val_a.
u.
f64, val_b.
u.
f64);
823 if (std::is_same<A, IntLiteral>::value) {
824 eb =
b.make(state, {});
825 ea =
a.make(state, eb.
type());
827 ea =
a.make(state, {});
828 eb =
b.make(state, ea.
type());
838 return Op::make(std::move(ea), std::move(eb));
842 template<
typename A,
typename B>
844 s <<
"(" << op.
a <<
" + " << op.
b <<
")";
848 template<
typename A,
typename B>
850 s <<
"(" << op.
a <<
" - " << op.
b <<
")";
854 template<
typename A,
typename B>
856 s <<
"(" << op.
a <<
" * " << op.
b <<
")";
860 template<
typename A,
typename B>
862 s <<
"(" << op.
a <<
" / " << op.
b <<
")";
866 template<
typename A,
typename B>
868 s <<
"(" << op.
a <<
" && " << op.
b <<
")";
872 template<
typename A,
typename B>
874 s <<
"(" << op.
a <<
" || " << op.
b <<
")";
878 template<
typename A,
typename B>
880 s <<
"min(" << op.
a <<
", " << op.
b <<
")";
884 template<
typename A,
typename B>
886 s <<
"max(" << op.
a <<
", " << op.
b <<
")";
890 template<
typename A,
typename B>
892 s <<
"(" << op.
a <<
" <= " << op.
b <<
")";
896 template<
typename A,
typename B>
898 s <<
"(" << op.
a <<
" < " << op.
b <<
")";
902 template<
typename A,
typename B>
904 s <<
"(" << op.
a <<
" >= " << op.
b <<
")";
908 template<
typename A,
typename B>
910 s <<
"(" << op.
a <<
" > " << op.
b <<
")";
914 template<
typename A,
typename B>
916 s <<
"(" << op.
a <<
" == " << op.
b <<
")";
920 template<
typename A,
typename B>
922 s <<
"(" << op.
a <<
" != " << op.
b <<
")";
926 template<
typename A,
typename B>
928 s <<
"(" << op.
a <<
" % " << op.
b <<
")";
932 template<
typename A,
typename B>
934 assert_is_lvalue_if_expr<A>();
935 assert_is_lvalue_if_expr<B>();
939 template<
typename A,
typename B>
941 assert_is_lvalue_if_expr<A>();
942 assert_is_lvalue_if_expr<B>();
949 int dead_bits = 64 - t.bits;
957 return (a + b) & (ones >> (64 - t.bits));
965 template<
typename A,
typename B>
967 assert_is_lvalue_if_expr<A>();
968 assert_is_lvalue_if_expr<B>();
972 template<
typename A,
typename B>
974 assert_is_lvalue_if_expr<A>();
975 assert_is_lvalue_if_expr<B>();
983 int dead_bits = 64 - t.bits;
990 return (a - b) & (ones >> (64 - t.bits));
998 template<
typename A,
typename B>
1000 assert_is_lvalue_if_expr<A>();
1001 assert_is_lvalue_if_expr<B>();
1005 template<
typename A,
typename B>
1007 assert_is_lvalue_if_expr<A>();
1008 assert_is_lvalue_if_expr<B>();
1015 int dead_bits = 64 - t.bits;
1023 return (a * b) & (ones >> (64 - t.bits));
1031 template<
typename A,
typename B>
1033 assert_is_lvalue_if_expr<A>();
1034 assert_is_lvalue_if_expr<B>();
1038 template<
typename A,
typename B>
1058 template<
typename A,
typename B>
1060 assert_is_lvalue_if_expr<A>();
1061 assert_is_lvalue_if_expr<B>();
1065 template<
typename A,
typename B>
1067 assert_is_lvalue_if_expr<A>();
1068 assert_is_lvalue_if_expr<B>();
1087 template<
typename A,
typename B>
1089 assert_is_lvalue_if_expr<A>();
1090 assert_is_lvalue_if_expr<B>();
1109 template<
typename A,
typename B>
1111 assert_is_lvalue_if_expr<A>();
1112 assert_is_lvalue_if_expr<B>();
1131 template<
typename A,
typename B>
1136 template<
typename A,
typename B>
1156 template<
typename A,
typename B>
1161 template<
typename A,
typename B>
1181 template<
typename A,
typename B>
1186 template<
typename A,
typename B>
1206 template<
typename A,
typename B>
1211 template<
typename A,
typename B>
1231 template<
typename A,
typename B>
1236 template<
typename A,
typename B>
1256 template<
typename A,
typename B>
1261 template<
typename A,
typename B>
1281 template<
typename A,
typename B>
1286 template<
typename A,
typename B>
1307 template<
typename A,
typename B>
1312 template<
typename A,
typename B>
1337 template<
typename... Args>
1346 template<
typename... Args>
1353 return a < b ? a : b;
1356 template<
typename... Args>
1374 typename =
typename std::enable_if<(i <
sizeof...(Args))>::type>
1376 using T = decltype(std::get<i>(
args));
1377 return (std::get<i>(
args).
template match<bound>(*c.args[i].get(), state) &&
1381 template<
int i, u
int32_t binds>
1386 template<u
int32_t bound>
1394 match_args<0, bound>(0, c, state));
1398 typename =
typename std::enable_if<(i <
sizeof...(Args))>::type>
1400 s << std::get<i>(
args);
1401 if (i + 1 <
sizeof...(Args)) {
1404 print_args<i + 1>(0, s);
1413 print_args<0>(0, s);
1418 Expr arg0 = std::get<0>(
args).make(state, type_hint);
1431 return absd(arg0, arg1);
1455 return arg0 << arg1;
1457 return arg0 >> arg1;
1483 std::get<0>(
args).make_folded_const(val, ty, state);
1488 std::get<1>(
args).make_folded_const(arg1, signed_ty, state);
1491 if (arg1.
u.
i64 < 0) {
1494 val.u.i64 >>= -arg1.
u.
i64;
1497 val.u.u64 >>= -arg1.
u.
i64;
1500 val.u.u64 <<= arg1.
u.
i64;
1503 if (arg1.
u.
i64 > 0) {
1506 val.u.i64 >>= arg1.
u.
i64;
1509 val.u.u64 >>= arg1.
u.
i64;
1512 val.u.u64 <<= -arg1.
u.
i64;
1525 template<
typename... Args>
1533 template<
typename... Args>
1538 template<
typename A,
typename B>
1542 template<
typename A,
typename B>
1546 template<
typename A,
typename B>
1551 template<
typename A,
typename B>
1555 template<
typename A,
typename B>
1559 template<
typename A,
typename B>
1563 template<
typename A,
typename B>
1567 template<
typename A,
typename B>
1571 template<
typename A>
1577 template<
typename A,
typename B>
1581 template<
typename A,
typename B>
1585 template<
typename A,
typename B>
1589 template<
typename A,
typename B>
1593 template<
typename A,
typename B>
1597 template<
typename A,
typename B>
1601 template<
typename A,
typename B>
1605 template<
typename A,
typename B,
typename C>
1609 template<
typename A,
typename B,
typename C>
1614 template<
typename A>
1625 template<u
int32_t bound>
1630 const Not &op = (
const Not &)e;
1631 return (
a.template match<bound>(*op.
a.
get(), state));
1634 template<u
int32_t bound,
typename A2>
1636 return a.template match<bound>(
unwrap(op.a), state);
1646 template<
typename A1 = A>
1648 a.make_folded_const(val, ty, state);
1649 val.u.u64 = ~val.u.u64;
1654 template<
typename A>
1656 assert_is_lvalue_if_expr<A>();
1660 template<
typename A>
1662 assert_is_lvalue_if_expr<A>();
1666 template<
typename A>
1668 s <<
"!(" << op.
a <<
")";
1672 template<
typename C,
typename T,
typename F>
1684 constexpr
static bool canonical = C::canonical && T::canonical && F::canonical;
1686 template<u
int32_t bound>
1692 return (
c.template match<bound>(*op.
condition.
get(), state) &&
1693 t.template match<bound | bindings<C>::mask>(*op.
true_value.
get(), state) &&
1696 template<u
int32_t bound,
typename C2,
typename T2,
typename F2>
1698 return (
c.template match<bound>(
unwrap(instance.c), state) &&
1705 return Select::make(
c.make(state, {}),
t.make(state, type_hint),
f.make(state, type_hint));
1708 constexpr
static bool foldable = C::foldable && T::foldable && F::foldable;
1710 template<
typename C1 = C>
1714 c.make_folded_const(c_val, c_ty, state);
1715 if ((c_val.
u.
u64 & 1) == 1) {
1716 t.make_folded_const(val, ty, state);
1718 f.make_folded_const(val, ty, state);
1724 template<
typename C,
typename T,
typename F>
1726 s <<
"select(" << op.
c <<
", " << op.
t <<
", " << op.
f <<
")";
1730 template<
typename C,
typename T,
typename F>
1732 assert_is_lvalue_if_expr<C>();
1733 assert_is_lvalue_if_expr<T>();
1734 assert_is_lvalue_if_expr<F>();
1738 template<
typename A,
typename B>
1749 constexpr
static bool canonical = A::canonical && B::canonical;
1751 template<u
int32_t bound>
1755 if (
a.template match<bound>(*op.
value.
get(), state) &&
1756 lanes.template match<bound>(op.
lanes, state)) {
1763 template<u
int32_t bound,
typename A2,
typename B2>
1765 return (
a.template match<bound>(
unwrap(op.a), state) &&
1773 lanes.make_folded_const(lanes_val, ty, state);
1775 type_hint.
lanes /= l;
1776 Expr val =
a.make(state, type_hint);
1786 template<
typename A1 = A>
1790 lanes.make_folded_const(lanes_val, lanes_ty, state);
1792 a.make_folded_const(val, ty, state);
1797 template<
typename A,
typename B>
1799 s <<
"broadcast(" << op.
a <<
", " << op.
lanes <<
")";
1803 template<
typename A,
typename B>
1805 assert_is_lvalue_if_expr<A>();
1809 template<
typename A,
typename B,
typename C>
1821 constexpr
static bool canonical = A::canonical && B::canonical && C::canonical;
1823 template<u
int32_t bound>
1829 if (
a.template match<bound>(*op.
base.
get(), state) &&
1830 b.template match<bound | bindings<A>::mask>(*op.
stride.
get(), state) &&
1838 template<u
int32_t bound,
typename A2,
typename B2,
typename C2>
1840 return (
a.template match<bound>(
unwrap(op.a), state) &&
1849 lanes.make_folded_const(lanes_val, ty, state);
1851 type_hint.
lanes /= l;
1853 eb =
b.make(state, type_hint);
1854 ea =
a.make(state, eb.type());
1861 template<
typename A,
typename B,
typename C>
1863 s <<
"ramp(" << op.
a <<
", " << op.
b <<
", " << op.
lanes <<
")";
1867 template<
typename A,
typename B,
typename C>
1869 assert_is_lvalue_if_expr<A>();
1870 assert_is_lvalue_if_expr<B>();
1871 assert_is_lvalue_if_expr<C>();
1875 template<
typename A,
typename B, VectorReduce::Operator reduce_op>
1887 template<u
int32_t bound>
1891 if (op.
op == reduce_op &&
1892 a.template match<bound>(*op.
value.
get(), state) &&
1893 lanes.template match<bound | bindings<A>::mask>(op.
type.
lanes(), state)) {
1900 template<u
int32_t bound,
typename A2,
typename B2, VectorReduce::Operator reduce_op_2>
1902 return (reduce_op == reduce_op_2 &&
1903 a.template match<bound>(
unwrap(op.a), state) &&
1911 lanes.make_folded_const(lanes_val, ty, state);
1912 int l = (int)lanes_val.
u.
i64;
1919 template<
typename A,
typename B, VectorReduce::Operator reduce_op>
1921 s <<
"vector_reduce(" << reduce_op <<
", " << op.
a <<
", " << op.
lanes <<
")";
1925 template<
typename A,
typename B>
1927 assert_is_lvalue_if_expr<A>();
1931 template<
typename A,
typename B>
1933 assert_is_lvalue_if_expr<A>();
1937 template<
typename A,
typename B>
1939 assert_is_lvalue_if_expr<A>();
1943 template<
typename A,
typename B>
1945 assert_is_lvalue_if_expr<A>();
1949 template<
typename A,
typename B>
1951 assert_is_lvalue_if_expr<A>();
1955 template<
typename A>
1967 template<u
int32_t bound>
1972 const Sub &op = (
const Sub &)e;
1973 return (
a.template match<bound>(*op.
b.
get(), state) &&
1977 template<u
int32_t bound,
typename A2>
1979 return a.template match<bound>(
unwrap(p.a), state);
1984 Expr ea =
a.make(state, type_hint);
1986 return Sub::make(std::move(z), std::move(ea));
1991 template<
typename A1 = A>
1993 a.make_folded_const(val, ty, state);
1994 int dead_bits = 64 - ty.bits;
1997 if (ty.bits >= 32 && val.u.u64 && (val.u.u64 << (65 - ty.bits)) == 0) {
2006 val.u.u64 = ((-val.u.u64) << dead_bits) >> dead_bits;
2010 val.u.f64 = -val.u.f64;
2019 template<
typename A>
2025 template<
typename A>
2027 assert_is_lvalue_if_expr<A>();
2031 template<
typename A>
2033 assert_is_lvalue_if_expr<A>();
2037 template<
typename A>
2049 template<u
int32_t bound>
2055 return (e.type ==
t &&
2056 a.template match<bound>(*op.
value.
get(), state));
2058 template<u
int32_t bound,
typename A2>
2060 return t == op.t &&
a.template match<bound>(
unwrap(op.a), state);
2065 return cast(
t,
a.make(state, {}));
2071 template<
typename A>
2073 s <<
"cast(" << op.
t <<
", " << op.
a <<
")";
2077 template<
typename A>
2079 assert_is_lvalue_if_expr<A>();
2083 template<
typename Vec,
typename Base,
typename Str
ide,
typename Lanes>
2091 static constexpr
uint32_t binds = Vec::binds | Base::binds | Stride::binds | Lanes::binds;
2095 constexpr
static bool canonical = Vec::canonical && Base::canonical && Stride::canonical && Lanes::canonical;
2097 template<u
int32_t bound>
2103 return v.
vectors.size() == 1 &&
2104 vec.template match<bound>(*v.
vectors[0].get(), state) &&
2105 base.template match<bound | bindings<Vec>::mask>(v.
slice_begin(), state) &&
2114 base.make_folded_const(base_val, ty, state);
2115 int b = (int)base_val.
u.
i64;
2116 stride.make_folded_const(stride_val, ty, state);
2117 int s = (int)stride_val.
u.
i64;
2118 lanes.make_folded_const(lanes_val, ty, state);
2119 int l = (int)lanes_val.
u.
i64;
2128 static_assert(Base::foldable,
"Base of slice should consist only of operations that constant-fold");
2129 static_assert(Stride::foldable,
"Stride of slice should consist only of operations that constant-fold");
2130 static_assert(Lanes::foldable,
"Lanes of slice should consist only of operations that constant-fold");
2134 template<
typename Vec,
typename Base,
typename Str
ide,
typename Lanes>
2136 s <<
"slice(" << op.
vec <<
", " << op.
base <<
", " << op.
stride <<
", " << op.
lanes <<
")";
2140 template<
typename Vec,
typename Base,
typename Str
ide,
typename Lanes>
2146 template<
typename A>
2161 a.make_folded_const(c, ty, state);
2167 if (type_hint.bits) {
2171 c.
u.
f64 = (double)x;
2173 ty.
code = type_hint.code;
2174 ty.
bits = type_hint.bits;
2183 template<
typename A1 = A>
2185 a.make_folded_const(val, ty, state);
2189 template<
typename A>
2191 assert_is_lvalue_if_expr<A>();
2195 template<
typename A>
2197 s <<
"fold(" << op.
a <<
")";
2201 template<
typename A>
2216 template<
typename A1 = A>
2218 a.make_folded_const(val, ty, state);
2226 template<
typename A>
2228 assert_is_lvalue_if_expr<A>();
2232 template<
typename A>
2234 s <<
"overflows(" << op.
a <<
")";
2248 template<u
int32_t bound>
2277 template<
typename A>
2294 template<
typename A1 = A>
2296 Expr e =
a.make(state, {});
2308 template<
typename A>
2310 assert_is_lvalue_if_expr<A>();
2314 template<
typename A>
2316 assert_is_lvalue_if_expr<A>();
2320 template<
typename A>
2323 s <<
"is_const(" << op.
a <<
")";
2325 s <<
"is_const(" << op.
a <<
", " << op.
v <<
")";
2330 template<
typename A,
typename Prover>
2347 Expr condition =
a.make(state, {});
2348 condition =
prover->mutate(condition,
nullptr);
2356 template<
typename A,
typename Prover>
2358 assert_is_lvalue_if_expr<A>();
2362 template<
typename A,
typename Prover>
2364 s <<
"can_prove(" << op.
a <<
")";
2368 template<
typename A>
2385 Type t =
a.make(state, {}).type();
2393 template<
typename A>
2395 assert_is_lvalue_if_expr<A>();
2399 template<
typename A>
2401 s <<
"is_float(" << op.
a <<
")";
2405 template<
typename A>
2423 Type t =
a.make(state, {}).type();
2431 template<
typename A>
2433 assert_is_lvalue_if_expr<A>();
2437 template<
typename A>
2439 s <<
"is_int(" << op.
a;
2441 s <<
", " << op.
bits;
2444 s <<
", " << op.
lanes;
2450 template<
typename A>
2468 Type t =
a.make(state, {}).type();
2476 template<
typename A>
2478 assert_is_lvalue_if_expr<A>();
2482 template<
typename A>
2484 s <<
"is_uint(" << op.
a;
2486 s <<
", " << op.
bits;
2489 s <<
", " << op.
lanes;
2495 template<
typename A>
2512 Type t =
a.make(state, {}).type();
2520 template<
typename A>
2522 assert_is_lvalue_if_expr<A>();
2526 template<
typename A>
2528 s <<
"is_scalar(" << op.
a <<
")";
2532 template<
typename A>
2549 a.make_folded_const(val, ty, state);
2552 val.
u.
u64 = (val.
u.
u64 == max_bits);
2561 template<
typename A>
2563 assert_is_lvalue_if_expr<A>();
2567 template<
typename A>
2569 s <<
"is_max_value(" << op.
a <<
")";
2573 template<
typename A>
2590 a.make_folded_const(val, ty, state);
2593 val.
u.
u64 = (val.
u.
u64 == min_bits);
2604 template<
typename A>
2606 assert_is_lvalue_if_expr<A>();
2610 template<
typename A>
2612 s <<
"is_min_value(" << op.
a <<
")";
2616 template<
typename A>
2633 Type t =
a.make(state, {}).type();
2641 template<
typename A>
2643 assert_is_lvalue_if_expr<A>();
2647 template<
typename A>
2649 s <<
"lanes_of(" << op.
a <<
")";
2654 template<
typename Before,
2657 typename =
typename std::enable_if<std::decay<Before>::type::foldable &&
2658 std::decay<After>::type::foldable>::type>
2663 wildcard_type.lanes = output_type.lanes = 1;
2666 static std::set<uint32_t> tested;
2668 if (!tested.insert(reinterpret_bits<uint32_t>(wildcard_type)).second) {
2673 debug(0) <<
"validate('" << before <<
"', '" << after <<
"', '" << pred <<
"', " <<
Type(wildcard_type) <<
", " <<
Type(output_type) <<
")\n";
2678 static std::mt19937_64 rng(0);
2683 for (
int trials = 0; trials < 100; trials++) {
2687 int shift = (int)(rng() & (wildcard_type.bits - 1));
2689 for (
int i = 0; i <
max_wild; i++) {
2691 switch (wildcard_type.code) {
2711 double val = ((
int64_t)(rng() & 15) - 8) / 2.0;
2713 val = ((
int64_t)(rng() & 15) - 8) / 2.0;
2727 before.make_folded_const(val_before, type, state);
2729 after.make_folded_const(val_after, type, state);
2730 lanes |= type.
lanes;
2737 switch (output_type.code) {
2752 ok &= (error < 0.01 ||
2753 val_before.
u.
u64 == val_after.
u.
u64 ||
2754 std::isnan(val_before.
u.
f64));
2762 debug(0) <<
"Fails with values:\n";
2763 for (
int i = 0; i <
max_wild; i++) {
2768 for (
int i = 0; i <
max_wild; i++) {
2773 debug(0) << val_before.
u.
u64 <<
" " << val_after.
u.
u64 <<
"\n";
2779 template<
typename Before,
2782 typename =
typename std::enable_if<!(std::decay<Before>::type::foldable &&
2783 std::decay<After>::type::foldable)>::type>
2794 template<
typename Pattern,
2795 typename =
typename enable_if_pattern<Pattern>::type>
2799 p.make_folded_const(c, ty, state);
2807 #define HALIDE_DEBUG_MATCHED_RULES 0
2808 #define HALIDE_DEBUG_UNMATCHED_RULES 0
2814 #define HALIDE_FUZZ_TEST_RULES 0
2816 template<
typename Instance>
2829 template<
typename After>
2834 template<
typename Before,
2839 static_assert((Before::binds & After::binds) == After::binds,
"Rule result uses unbound values");
2840 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2841 static_assert(After::canonical,
"RHS of rewrite rule should be in canonical form");
2842 #if HALIDE_FUZZ_TEST_RULES
2847 #if HALIDE_DEBUG_MATCHED_RULES
2852 #if HALIDE_DEBUG_UNMATCHED_RULES
2853 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2859 template<
typename Before,
2862 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2865 #if HALIDE_DEBUG_MATCHED_RULES
2870 #if HALIDE_DEBUG_UNMATCHED_RULES
2871 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2877 template<
typename Before,
2880 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2881 #if HALIDE_FUZZ_TEST_RULES
2886 #if HALIDE_DEBUG_MATCHED_RULES
2891 #if HALIDE_DEBUG_UNMATCHED_RULES
2892 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2898 template<
typename Before,
2905 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2906 static_assert((Before::binds & After::binds) == After::binds,
"Rule result uses unbound values");
2907 static_assert((Before::binds & Predicate::binds) == Predicate::binds,
"Rule predicate uses unbound values");
2908 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2909 static_assert(After::canonical,
"RHS of rewrite rule should be in canonical form");
2911 #if HALIDE_FUZZ_TEST_RULES
2917 #if HALIDE_DEBUG_MATCHED_RULES
2918 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2922 #if HALIDE_DEBUG_UNMATCHED_RULES
2923 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2929 template<
typename Before,
2934 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2935 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2940 #if HALIDE_DEBUG_MATCHED_RULES
2941 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2945 #if HALIDE_DEBUG_UNMATCHED_RULES
2946 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2952 template<
typename Before,
2957 static_assert(Predicate::foldable,
"Predicates must consist only of operations that can constant-fold");
2958 static_assert(Before::canonical,
"LHS of rewrite rule should be in canonical form");
2959 #if HALIDE_FUZZ_TEST_RULES
2965 #if HALIDE_DEBUG_MATCHED_RULES
2966 debug(0) <<
instance <<
" -> " <<
result <<
" via " << before <<
" -> " << after <<
" when " << pred <<
"\n";
2970 #if HALIDE_DEBUG_UNMATCHED_RULES
2971 debug(0) <<
instance <<
" does not match " << before <<
"\n";
2995 template<
typename Instance,
2996 typename =
typename enable_if_pattern<Instance>::type>
2998 return {
pattern_arg(instance), output_type, wildcard_type};
3001 template<
typename Instance,
3002 typename =
typename enable_if_pattern<Instance>::type>
3004 return {
pattern_arg(instance), output_type, output_type};
static constexpr uint32_t binds
signed __INT32_TYPE__ int32_t
HALIDE_ALWAYS_INLINE auto not_op(A &&a) -> decltype(IRMatcher::operator!(a))
HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp< decltype(pattern_arg(a))>
constexpr static bool foldable
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr int const_min(int a, int b)
constexpr static uint32_t binds
@ signed_integer_overflow
HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) noexcept -> Intrin< decltype(pattern_arg(args))... >
The sum of two expressions.
int slice_stride() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr()
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Div >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LE >(int64_t a, int64_t b) noexcept
constexpr static bool foldable
HALIDE_ALWAYS_INLINE auto operator%(A &&a, B &&b) noexcept -> BinOp< Mod, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
constexpr static bool canonical
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool canonical
constexpr static bool foldable
constexpr uint32_t bitwise_or_reduce()
static constexpr uint16_t signed_integer_overflow
constexpr static bool canonical
constexpr static bool canonical
constexpr static bool foldable
HALIDE_ALWAYS_INLINE auto operator+(A &&a, B &&b) noexcept -> BinOp< Add, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
auto widen_right_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE auto gt(A &&a, B &&b) -> decltype(IRMatcher::operator>(a, b))
constexpr static bool canonical
constexpr static bool canonical
HALIDE_ALWAYS_INLINE auto operator&&(A &&a, B &&b) noexcept -> BinOp< And, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool match_args(double, const Call &c, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const VectorReduceOp< A2, B2, reduce_op_2 > &op, MatcherState &state) const noexcept
constexpr IRNodeType StrongestExprNodeType
std::vector< Expr > vectors
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static bool canonical
constexpr static bool canonical
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< LT >(int64_t a, int64_t b) noexcept
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
HALIDE_ALWAYS_INLINE SliceOp(Vec v, Base b, Stride s, Lanes l)
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Rewriter(Instance instance, halide_type_t ot, halide_type_t wt)
@ halide_type_bfloat
floating point numbers in the bfloat format
constexpr static IRNodeType min_node_type
auto widening_mul(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
HALIDE_NEVER_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
unsigned __INT16_TYPE__ uint16_t
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static bool foldable
bool sub_would_overflow(int bits, int64_t a, int64_t b)
constexpr static IRNodeType min_node_type
@ Predicate
Guard the loads and stores in the loop with an if statement that prevents evaluation beyond the origi...
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after, Predicate pred)
constexpr static bool foldable
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
HALIDE_ALWAYS_INLINE auto ne(A &&a, B &&b) -> decltype(IRMatcher::operator!=(a, b))
HALIDE_ALWAYS_INLINE Expr make_const_expr(halide_scalar_value_t val, halide_type_t ty)
HALIDE_ALWAYS_INLINE void set_bound_const(int i, halide_scalar_value_t val, halide_type_t t) noexcept
constexpr static bool foldable
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mul >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool foldable
Is the first expression greater than or equal to the second.
HALIDE_ALWAYS_INLINE MatcherState() noexcept
constexpr static bool foldable
@ halide_type_float
IEEE floating point numbers.
auto shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
IRNodeType
All our IR node types get unique IDs for the purposes of RTTI.
HALIDE_NEVER_INLINE void fuzz_test_rule(Before &&before, After &&after, Predicate &&pred, halide_type_t wildcard_type, halide_type_t output_type) noexcept
constexpr static IRNodeType min_node_type
constexpr static bool canonical
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool operator()(Before before, const Expr &after) noexcept
constexpr static bool canonical
constexpr static bool foldable
constexpr static uint32_t binds
const BaseExprNode & expr
auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
constexpr static IRNodeType max_node_type
int64_t constant_fold_bin_op(halide_type_t &, int64_t, int64_t) noexcept
Floating point constants.
HALIDE_ALWAYS_INLINE auto h_max(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Max >
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE int lanes() const
Return the number of vector elements in this type.
auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto min(A &&a, B &&b) noexcept -> BinOp< Min, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
halide_type_t output_type
static const UIntImm * make(Type t, uint64_t value)
constexpr static uint32_t binds
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
A vector with 'lanes' elements, in which every element is 'value'.
The ratio of two expressions.
HALIDE_ALWAYS_INLINE auto is_uint(A &&a, int bits=0, int lanes=0) noexcept -> IsUInt< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
constexpr bool commutative(IRNodeType t)
HALIDE_ALWAYS_INLINE bool is_uint() const
Is this type an unsigned integer type?
constexpr static IRNodeType min_node_type
bool is_const_zero(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to zero (in all lanes,...
auto halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool match(const BroadcastOp< A2, B2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType max_node_type
constexpr static uint32_t mask
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
constexpr static uint32_t binds
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE auto lt(A &&a, B &&b) -> decltype(IRMatcher::operator<(a, b))
uint8_t bits
The number of bits of precision of a single scalar value of this type.
HALIDE_ALWAYS_INLINE auto operator||(A &&a, B &&b) noexcept -> BinOp< Or, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto is_const(A &&a) noexcept -> IsConst< decltype(pattern_arg(a))>
Expr with_lanes(const Expr &x, int lanes)
Rewrite the expression x to have lanes lanes.
HALIDE_ALWAYS_INLINE bool match(const IntLiteral &b, MatcherState &state) const noexcept
Expr absd(Expr a, Expr b)
Return the absolute difference between two values.
HALIDE_ALWAYS_INLINE bool match(int64_t e, MatcherState &state) const noexcept
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE bool match_args(int, const Call &c, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const NotOp< A2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool foldable
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void set_binding(int i, const BaseExprNode &n) noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static uint32_t binds
constexpr static bool canonical
std::ostream & operator<<(std::ostream &s, const SpecificExpr &e)
@ rounding_mul_shift_right
constexpr static IRNodeType min_node_type
constexpr static bool foldable
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after, Predicate pred)
The actual IR nodes begin here.
Is the first expression less than or equal to the second.
Expr make_signed_integer_overflow(Type type)
Construct a unique signed_integer_overflow Expr.
static const IRNodeType _node_type
A runtime tag for a type in the halide type system.
bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept
auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Is the first expression not equal to the second.
constexpr static bool foldable
constexpr static bool canonical
auto widening_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(NegateOp< A2 > &&p, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE auto operator<(A &&a, B &&b) noexcept -> CmpOp< LT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto operator<=(A &&a, B &&b) noexcept -> CmpOp< LE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE auto is_int(A &&a, int bits=0, int lanes=0) noexcept -> IsInt< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
HALIDE_ALWAYS_INLINE auto operator>(A &&a, B &&b) noexcept -> CmpOp< GT, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto or_op(A &&a, B &&b) -> decltype(IRMatcher::operator||(a, b))
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(int64_t value, MatcherState &state) const noexcept
constexpr static bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
auto widen_right_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
unsigned __INT64_TYPE__ uint64_t
int slice_begin() const
Check if this shuffle is a contiguous strict subset of the vector arguments, and if so,...
uint64_t constant_fold_cmp_op(int64_t, int64_t) noexcept
HALIDE_ALWAYS_INLINE auto overflows(A &&a) noexcept -> Overflows< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE bool is_vector() const
Is this type a vector type? (lanes() != 1).
const HALIDE_ALWAYS_INLINE BaseExprNode * get_binding(int i) const noexcept
HALIDE_ALWAYS_INLINE bool operator()(Before before, After after)
Types in the halide type system.
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
#define HALIDE_NEVER_INLINE
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Mod >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE auto operator*(A &&a, B &&b) noexcept -> BinOp< Mul, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static Expr make(Expr value, int lanes)
auto saturating_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
bool expr_match(const Expr &pattern, const Expr &expr, std::vector< Expr > &result)
Does the first expression have the same structure as the second? Variables in the first expression wi...
bool add_would_overflow(int bits, int64_t a, int64_t b)
Routines to test if math would overflow for signed integers with the given number of bits.
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GT >(int64_t a, int64_t b) noexcept
constexpr static bool foldable
constexpr static bool foldable
static Expr make_slice(Expr vector, int begin, int stride, int size)
Convenience constructor for making a shuffle representing a contiguous subset of a vector.
HALIDE_ALWAYS_INLINE auto is_const(A &&a, int64_t value) noexcept -> IsConst< decltype(pattern_arg(a))>
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
constexpr static bool canonical
Logical or - is at least one of the expression true.
Is the first expression equal to the second.
constexpr static IRNodeType max_node_type
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
constexpr static uint32_t binds
constexpr static bool canonical
auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b)), decltype(pattern_arg(c))>
HALIDE_ALWAYS_INLINE auto negate(A &&a) -> decltype(IRMatcher::operator-(a))
constexpr static bool canonical
constexpr static IRNodeType max_node_type
constexpr static bool foldable
halide_scalar_value_t is a simple union able to represent all the well-known scalar values in a filte...
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE T pattern_arg(T t)
constexpr static IRNodeType max_node_type
@ Internal
Not visible externally, similar to 'static' linkage in C.
The greater of two values.
constexpr static IRNodeType min_node_type
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE auto sub(A &&a, B &&b) -> decltype(IRMatcher::operator-(a, b))
static constexpr uint32_t binds
constexpr static IRNodeType max_node_type
Expr make_zero(Type t)
Construct the representation of zero in the given type.
HALIDE_NEVER_INLINE Expr make_const_special_expr(halide_type_t ty)
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
constexpr static uint32_t binds
constexpr static IRNodeType max_node_type
auto shift_left(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE auto slice(Vec vec, Base base, Stride stride, Lanes lanes) noexcept -> SliceOp< decltype(pattern_arg(vec)), decltype(pattern_arg(base)), decltype(pattern_arg(stride)), decltype(pattern_arg(lanes))>
HALIDE_ALWAYS_INLINE auto is_scalar(A &&a) noexcept -> IsScalar< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE auto eq(A &&a, B &&b) -> decltype(IRMatcher::operator==(a, b))
constexpr static IRNodeType min_node_type
static Expr make(Operator op, Expr vec, int lanes)
constexpr static IRNodeType max_node_type
union halide_scalar_value_t::@4 u
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static uint32_t binds
static Expr make(Expr condition, Expr true_value, Expr false_value)
HALIDE_ALWAYS_INLINE auto le(A &&a, B &&b) -> decltype(IRMatcher::operator<=(a, b))
constexpr static bool foldable
constexpr static uint32_t binds
constexpr static IRNodeType max_node_type
constexpr static bool foldable
constexpr static bool canonical
constexpr static bool canonical
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
A linear ramp vector node.
static Expr make(Expr base, Expr stride, int lanes)
#define HALIDE_ALWAYS_INLINE
HALIDE_ALWAYS_INLINE auto is_min_value(A &&a) noexcept -> IsMinValue< decltype(pattern_arg(a))>
constexpr bool and_reduce()
HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
uint16_t lanes
How many elements in a vector.
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
halide_scalar_value_t bound_const[max_wild]
For optional debugging during codegen, use the debug class as follows:
HALIDE_ALWAYS_INLINE void set_bound_const(int i, int64_t s, halide_type_t t) noexcept
constexpr static bool canonical
HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const
constexpr static IRNodeType max_node_type
signed __INT64_TYPE__ int64_t
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
A base class for expression nodes.
static constexpr uint16_t special_values_mask
@ halide_type_uint
unsigned integers
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
Expr abs(Expr a)
Returns the absolute value of a signed integer or floating-point expression.
HALIDE_ALWAYS_INLINE Intrin(Call::IntrinsicOp intrin, Args... args) noexcept
HALIDE_ALWAYS_INLINE void print_args(std::ostream &s) const
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Expr make_const(Type t, int64_t val)
Construct an immediate of the given type from any numeric C++ type.
HALIDE_ALWAYS_INLINE auto operator!(A &&a) noexcept -> NotOp< decltype(pattern_arg(a))>
constexpr static uint32_t binds
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< EQ >(int64_t a, int64_t b) noexcept
constexpr static bool canonical
The lesser of two values.
constexpr static IRNodeType max_node_type
To save stack space, the matcher objects are largely stateless and immutable.
static const IntImm * make(Type t, int64_t value)
auto widening_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
constexpr static IRNodeType max_node_type
constexpr static bool foldable
HALIDE_ALWAYS_INLINE void get_bound_const(int i, halide_scalar_value_t &val, halide_type_t &type) const noexcept
HALIDE_ALWAYS_INLINE void set_bound_const(int i, double f, halide_type_t t) noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE int64_t unwrap(IntLiteral t)
Expr likely_if_innermost(Expr e)
Equivalent to likely, but only triggers a loop partitioning if found in an innermost loop.
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE auto is_max_value(A &&a) noexcept -> IsMaxValue< decltype(pattern_arg(a))>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const CastOp< A2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
HALIDE_ALWAYS_INLINE bool evaluate_predicate(bool x, MatcherState &) noexcept
HALIDE_ALWAYS_INLINE bool match(const BinOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE auto operator!=(A &&a, B &&b) noexcept -> CmpOp< NE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
static const FloatImm * make(Type t, double value)
constexpr static bool foldable
constexpr static uint32_t binds
constexpr static bool foldable
static const IRNodeType _node_type
constexpr static IRNodeType max_node_type
constexpr static bool foldable
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE void set_bound_const(int i, uint64_t u, halide_type_t t) noexcept
static const IRNodeType _node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const noexcept
halide_type_t wildcard_type
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after, Predicate pred)
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
halide_type_t bound_const_type[max_wild]
HALIDE_ALWAYS_INLINE auto div(A &&a, B &&b) -> decltype(IRMatcher::operator/(a, b))
HALIDE_ALWAYS_INLINE auto fold(A &&a) noexcept -> Fold< decltype(pattern_arg(a))>
bool is_intrinsic() const
HALIDE_ALWAYS_INLINE IntLiteral(int64_t v)
constexpr static bool foldable
constexpr static bool canonical
constexpr static uint32_t binds
constexpr static bool canonical
constexpr static bool canonical
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
HALIDE_ALWAYS_INLINE auto max(A &&a, B &&b) noexcept -> BinOp< Max, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes))>
auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Add >(halide_type_t &t, int64_t a, int64_t b) noexcept
constexpr static uint32_t binds
constexpr static bool foldable
constexpr static bool foldable
Expr likely(Expr e)
Expressions tagged with this intrinsic are considered to be part of the steady state of some loop wit...
HALIDE_ALWAYS_INLINE auto mul(A &&a, B &&b) -> decltype(IRMatcher::operator*(a, b))
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static bool canonical
constexpr static bool foldable
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE auto rewriter(Instance instance, halide_type_t output_type, halide_type_t wildcard_type) noexcept -> Rewriter< decltype(pattern_arg(instance))>
Construct a rewriter for the given instance, which may be a pattern with concrete expressions as leav...
HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Add >
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
IRNodeType node_type
Each IR node subclass has a unique identifier.
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE auto operator>=(A &&a, B &&b) noexcept -> CmpOp< GE, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
constexpr static bool canonical
static Expr make(Expr a, Expr b)
HALIDE_ALWAYS_INLINE auto can_prove(A &&a, Prover *p) noexcept -> CanProve< decltype(pattern_arg(a)), Prover >
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE auto ge(A &&a, B &&b) -> decltype(IRMatcher::operator>=(a, b))
static const IRNodeType _node_type
static const IRNodeType _node_type
static const IRNodeType _node_type
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
A fragment of Halide syntax.
HALIDE_ALWAYS_INLINE auto h_and(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::And >
constexpr static bool foldable
constexpr static IRNodeType min_node_type
constexpr static bool foldable
bool mul_would_overflow(int bits, int64_t a, int64_t b)
constexpr static IRNodeType max_node_type
@ halide_type_int
signed integers
HALIDE_ALWAYS_INLINE bool match(int64_t val, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Min >
constexpr static bool canonical
auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
Is the first expression greater than the second.
constexpr static bool foldable
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE auto operator/(A &&a, B &&b) noexcept -> BinOp< Div, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
Construct a new vector by taking elements from another sequence of vectors.
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
HALIDE_ALWAYS_INLINE auto lanes_of(A &&a) noexcept -> LanesOf< decltype(pattern_arg(a))>
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE auto h_or(A &&a, B lanes) noexcept -> VectorReduceOp< decltype(pattern_arg(a)), decltype(pattern_arg(lanes)), VectorReduce::Or >
constexpr static IRNodeType max_node_type
constexpr static IRNodeType min_node_type
const HALIDE_ALWAYS_INLINE Internal::BaseExprNode * get() const
Override get() to return a BaseExprNode * instead of an IRNode *.
auto saturating_add(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
std::tuple< Args... > args
unsigned __INT32_TYPE__ uint32_t
constexpr static uint32_t binds
Unsigned integer constants.
constexpr static uint32_t binds
constexpr static uint32_t binds
constexpr static bool canonical
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Min >(halide_type_t &t, int64_t a, int64_t b) noexcept
constexpr static IRNodeType min_node_type
HALIDE_ALWAYS_INLINE auto operator-(A &&a, B &&b) noexcept -> BinOp< Sub, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< NE >(int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE uint64_t constant_fold_cmp_op< GE >(int64_t a, int64_t b) noexcept
constexpr static IRNodeType max_node_type
constexpr static IRNodeType min_node_type
bool is_const_one(const Expr &e)
Is the expression a const (as defined by is_const), and also equal to one (in all lanes,...
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Max >(halide_type_t &t, int64_t a, int64_t b) noexcept
constexpr static IRNodeType min_node_type
The difference of two expressions.
HALIDE_ALWAYS_INLINE auto and_op(A &&a, B &&b) -> decltype(IRMatcher::operator&&(a, b))
constexpr static bool foldable
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Sub >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< Or >(halide_type_t &t, int64_t a, int64_t b) noexcept
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE bool match(const CmpOp< Op2, A2, B2 > &op, MatcherState &state) const noexcept
HALIDE_ALWAYS_INLINE bool match(const SelectOp< C2, T2, F2 > &instance, MatcherState &state) const noexcept
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type
Logical and - are both expressions true.
HALIDE_ALWAYS_INLINE auto operator==(A &&a, B &&b) noexcept -> CmpOp< EQ, decltype(pattern_arg(a)), decltype(pattern_arg(b))>
HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept
auto halving_sub(A &&a, B &&b) noexcept -> Intrin< decltype(pattern_arg(a)), decltype(pattern_arg(b))>
constexpr static IRNodeType min_node_type
static const IRNodeType _node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE bool match(const RampOp< A2, B2, C2 > &op, MatcherState &state) const noexcept
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE bool operator()(Before before, int64_t after) noexcept
constexpr static bool canonical
Logical not - true if the expression false.
HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const
Is the first expression less than the second.
constexpr static uint32_t binds
HALIDE_ALWAYS_INLINE int64_t constant_fold_bin_op< And >(halide_type_t &t, int64_t a, int64_t b) noexcept
HALIDE_ALWAYS_INLINE bool equal(const BaseExprNode &a, const BaseExprNode &b) noexcept
The product of two expressions.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.
HALIDE_ALWAYS_INLINE auto add(A &&a, B &&b) -> decltype(IRMatcher::operator+(a, b))
HALIDE_NEVER_INLINE void build_replacement(After after)
constexpr static IRNodeType max_node_type
HALIDE_ALWAYS_INLINE Expr make(MatcherState &state, halide_type_t type_hint) const
constexpr static uint32_t binds
constexpr static IRNodeType min_node_type
constexpr static IRNodeType max_node_type