1 #ifndef HALIDE_GENERATOR_H_
2 #define HALIDE_GENERATOR_H_
262 #include <functional>
270 #include <type_traits>
281 #if !(__cplusplus >= 201703L || _MSVC_LANG >= 201703L)
282 #error "Halide requires C++17 or later; please upgrade your compiler."
287 class GeneratorContext;
299 for (
const auto &key_value : enum_map) {
300 if (t == key_value.second) {
301 return key_value.first;
304 user_error <<
"Enumeration value not found.\n";
310 auto it = enum_map.find(s);
311 user_assert(it != enum_map.end()) <<
"Enumeration value not found: " << s <<
"\n";
338 virtual std::vector<std::string>
enumerate()
const = 0;
380 template<
bool B,
typename T>
386 template<
typename First,
typename... Rest>
387 struct select_type : std::conditional<First::value, typename First::type, typename select_type<Rest...>::type> {};
389 template<
typename First>
391 using type =
typename std::conditional<First::value, typename First::type, void>::type;
401 inline const std::string &
name()
const {
413 #define HALIDE_GENERATOR_PARAM_TYPED_SETTER(TYPE) \
414 virtual void set(const TYPE &new_value) = 0;
432 #undef HALIDE_GENERATOR_PARAM_TYPED_SETTER
435 void set(
const std::string &new_value) {
438 void set(
const char *new_value) {
453 virtual std::string
call_to_string(
const std::string &v)
const = 0;
473 const std::string name_;
490 template<
typename FROM,
typename TO>
492 template<typename TO2 = TO, typename std::enable_if<!std::is_same<TO2, bool>::value>::type * =
nullptr>
493 inline static TO2
value(
const FROM &from) {
494 return static_cast<TO2
>(from);
497 template<typename TO2 = TO, typename std::enable_if<std::is_same<TO2, bool>::value>::type * =
nullptr>
498 inline static TO2
value(
const FROM &from) {
518 return this->
value();
525 #define HALIDE_GENERATOR_PARAM_TYPED_SETTER(TYPE) \
526 void set(const TYPE &new_value) override { \
527 typed_setter_impl<TYPE>(new_value, #TYPE); \
546 #undef HALIDE_GENERATOR_PARAM_TYPED_SETTER
549 void set(
const std::string &new_value) {
565 template<
typename FROM,
typename std::enable_if<
566 !std::is_convertible<FROM, T>::value>
::type * =
nullptr>
572 template<
typename FROM,
typename std::enable_if<
573 std::is_same<FROM, T>::value>
::type * =
nullptr>
580 template<
typename FROM,
typename std::enable_if<
581 !std::is_same<FROM, T>::value &&
582 std::is_convertible<FROM, T>::value &&
583 std::is_convertible<T, FROM>::value>
::type * =
nullptr>
588 if (value2 !=
value) {
595 template<
typename FROM,
typename std::enable_if<
596 !std::is_same<FROM, T>::value &&
597 std::is_convertible<FROM, T>::value &&
598 !std::is_convertible<T, FROM>::value>
::type * =
nullptr>
622 return this->
value().to_string();
626 std::ostringstream oss;
627 oss << v <<
".to_string()";
648 bool try_set(
const std::string &key,
const std::string &
value);
680 if (new_value_string ==
"root") {
682 }
else if (new_value_string ==
"inlined") {
685 user_error <<
"Unable to parse " << this->
name() <<
": " << new_value_string;
702 return "LoopLevel::inlined()";
704 return "LoopLevel::root()";
713 return std::string();
730 const T &min = std::numeric_limits<T>::lowest(),
738 user_assert(new_value >= min && new_value <= max) <<
"Value out of range: " << new_value;
743 std::istringstream iss(new_value_string);
748 if (
sizeof(T) ==
sizeof(
char) && !std::is_same<T, bool>::value) {
755 user_assert(!iss.fail() && iss.get() == EOF) <<
"Unable to parse: " << new_value_string;
760 std::ostringstream oss;
761 oss << this->
value();
762 if (std::is_same<T, float>::value) {
765 if (oss.str().find(
'.') == std::string::npos) {
774 std::ostringstream oss;
775 oss <<
"std::to_string(" << v <<
")";
780 std::ostringstream oss;
781 if (std::is_same<T, float>::value) {
783 }
else if (std::is_same<T, double>::value) {
785 }
else if (std::is_integral<T>::value) {
786 if (std::is_unsigned<T>::value) {
789 oss <<
"int" << (
sizeof(T) * 8) <<
"_t";
810 if (new_value_string ==
"true" || new_value_string ==
"True") {
812 }
else if (new_value_string ==
"false" || new_value_string ==
"False") {
815 user_assert(
false) <<
"Unable to parse bool: " << new_value_string;
821 return this->
value() ?
"true" :
"false";
825 std::ostringstream oss;
826 oss <<
"std::string((" << v <<
") ? \"true\" : \"false\")";
845 template<typename T2 = T, typename std::enable_if<!std::is_same<T2, Type>::value>
::type * =
nullptr>
851 auto it = enum_map.find(new_value_string);
852 user_assert(it != enum_map.end()) <<
"Enumeration value not found: " << new_value_string;
857 return "Enum_" + this->
name() +
"_map().at(" + v +
")";
861 return "Enum_" + this->
name();
869 std::ostringstream oss;
870 oss <<
"enum class Enum_" << this->
name() <<
" {\n";
871 for (
auto key_value : enum_map) {
872 oss <<
" " << key_value.first <<
",\n";
879 oss <<
"inline HALIDE_NO_USER_CODE_INLINE const std::map<Enum_" << this->
name() <<
", std::string>& Enum_" << this->
name() <<
"_map() {\n";
880 oss <<
" static const std::map<Enum_" << this->
name() <<
", std::string> m = {\n";
881 for (
auto key_value : enum_map) {
882 oss <<
" { Enum_" << this->
name() <<
"::" << key_value.first <<
", \"" << key_value.first <<
"\"},\n";
885 oss <<
" return m;\n";
891 const std::map<std::string, T> enum_map;
902 return "Halide::Internal::halide_type_to_enum_string(" + v +
")";
925 this->
set(new_value_string);
929 return "\"" + this->
value() +
"\"";
937 return "std::string";
943 typename select_type<
988 template<typename T2 = T, typename std::enable_if<!std::is_same<T2, std::string>::value>::type * =
nullptr>
997 GeneratorParam(
const std::string &name,
const T &value,
const std::map<std::string, T> &enum_map)
1009 template<
typename Other,
typename T>
1013 template<
typename Other,
typename T>
1022 template<
typename Other,
typename T>
1026 template<
typename Other,
typename T>
1035 template<
typename Other,
typename T>
1039 template<
typename Other,
typename T>
1048 template<
typename Other,
typename T>
1052 template<
typename Other,
typename T>
1061 template<
typename Other,
typename T>
1065 template<
typename Other,
typename T>
1074 template<
typename Other,
typename T>
1078 template<
typename Other,
typename T>
1087 template<
typename Other,
typename T>
1091 template<
typename Other,
typename T>
1100 template<
typename Other,
typename T>
1104 template<
typename Other,
typename T>
1113 template<
typename Other,
typename T>
1117 template<
typename Other,
typename T>
1126 template<
typename Other,
typename T>
1130 template<
typename Other,
typename T>
1139 template<
typename Other,
typename T>
1143 template<
typename Other,
typename T>
1152 template<
typename Other,
typename T>
1156 template<
typename Other,
typename T>
1160 template<
typename T>
1162 return (T)a && (T)b;
1169 template<
typename Other,
typename T>
1173 template<
typename Other,
typename T>
1177 template<
typename T>
1179 return (T)a || (T)b;
1187 namespace Internal {
1188 namespace GeneratorMinMax {
1193 template<
typename Other,
typename T>
1195 return min(a, (T)b);
1197 template<
typename Other,
typename T>
1199 return min((T)a, b);
1202 template<
typename Other,
typename T>
1204 return max(a, (T)b);
1206 template<
typename Other,
typename T>
1208 return max((T)a, b);
1217 template<
typename Other,
typename T>
1221 template<
typename Other,
typename T>
1230 template<
typename Other,
typename T>
1234 template<
typename Other,
typename T>
1241 template<
typename T>
1246 namespace Internal {
1248 template<
typename T2>
1263 template<
typename T2>
1265 template<
typename T2,
int D2>
1279 template<
typename T2,
int D2>
1295 template<
typename T2,
int D2>
1297 : parameter_(parameter_from_buffer(b)) {
1300 template<
typename T2>
1302 return {t.parameter_};
1305 template<
typename T2>
1307 std::vector<Parameter> r;
1308 r.reserve(v.size());
1309 for (
const auto &s : v) {
1310 r.push_back(s.parameter_);
1316 class AbstractGenerator;
1331 template<
typename... Args>
1336 template<
typename Dst>
1354 template<
typename T =
void>
1356 template<
typename T2>
1366 const std::shared_ptr<AbstractGenerator> &gen) {
1367 std::vector<StubOutputBuffer<T>> result;
1368 for (
const Func &
f : v) {
1388 template<
typename T2>
1455 const std::string &
name()
const;
1459 const std::vector<Type> &
gio_types()
const;
1465 const std::vector<Func> &
funcs()
const;
1466 const std::vector<Expr> &
exprs()
const;
1469 const std::string &
name,
1471 const std::vector<Type> &types,
1504 template<
typename ElemType>
1505 const std::vector<ElemType> &
get_values()
const;
1514 template<
typename T>
1526 inline const std::vector<Expr> &GIOBase::get_values<Expr>()
const {
1531 inline const std::vector<Func> &GIOBase::get_values<Func>()
const {
1538 const std::string &
name,
1540 const std::vector<Type> &t,
1553 void set_inputs(
const std::vector<StubInput> &inputs);
1577 template<
typename T,
typename ValueType>
1580 using TBase =
typename std::remove_all_extents<T>::type;
1583 return std::is_array<T>::value;
1586 template<
typename T2 = T,
typename std::enable_if<
1588 !std::is_array<T2>::value>::type * =
nullptr>
1593 template<
typename T2 = T,
typename std::enable_if<
1595 std::is_array<T2>::value && std::rank<T2>::value == 1 && (std::extent<T2, 0>::value > 0)>::type * =
nullptr>
1600 template<
typename T2 = T,
typename std::enable_if<
1602 std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0>::type * =
nullptr>
1608 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
1611 return get_values<ValueType>().size();
1614 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
1617 return get_values<ValueType>()[i];
1620 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
1621 const ValueType &
at(
size_t i)
const {
1623 return get_values<ValueType>().at(i);
1626 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
1627 typename std::vector<ValueType>::const_iterator
begin()
const {
1629 return get_values<ValueType>().begin();
1632 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
1633 typename std::vector<ValueType>::const_iterator
end()
const {
1635 return get_values<ValueType>().end();
1647 #define HALIDE_FORWARD_METHOD(Class, Method) \
1648 template<typename... Args> \
1649 inline auto Method(Args &&...args)->typename std::remove_reference<decltype(std::declval<Class>().Method(std::forward<Args>(args)...))>::type { \
1650 return this->template as<Class>().Method(std::forward<Args>(args)...); \
1653 #define HALIDE_FORWARD_METHOD_CONST(Class, Method) \
1654 template<typename... Args> \
1655 inline auto Method(Args &&...args) const-> \
1656 typename std::remove_reference<decltype(std::declval<Class>().Method(std::forward<Args>(args)...))>::type { \
1657 this->check_gio_access(); \
1658 return this->template as<Class>().Method(std::forward<Args>(args)...); \
1661 template<
typename T>
1662 class GeneratorInput_Buffer :
public GeneratorInputImpl<T, Func> {
1664 using Super = GeneratorInputImpl<T, Func>;
1669 friend class ::Halide::Func;
1670 friend class ::Halide::Stage;
1673 if (TBase::has_static_halide_type) {
1674 return "Halide::Internal::StubInputBuffer<" +
1678 return "Halide::Internal::StubInputBuffer<>";
1682 template<
typename T2>
1690 TBase::has_static_halide_type ? std::vector<
Type>{TBase::static_halide_type()} : std::vector<Type>{},
1691 TBase::has_static_dimensions ? TBase::static_dimensions() : -1) {
1696 static_assert(!TBase::has_static_halide_type,
"You can only specify a Type argument for Input<Buffer<T>> if T is void or omitted.");
1697 static_assert(!TBase::has_static_dimensions,
"You can only specify a dimension argument for Input<Buffer<T, D>> if D is -1 or omitted.");
1702 static_assert(!TBase::has_static_halide_type,
"You can only specify a Type argument for Input<Buffer<T>> if T is void or omitted.");
1707 TBase::has_static_halide_type ? std::vector<
Type>{TBase::static_halide_type()} : std::vector<Type>{},
1709 static_assert(!TBase::has_static_dimensions,
"You can only specify a dimension argument for Input<Buffer<T, D>> if D is -1 or omitted.");
1712 template<
typename... Args>
1714 this->check_gio_access();
1715 return Func(*
this)(std::forward<Args>(args)...);
1719 this->check_gio_access();
1720 return Func(*
this)(std::move(args));
1723 template<
typename T2>
1725 user_assert(!this->is_array()) <<
"Cannot assign an array type to a non-array type for Input " << this->name();
1730 this->check_gio_access();
1731 return this->funcs().at(0);
1735 this->check_gio_access();
1740 this->check_gio_access();
1741 this->set_estimate_impl(var,
min, extent);
1746 this->check_gio_access();
1747 this->set_estimates_impl(estimates);
1752 this->check_gio_access();
1757 this->check_gio_access();
1758 return Func(*this).
in(other);
1762 this->check_gio_access();
1763 return Func(*this).
in(others);
1767 this->check_gio_access();
1768 user_assert(!this->is_array()) <<
"Cannot convert an Input<Buffer<>[]> to an ImageParam; use an explicit subscript operator: " << this->name();
1772 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
1774 this->check_gio_access();
1775 return this->parameters_.size();
1778 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
1780 this->check_gio_access();
1781 return ImageParam(this->parameters_.at(i), this->funcs().at(i));
1784 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
1786 this->check_gio_access();
1787 return ImageParam(this->parameters_.at(i), this->funcs().at(i));
1790 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
1791 typename std::vector<ImageParam>::const_iterator
begin()
const {
1792 user_error <<
"Input<Buffer<>>::begin() is not supported.";
1796 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
1797 typename std::vector<ImageParam>::const_iterator
end()
const {
1798 user_error <<
"Input<Buffer<>>::end() is not supported.";
1823 template<
typename T>
1835 template<
typename T2>
1879 template<
typename... Args>
1882 return this->
funcs().at(0)(std::forward<Args>(args)...);
1887 return this->
funcs().at(0)(args);
1892 return this->
funcs().at(0);
1919 return Func(*this).
in(other);
1924 return Func(*this).
in(others);
1947 template<
typename T>
1952 static_assert(std::is_same<
typename std::remove_all_extents<T>::type,
Expr>::value,
"GeneratorInput_DynamicScalar is only legal to use with T=Expr for now");
1962 user_assert(!std::is_array<T>::value) <<
"Input<Expr[]> is not allowed";
1969 return this->
exprs().at(0);
1991 template<
typename T>
2015 template<typename TBase2 = TBase, typename std::enable_if<!std::is_pointer<TBase2>::value>
::type * =
nullptr>
2017 return cast<TBase>(
Expr(value));
2020 template<typename TBase2 = TBase, typename std::enable_if<std::is_pointer<TBase2>::value>
::type * =
nullptr>
2022 user_assert(value == 0) <<
"Zero is the only legal default value for Inputs which are pointer types.\n";
2036 const std::string &
name)
2041 const std::string &
name,
2050 return this->
exprs().at(0);
2060 template<typename T2 = T, typename std::enable_if<std::is_pointer<T2>::value>
::type * =
nullptr>
2063 user_assert(value ==
nullptr) <<
"nullptr is the only valid estimate for Input<PointerType>";
2070 template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value && !std::is_pointer<T2>::value>
::type * =
nullptr>
2074 if (std::is_same<T2, bool>::value) {
2082 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>
::type * =
nullptr>
2086 if (std::is_same<T2, bool>::value) {
2097 template<
typename T>
2110 if (!std::is_same<TBase, bool>::value) {
2133 const std::string &
name)
2138 const std::string &
name,
2151 const std::string &
name,
2164 template<
typename T2,
typename =
void>
2167 template<
typename T2>
2170 template<typename T, typename TBase = typename std::remove_all_extents<T>::type>
2181 template<
typename T>
2208 : Super(name, def) {
2212 : Super(array_size, name, def) {
2217 : Super(name, def,
min,
max) {
2222 : Super(array_size, name, def,
min,
max) {
2226 : Super(name, t, d) {
2239 : Super(array_size, name, t, d) {
2243 : Super(array_size, name, t) {
2249 : Super(array_size, name, d) {
2253 : Super(array_size, name) {
2257 namespace Internal {
2261 template<typename T2, typename std::enable_if<std::is_same<T2, Func>::value>::type * =
nullptr>
2263 static_assert(std::is_same<T2, Func>::value,
"Only Func allowed here");
2267 user_assert(
funcs_.size() == 1) <<
"Use [] to access individual Funcs in Output<Func[]>";
2332 #undef HALIDE_OUTPUT_FORWARD
2333 #undef HALIDE_OUTPUT_FORWARD_CONST
2337 const std::string &
name,
2339 const std::vector<Type> &t,
2344 const std::vector<Type> &t,
2351 void resize(
size_t size);
2367 template<
typename T>
2370 using TBase =
typename std::remove_all_extents<T>::type;
2374 return std::is_array<T>::value;
2377 template<
typename T2 = T,
typename std::enable_if<
2379 !std::is_array<T2>::value>::type * =
nullptr>
2384 template<
typename T2 = T,
typename std::enable_if<
2386 std::is_array<T2>::value && std::rank<T2>::value == 1 && (std::extent<T2, 0>::value > 0)>::type * =
nullptr>
2391 template<
typename T2 = T,
typename std::enable_if<
2393 std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0>::type * =
nullptr>
2399 template<
typename... Args,
typename T2 = T,
typename std::enable_if<!std::is_array<T2>::value>::type * =
nullptr>
2402 return get_values<ValueType>().at(0)(std::forward<Args>(args)...);
2405 template<typename ExprOrVar, typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * =
nullptr>
2408 return get_values<ValueType>().at(0)(args);
2411 template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * =
nullptr>
2414 return get_values<ValueType>().at(0);
2417 template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * =
nullptr>
2420 return get_values<ValueType>().at(0);
2423 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
2426 return get_values<ValueType>().size();
2429 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
2432 return get_values<ValueType>()[i];
2435 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
2438 return get_values<ValueType>().at(i);
2441 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
2442 typename std::vector<ValueType>::const_iterator
begin()
const {
2444 return get_values<ValueType>().begin();
2447 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
2448 typename std::vector<ValueType>::const_iterator
end()
const {
2450 return get_values<ValueType>().end();
2453 template<
typename T2 = T,
typename std::enable_if<
2455 std::is_array<T2>::value && std::rank<T2>::value == 1 && std::extent<T2, 0>::value == 0>::type * =
nullptr>
2462 template<
typename T>
2473 const auto &my_types = this->
gio_types();
2475 <<
"Cannot assign Func \"" << f.
name()
2476 <<
"\" to Output \"" << this->
name() <<
"\"\n"
2477 <<
"Output " << this->
name()
2478 <<
" is declared to have " << my_types.size() <<
" tuple elements"
2479 <<
" but Func " << f.
name()
2480 <<
" has " << f.
types().size() <<
" tuple elements.\n";
2481 for (
size_t i = 0; i < my_types.size(); i++) {
2483 <<
"Cannot assign Func \"" << f.
name()
2484 <<
"\" to Output \"" << this->
name() <<
"\"\n"
2485 << (my_types.size() > 1 ?
"In tuple element " + std::to_string(i) +
", " :
"")
2486 <<
"Output " << this->
name()
2487 <<
" has declared type " << my_types[i]
2488 <<
" but Func " << f.
name()
2489 <<
" has type " << f.
types().at(i) <<
"\n";
2494 <<
"Cannot assign Func \"" << f.
name()
2495 <<
"\" to Output \"" << this->
name() <<
"\"\n"
2496 <<
"Output " << this->
name()
2497 <<
" has declared dimensionality " << this->
dims()
2498 <<
" but Func " << f.
name()
2499 <<
" has dimensionality " << f.
dimensions() <<
"\n";
2512 TBase::has_static_halide_type ? std::vector<
Type>{TBase::static_halide_type()} : std::vector<Type>{},
2513 TBase::has_static_dimensions ? TBase::static_dimensions() : -1) {
2520 static_assert(!TBase::has_static_halide_type,
"You can only specify a Type argument for Output<Buffer<T, D>> if T is void or omitted.");
2521 static_assert(!TBase::has_static_dimensions,
"You can only specify a dimension argument for Output<Buffer<T, D>> if D is -1 or omitted.");
2527 static_assert(!TBase::has_static_halide_type,
"You can only specify a Type argument for Output<Buffer<T, D>> if T is void or omitted.");
2532 TBase::has_static_halide_type ? std::vector<
Type>{TBase::static_halide_type()} : std::vector<Type>{},
2535 static_assert(!TBase::has_static_dimensions,
"You can only specify a dimension argument for Output<Buffer<T, D>> if D is -1 or omitted.");
2540 TBase::has_static_halide_type ? std::vector<
Type>{TBase::static_halide_type()} : std::vector<Type>{},
2541 TBase::has_static_dimensions ? TBase::static_dimensions() : -1) {
2548 static_assert(!TBase::has_static_halide_type,
"You can only specify a Type argument for Output<Buffer<T, D>> if T is void or omitted.");
2549 static_assert(!TBase::has_static_dimensions,
"You can only specify a dimension argument for Output<Buffer<T, D>> if D is -1 or omitted.");
2555 static_assert(!TBase::has_static_halide_type,
"You can only specify a Type argument for Output<Buffer<T, D>> if T is void or omitted.");
2560 TBase::has_static_halide_type ? std::vector<
Type>{TBase::static_halide_type()} : std::vector<Type>{},
2563 static_assert(!TBase::has_static_dimensions,
"You can only specify a dimension argument for Output<Buffer<T, D>> if D is -1 or omitted.");
2567 if (TBase::has_static_halide_type) {
2568 return "Halide::Internal::StubOutputBuffer<" +
2572 return "Halide::Internal::StubOutputBuffer<>";
2576 template<typename T2, typename std::enable_if<!std::is_same<T2, Func>::value>::type * =
nullptr>
2588 template<
typename T2,
int D2>
2594 <<
"Cannot assign to the Output \"" << this->
name()
2595 <<
"\": the expression is not convertible to the same Buffer type and/or dimensions.\n";
2599 <<
"Output " << this->
name() <<
" should have type=" << this->
gio_type() <<
" but saw type=" <<
Type(buffer.
type()) <<
"\n";
2603 <<
"Output " << this->
name() <<
" should have dim=" << this->
dims() <<
" but saw dim=" << buffer.dimensions() <<
"\n";
2608 this->
funcs_.at(0)(_) = buffer(_);
2616 template<
typename T2>
2619 assign_from_func(stub_output_buffer.
f);
2628 assign_from_func(f);
2634 user_assert(!this->
is_array()) <<
"Cannot convert an Output<Buffer<>[]> to an ImageParam; use an explicit subscript operator: " << this->
name();
2636 return this->
funcs_.at(0).output_buffer();
2642 user_assert(!this->
is_array()) <<
"Cannot call set_estimates() on an array Output; use an explicit subscript operator: " << this->
name();
2644 this->
funcs_.at(0).set_estimates(estimates);
2648 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
2651 return this->
template get_values<Func>()[i];
2655 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
2658 return this->
template get_values<Func>()[i];
2679 template<
typename T>
2686 return this->funcs_.at(i);
2714 template<typename T2 = T, typename std::enable_if<!std::is_array<T2>::value>::type * =
nullptr>
2716 this->check_gio_access();
2717 this->check_value_writable();
2721 get_assignable_func_ref(0) = f;
2726 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
2728 this->check_gio_access();
2729 this->check_value_writable();
2730 return get_assignable_func_ref(i);
2734 template<typename T2 = T, typename std::enable_if<std::is_array<T2>::value>::type * =
nullptr>
2736 this->check_gio_access();
2737 return Super::operator[](i);
2741 this->check_gio_access();
2743 for (
Func &f : this->funcs_) {
2744 f.set_estimate(var,
min, extent);
2750 this->check_gio_access();
2752 for (
Func &f : this->funcs_) {
2753 f.set_estimates(estimates);
2759 template<
typename T>
2776 template<typename T, typename TBase = typename std::remove_all_extents<T>::type>
2778 typename select_type<
2785 template<
typename T>
2806 : Super(array_size, name) {
2814 : Super(name, {t}) {
2822 : Super(name, {t}, d) {
2826 : Super(name, t, d) {
2830 : Super(array_size, name, d) {
2834 : Super(array_size, name, {t}) {
2837 explicit GeneratorOutput(
size_t array_size,
const std::string &name,
const std::vector<Type> &t)
2838 : Super(array_size, name, t) {
2842 : Super(array_size, name, {t}, d) {
2845 explicit GeneratorOutput(
size_t array_size,
const std::string &name,
const std::vector<Type> &t,
int d)
2846 : Super(array_size, name, t, d) {
2852 template<
typename T2,
int D2>
2854 Super::operator=(buffer);
2858 template<
typename T2>
2860 Super::operator=(stub_output_buffer);
2865 Super::operator=(f);
2870 namespace Internal {
2872 template<
typename T>
2874 std::istringstream iss(value);
2877 user_assert(!iss.fail() && iss.get() == EOF) <<
"Unable to parse: " << value;
2889 template<
typename T>
2895 if (!error_msg.empty()) {
2898 set_from_string_impl<T>(new_value_string);
2903 return std::string();
2908 return std::string();
2913 return std::string();
2923 static std::unique_ptr<Internal::GeneratorParamBase> make(
2925 const std::string &generator_name,
2926 const std::string &gpname,
2930 std::string error_msg = defined ?
"Cannot set the GeneratorParam " + gpname +
" for " + generator_name +
" because the value is explicitly specified in the C++ source." :
"";
2931 return std::unique_ptr<GeneratorParam_Synthetic<T>>(
2939 template<typename T2 = T, typename std::enable_if<std::is_same<T2, ::Halide::Type>::value>
::type * =
nullptr>
2940 void set_from_string_impl(
const std::string &new_value_string) {
2945 template<typename T2 = T, typename std::enable_if<std::is_integral<T2>::value>
::type * =
nullptr>
2946 void set_from_string_impl(
const std::string &new_value_string) {
2948 gio.
dims_ = parse_scalar<T2>(new_value_string);
2950 gio.
array_size_ = parse_scalar<T2>(new_value_string);
2958 const std::string error_msg;
3016 return autoscheduler_params_;
3024 template<
typename T>
3026 return T::create(*
this);
3028 template<
typename T,
typename... Args>
3029 inline std::unique_ptr<T>
apply(
const Args &...args)
const {
3030 auto t = this->create<T>();
3063 template<
typename T>
3065 return Halide::cast<T>(e);
3070 template<
typename T>
3072 template<
typename T = void,
int D = -1>
3074 template<
typename T>
3090 namespace Internal {
3092 template<
typename... Args>
3098 template<
typename T,
typename... Args>
3100 static const bool value = !std::is_convertible<T, Realization>::value &&
NoRealizations<Args...>::value;
3109 std::set<std::string> names;
3112 std::vector<Internal::GeneratorParamBase *> filter_generator_params;
3115 std::vector<Internal::GeneratorInputBase *> filter_inputs;
3118 std::vector<Internal::GeneratorOutputBase *> filter_outputs;
3123 std::vector<std::unique_ptr<Internal::GeneratorParamBase>> owned_synthetic_params;
3126 std::vector<std::unique_ptr<Internal::GIOBase>> owned_extras;
3134 return filter_generator_params;
3136 const std::vector<Internal::GeneratorInputBase *> &
inputs()
const {
3137 return filter_inputs;
3139 const std::vector<Internal::GeneratorOutputBase *> &
outputs()
const {
3140 return filter_outputs;
3156 template<
typename data_t>
3171 template<
typename... Args>
3176 <<
"Expected exactly " << pi.
inputs().size()
3177 <<
" inputs but got " <<
sizeof...(args) <<
"\n";
3178 set_inputs_vector(build_inputs(std::forward_as_tuple<const Args &...>(args...), std::make_index_sequence<
sizeof...(Args)>{}));
3182 this->check_scheduled(
"realize");
3188 template<
typename... Args,
typename std::enable_if<
NoRealizations<Args...>::value>::type * =
nullptr>
3190 this->check_scheduled(
"realize");
3195 this->check_scheduled(
"realize");
3206 template<
typename T,
3207 typename std::enable_if<std::is_same<T, Halide::Func>::value>::type * =
nullptr>
3211 p->generator =
this;
3212 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3213 param_info_ptr->filter_inputs.push_back(p);
3218 template<
typename T,
3219 typename std::enable_if<!std::is_arithmetic<T>::value && !std::is_same<T, Halide::Func>::value>::type * =
nullptr>
3221 static_assert(!T::has_static_halide_type,
"You can only call this version of add_input() for a Buffer<T, D> where T is void or omitted .");
3222 static_assert(!T::has_static_dimensions,
"You can only call this version of add_input() for a Buffer<T, D> where D is -1 or omitted.");
3225 p->generator =
this;
3226 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3227 param_info_ptr->filter_inputs.push_back(p);
3232 template<
typename T,
3233 typename std::enable_if<!std::is_arithmetic<T>::value && !std::is_same<T, Halide::Func>::value>::type * =
nullptr>
3235 static_assert(T::has_static_halide_type,
"You can only call this version of add_input() for a Buffer<T, D> where T is not void.");
3236 static_assert(!T::has_static_dimensions,
"You can only call this version of add_input() for a Buffer<T, D> where D is -1 or omitted.");
3239 p->generator =
this;
3240 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3241 param_info_ptr->filter_inputs.push_back(p);
3246 template<
typename T,
3247 typename std::enable_if<!std::is_arithmetic<T>::value && !std::is_same<T, Halide::Func>::value>::type * =
nullptr>
3249 static_assert(T::has_static_halide_type,
"You can only call this version of add_input() for a Buffer<T, D> where T is not void.");
3250 static_assert(T::has_static_dimensions,
"You can only call this version of add_input() for a Buffer<T, D> where D is not -1.");
3253 p->generator =
this;
3254 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3255 param_info_ptr->filter_inputs.push_back(p);
3259 template<
typename T,
3260 typename std::enable_if<std::is_arithmetic<T>::value>::type * =
nullptr>
3264 p->generator =
this;
3265 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3266 param_info_ptr->filter_inputs.push_back(p);
3270 template<
typename T,
3271 typename std::enable_if<std::is_same<T, Expr>::value>::type * =
nullptr>
3275 p->generator =
this;
3277 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3278 param_info_ptr->filter_inputs.push_back(p);
3283 template<
typename T,
3284 typename std::enable_if<std::is_same<T, Halide::Func>::value>::type * =
nullptr>
3288 p->generator =
this;
3289 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3290 param_info_ptr->filter_outputs.push_back(p);
3295 template<
typename T,
3296 typename std::enable_if<!std::is_arithmetic<T>::value && !std::is_same<T, Halide::Func>::value>::type * =
nullptr>
3298 static_assert(!T::has_static_halide_type,
"You can only call this version of add_output() for a Buffer<T, D> where T is void or omitted .");
3299 static_assert(!T::has_static_dimensions,
"You can only call this version of add_output() for a Buffer<T, D> where D is -1 or omitted.");
3302 p->generator =
this;
3303 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3304 param_info_ptr->filter_outputs.push_back(p);
3309 template<
typename T,
3310 typename std::enable_if<!std::is_arithmetic<T>::value && !std::is_same<T, Halide::Func>::value>::type * =
nullptr>
3312 static_assert(T::has_static_halide_type,
"You can only call this version of add_output() for a Buffer<T, D> where T is not void.");
3313 static_assert(!T::has_static_dimensions,
"You can only call this version of add_output() for a Buffer<T, D> where D is -1 or omitted.");
3316 p->generator =
this;
3317 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3318 param_info_ptr->filter_outputs.push_back(p);
3323 template<
typename T,
3324 typename std::enable_if<!std::is_arithmetic<T>::value && !std::is_same<T, Halide::Func>::value>::type * =
nullptr>
3326 static_assert(T::has_static_halide_type,
"You can only call this version of add_output() for a Buffer<T, D> where T is not void.");
3327 static_assert(T::has_static_dimensions,
"You can only call this version of add_output() for a Buffer<T, D> where D is not -1.");
3330 p->generator =
this;
3331 param_info_ptr->owned_extras.push_back(std::unique_ptr<Internal::GIOBase>(p));
3332 param_info_ptr->filter_outputs.push_back(p);
3338 template<
typename... Args,
3341 std::vector<Expr> collected_args;
3351 GeneratorBase(
size_t size,
const void *introspection_helper);
3352 void set_generator_names(
const std::string ®istered_name,
const std::string &stub_name);
3424 template<
typename T>
3427 template<
typename T>
3483 std::unique_ptr<GeneratorParamInfo> param_info_ptr;
3485 std::string generator_registered_name, generator_stub_name;
3488 struct Requirement {
3490 std::vector<Expr> error_args;
3492 std::vector<Requirement> requirements;
3497 template<
typename T>
3498 T *find_by_name(
const std::string &
name,
const std::vector<T *> &v) {
3500 if (t->name() ==
name) {
3507 Internal::GeneratorInputBase *find_input_by_name(
const std::string &
name);
3508 Internal::GeneratorOutputBase *find_output_by_name(
const std::string &
name);
3510 void check_scheduled(
const char *m)
const;
3512 void build_params(
bool force =
false);
3517 void get_host_target();
3518 void get_jit_target_from_environment();
3519 void get_target_from_environment();
3521 void set_inputs_vector(
const std::vector<std::vector<StubInput>> &inputs);
3523 static void check_input_is_singular(Internal::GeneratorInputBase *in);
3524 static void check_input_is_array(Internal::GeneratorInputBase *in);
3531 template<
typename T,
int Dims>
3532 std::vector<StubInput> build_input(
size_t i,
const Buffer<T, Dims> &arg) {
3533 auto *in = param_info().
inputs().at(i);
3534 check_input_is_singular(in);
3535 const auto k = in->kind();
3538 StubInputBuffer<> sib(b);
3543 f(Halide::_) = arg(Halide::_);
3556 template<
typename T,
int Dims>
3557 std::vector<StubInput> build_input(
size_t i,
const GeneratorInput<Buffer<T, Dims>> &arg) {
3558 auto *in = param_info().
inputs().at(i);
3559 check_input_is_singular(in);
3560 const auto k = in->kind();
3562 StubInputBuffer<> sib = arg;
3576 std::vector<StubInput> build_input(
size_t i,
const Func &arg) {
3577 auto *in = param_info().
inputs().at(i);
3579 check_input_is_singular(in);
3586 std::vector<StubInput> build_input(
size_t i,
const std::vector<Func> &arg) {
3587 auto *in = param_info().
inputs().at(i);
3589 check_input_is_array(in);
3591 std::vector<StubInput> siv;
3592 siv.reserve(arg.size());
3593 for (
const auto &f : arg) {
3594 siv.emplace_back(f);
3600 std::vector<StubInput> build_input(
size_t i,
const Expr &arg) {
3601 auto *in = param_info().
inputs().at(i);
3603 check_input_is_singular(in);
3609 std::vector<StubInput> build_input(
size_t i,
const std::vector<Expr> &arg) {
3610 auto *in = param_info().
inputs().at(i);
3612 check_input_is_array(in);
3613 std::vector<StubInput> siv;
3614 siv.reserve(arg.size());
3615 for (
const auto &value : arg) {
3616 siv.emplace_back(value);
3623 template<
typename T,
3624 typename std::enable_if<std::is_arithmetic<T>::value>::type * =
nullptr>
3625 std::vector<StubInput> build_input(
size_t i,
const T &arg) {
3626 auto *in = param_info().
inputs().at(i);
3628 check_input_is_singular(in);
3636 template<
typename T,
3637 typename std::enable_if<std::is_arithmetic<T>::value>::type * =
nullptr>
3638 std::vector<StubInput> build_input(
size_t i,
const std::vector<T> &arg) {
3639 auto *in = param_info().
inputs().at(i);
3641 check_input_is_array(in);
3642 std::vector<StubInput> siv;
3643 siv.reserve(arg.size());
3644 for (
const auto &value : arg) {
3648 siv.emplace_back(e);
3653 template<
typename... Args,
size_t... Indices>
3654 std::vector<std::vector<StubInput>> build_inputs(
const std::tuple<const Args &...> &t, std::index_sequence<Indices...>) {
3655 return {build_input(Indices, std::get<Indices>(t))...};
3660 template<
typename T>
3661 static void get_arguments(std::vector<AbstractGenerator::ArgInfo> &args,
ArgInfoDirection dir,
const T &t) {
3663 args.push_back({e->name(),
3666 e->gio_types_defined() ? e->gio_types() : std::vector<Type>{},
3667 e->dims_defined() ? e->dims() : 0});
3673 std::string
name()
override;
3675 std::vector<ArgInfo>
arginfos()
override;
3686 void bind_input(
const std::string &
name,
const std::vector<Parameter> &v)
override;
3687 void bind_input(
const std::string &
name,
const std::vector<Func> &v)
override;
3688 void bind_input(
const std::string &
name,
const std::vector<Expr> &v)
override;
3690 bool emit_cpp_stub(
const std::string &stub_file_path)
override;
3702 static std::vector<std::string>
enumerate();
3709 using GeneratorFactoryMap = std::map<const std::string, GeneratorFactory>;
3711 GeneratorFactoryMap factories;
3739 auto g = std::make_unique<T>();
3740 g->init_from_context(
context);
3746 const std::string ®istered_name,
3747 const std::string &stub_name) {
3749 g->set_generator_names(registered_name, stub_name);
3753 template<
typename... Args>
3761 template<
typename T2>
3766 template<
typename T2,
typename... Args>
3767 inline std::unique_ptr<T2>
apply(
const Args &...args)
const {
3768 auto t = this->create<T2>();
3781 template<
typename T2,
typename =
void>
3782 struct has_configure_method : std::false_type {};
3784 template<
typename T2>
3785 struct has_configure_method<T2, typename type_sink<decltype(std::declval<T2>().configure())>::type> : std::true_type {};
3787 template<
typename T2,
typename =
void>
3788 struct has_generate_method : std::false_type {};
3790 template<
typename T2>
3791 struct has_generate_method<T2, typename type_sink<decltype(std::declval<T2>().generate())>::type> : std::true_type {};
3793 template<
typename T2,
typename =
void>
3794 struct has_schedule_method : std::false_type {};
3796 template<
typename T2>
3797 struct has_schedule_method<T2, typename type_sink<decltype(std::declval<T2>().
schedule())>::type> : std::true_type {};
3806 t->call_generate_impl();
3807 t->call_schedule_impl();
3811 void call_configure_impl() {
3813 if constexpr (has_configure_method<T>::value) {
3815 static_assert(std::is_void<decltype(t->configure())>::value,
"configure() must return void");
3821 void call_generate_impl() {
3823 static_assert(has_generate_method<T>::value,
"Expected a generate() method here.");
3825 static_assert(std::is_void<decltype(t->generate())>::value,
"generate() must return void");
3830 void call_schedule_impl() {
3832 if constexpr (has_schedule_method<T>::value) {
3834 static_assert(std::is_void<decltype(t->schedule())>::value,
"schedule() must return void");
3843 return this->build_pipeline_impl();
3847 this->call_configure_impl();
3851 this->call_generate_impl();
3855 this->call_schedule_impl();
3861 friend class ::Halide::GeneratorContext;
3964 const std::string &name,
3967 const std::string &name,
3977 struct halide_global_ns;
3980 #define _HALIDE_REGISTER_GENERATOR_IMPL(GEN_CLASS_NAME, GEN_REGISTRY_NAME, FULLY_QUALIFIED_STUB_NAME) \
3981 namespace halide_register_generator { \
3982 struct halide_global_ns; \
3983 namespace GEN_REGISTRY_NAME##_ns { \
3984 std::unique_ptr<Halide::Internal::AbstractGenerator> factory(const Halide::GeneratorContext &context); \
3985 std::unique_ptr<Halide::Internal::AbstractGenerator> factory(const Halide::GeneratorContext &context) { \
3986 using GenType = std::remove_pointer<decltype(new GEN_CLASS_NAME)>::type; \
3987 return GenType::create(context, #GEN_REGISTRY_NAME, #FULLY_QUALIFIED_STUB_NAME); \
3990 static auto reg_##GEN_REGISTRY_NAME = Halide::Internal::RegisterGenerator(#GEN_REGISTRY_NAME, GEN_REGISTRY_NAME##_ns::factory); \
3992 static_assert(std::is_same<::halide_register_generator::halide_global_ns, halide_register_generator::halide_global_ns>::value, \
3993 "HALIDE_REGISTER_GENERATOR must be used at global scope");
3995 #define _HALIDE_REGISTER_GENERATOR2(GEN_CLASS_NAME, GEN_REGISTRY_NAME) \
3996 _HALIDE_REGISTER_GENERATOR_IMPL(GEN_CLASS_NAME, GEN_REGISTRY_NAME, GEN_REGISTRY_NAME)
3998 #define _HALIDE_REGISTER_GENERATOR3(GEN_CLASS_NAME, GEN_REGISTRY_NAME, FULLY_QUALIFIED_STUB_NAME) \
3999 _HALIDE_REGISTER_GENERATOR_IMPL(GEN_CLASS_NAME, GEN_REGISTRY_NAME, FULLY_QUALIFIED_STUB_NAME)
4004 #define __HALIDE_REGISTER_ARGCOUNT_IMPL(_1, _2, _3, COUNT, ...) \
4007 #define _HALIDE_REGISTER_ARGCOUNT_IMPL(ARGS) \
4008 __HALIDE_REGISTER_ARGCOUNT_IMPL ARGS
4010 #define _HALIDE_REGISTER_ARGCOUNT(...) \
4011 _HALIDE_REGISTER_ARGCOUNT_IMPL((__VA_ARGS__, 3, 2, 1, 0))
4013 #define ___HALIDE_REGISTER_CHOOSER(COUNT) \
4014 _HALIDE_REGISTER_GENERATOR##COUNT
4016 #define __HALIDE_REGISTER_CHOOSER(COUNT) \
4017 ___HALIDE_REGISTER_CHOOSER(COUNT)
4019 #define _HALIDE_REGISTER_CHOOSER(COUNT) \
4020 __HALIDE_REGISTER_CHOOSER(COUNT)
4022 #define _HALIDE_REGISTER_GENERATOR_PASTE(A, B) \
4025 #define HALIDE_REGISTER_GENERATOR(...) \
4026 _HALIDE_REGISTER_GENERATOR_PASTE(_HALIDE_REGISTER_CHOOSER(_HALIDE_REGISTER_ARGCOUNT(__VA_ARGS__)), (__VA_ARGS__))
4042 #define HALIDE_REGISTER_GENERATOR_ALIAS(GEN_REGISTRY_NAME, ORIGINAL_REGISTRY_NAME, ...) \
4043 namespace halide_register_generator { \
4044 struct halide_global_ns; \
4045 namespace ORIGINAL_REGISTRY_NAME##_ns { \
4046 std::unique_ptr<Halide::Internal::AbstractGenerator> factory(const Halide::GeneratorContext &context); \
4048 namespace GEN_REGISTRY_NAME##_ns { \
4049 std::unique_ptr<Halide::Internal::AbstractGenerator> factory(const Halide::GeneratorContext &context) { \
4050 auto g = ORIGINAL_REGISTRY_NAME##_ns::factory(context); \
4051 const Halide::GeneratorParamsMap m = __VA_ARGS__; \
4052 g->set_generatorparam_values(m); \
4056 static auto reg_##GEN_REGISTRY_NAME = Halide::Internal::RegisterGenerator(#GEN_REGISTRY_NAME, GEN_REGISTRY_NAME##_ns::factory); \
4058 static_assert(std::is_same<::halide_register_generator::halide_global_ns, halide_register_generator::halide_global_ns>::value, \
4059 "HALIDE_REGISTER_GENERATOR_ALIAS must be used at global scope");
4064 #define HALIDE_GENERATOR_PYSTUB(GEN_REGISTRY_NAME, MODULE_NAME) \
4065 static_assert(PY_MAJOR_VERSION >= 3, "Python bindings for Halide require Python 3+"); \
4066 extern "C" PyObject *_halide_pystub_impl(const char *module_name, const Halide::Internal::GeneratorFactory &factory); \
4067 namespace halide_register_generator::GEN_REGISTRY_NAME##_ns { \
4068 extern std::unique_ptr<Halide::Internal::AbstractGenerator> factory(const Halide::GeneratorContext &context); \
4070 extern "C" HALIDE_EXPORT_SYMBOL PyObject *PyInit_##MODULE_NAME() { \
4071 const auto factory = halide_register_generator::GEN_REGISTRY_NAME##_ns::factory; \
4072 return _halide_pystub_impl(#MODULE_NAME, factory); \
4075 #endif // HALIDE_GENERATOR_H_