86 friend class PythonBindings::PyCallable;
100 static constexpr QuickCallCheckInfo _make_qcci(
uint8_t code,
uint8_t bits) {
104 static constexpr QuickCallCheckInfo make_scalar_qcci(
halide_type_t t) {
108 static constexpr QuickCallCheckInfo make_buffer_qcci() {
109 constexpr uint8_t fake_bits_buffer_cci = 3;
113 static constexpr QuickCallCheckInfo make_ucon_qcci() {
114 constexpr uint8_t fake_bits_ucon_cci = 5;
119 static constexpr QuickCallCheckInfo make_qcci() {
120 using T0 =
typename std::remove_const<typename std::remove_reference<T>::type>::type;
121 if constexpr (std::is_same<T0, JITUserContext *>::value) {
122 return make_ucon_qcci();
123 }
else if constexpr (Internal::IsHalideBuffer<T0>::value) {
125 return make_buffer_qcci();
126 }
else if constexpr (std::is_arithmetic<T0>::value || std::is_pointer<T0>::value) {
127 return make_scalar_qcci(halide_type_of<T0>());
132 static_assert(!
sizeof(T),
"Illegal type passed to Callable.");
136 template<
typename... Args>
137 static constexpr std::array<QuickCallCheckInfo,
sizeof...(Args)> make_qcci_array() {
138 return std::array<QuickCallCheckInfo,
sizeof...(Args)>{make_qcci<Args>()...};
151 static constexpr FullCallCheckInfo _make_fcci(
halide_type_t type,
int dims,
bool is_buffer) {
152 return type.with_lanes(((
uint16_t)dims << 1) | (
uint16_t)(is_buffer ? 1 : 0));
155 static constexpr FullCallCheckInfo make_scalar_fcci(
halide_type_t t) {
156 return _make_fcci(t, 0,
false);
159 static constexpr FullCallCheckInfo make_buffer_fcci(
halide_type_t t,
int dims) {
160 return _make_fcci(t, dims,
true);
163 static bool is_compatible_fcci(FullCallCheckInfo actual, FullCallCheckInfo expected) {
164 if (actual == expected) {
169 const bool a_is_buffer = (actual.lanes & 1) != 0;
170 const int a_dims = (((
int16_t)actual.lanes) >> 1);
173 const bool e_is_buffer = (expected.lanes & 1) != 0;
174 const int e_dims = (((
int16_t)expected.lanes) >> 1);
181 const bool dims_match = a_dims < 0 ||
185 return a_is_buffer == e_is_buffer && types_match && dims_match;
189 static constexpr FullCallCheckInfo make_fcci() {
190 using T0 =
typename std::remove_const<typename std::remove_reference<T>::type>::type;
191 if constexpr (Internal::IsHalideBuffer<T0>::value) {
192 using TypeAndDims = Internal::HalideBufferStaticTypeAndDims<T0>;
193 return make_buffer_fcci(TypeAndDims::type(), TypeAndDims::dims());
194 }
else if constexpr (std::is_arithmetic<T0>::value || std::is_pointer<T0>::value) {
195 return make_scalar_fcci(halide_type_of<T0>());
200 static_assert(!
sizeof(T),
"Illegal type passed to Callable.");
204 template<
typename... Args>
205 static constexpr std::array<FullCallCheckInfo,
sizeof...(Args)> make_fcci_array() {
206 return std::array<FullCallCheckInfo,
sizeof...(Args)>{make_fcci<Args>()...};
213 const void *argv[Size];
220 template<
typename... Args>
221 explicit ArgvStorage(Args &&...args) {
222 fill_slots(0, std::forward<Args>(args)...);
226 template<
typename T,
int Dims>
232 argv[idx] = value.defined() ? value.get()->raw_buffer() :
nullptr;
235 template<
typename T,
int Dims>
236 HALIDE_ALWAYS_INLINE void fill_slot(
size_t idx, const ::Halide::Runtime::Buffer<T, Dims> &value) {
237 argv[idx] = value.raw_buffer();
251 void fill_slot(
size_t idx, JITUserContext *value) {
252 auto *dest = &argv_scalar_store[idx];
259 auto *dest = &argv_scalar_store[idx];
266 fill_slot(idx, value);
269 template<
typename First,
typename Second,
typename... Rest>
271 fill_slots<First>(idx, std::forward<First>(first));
272 fill_slots<Second, Rest...>(idx + 1, std::forward<Second>(second), std::forward<Rest>(rest)...);
277 const JITHandlers &jit_handlers,
278 const std::map<std::string, JITExtern> &jit_externs,
279 Internal::JITCache &&jit_cache);
282 int call_argv_checked(
size_t argc,
const void *
const *argv,
const QuickCallCheckInfo *actual_cci)
const;
284 using FailureFn = std::function<int(JITUserContext *)>;
286 FailureFn do_check_fail(
int bad_idx,
size_t argc,
const char *verb)
const;
287 FailureFn check_qcci(
size_t argc,
const QuickCallCheckInfo *actual_cci)
const;
288 FailureFn check_fcci(
size_t argc,
const FullCallCheckInfo *actual_cci)
const;
290 template<
typename... Args>
291 int call(JITUserContext *context, Args &&...args)
const {
293 static constexpr auto actual_arg_types = make_qcci_array<JITUserContext *, Args...>();
295 constexpr size_t count =
sizeof...(args) + 1;
296 ArgvStorage<count> argv(context, std::forward<Args>(args)...);
297 return call_argv_checked(count, &argv.argv[0], actual_arg_types.data());
302 const std::vector<Argument> &arguments()
const;
312 template<
typename... Args>
315 return call(context, std::forward<Args>(args)...);
318 template<
typename... Args>
322 return call(&empty, std::forward<Args>(args)...);
345 template<
typename First,
typename... Rest>
346 std::function<int(First, Rest...)>
348 if constexpr (std::is_same_v<First, JITUserContext *>) {
349 constexpr auto actual_arg_types = make_fcci_array<First, Rest...>();
350 const auto failure_fn = check_fcci(actual_arg_types.size(), actual_arg_types.data());
354 return [*
this, failure_fn](
auto &&first,
auto &&...rest) ->
int {
355 return failure_fn(std::forward<First>(first));
360 return [*
this](
auto &&first,
auto &&...rest) ->
int {
361 constexpr size_t count = 1 +
sizeof...(rest);
362 ArgvStorage<count> argv(std::forward<First>(first), std::forward<Rest>(rest)...);
367 constexpr auto actual_arg_types = make_fcci_array<
JITUserContext *, First, Rest...>();
368 const auto failure_fn = check_fcci(actual_arg_types.size(), actual_arg_types.data());
372 return [*
this, failure_fn](
auto &&first,
auto &&...rest) ->
int {
374 return failure_fn(&empty);
379 return [*
this](
auto &&first,
auto &&...rest) ->
int {
382 constexpr size_t count = 1 + 1 +
sizeof...(rest);
383 ArgvStorage<count> argv(&empty, std::forward<First>(first), std::forward<Rest>(rest)...);
A context to be passed to Pipeline::realize.