Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
Callable.h
Go to the documentation of this file.
1#ifndef HALIDE_CALLABLE_H
2#define HALIDE_CALLABLE_H
3
4/** \file
5 *
6 * Defines the front-end class representing a jitted, callable Halide pipeline.
7 */
8
9#include <array>
10#include <map>
11
12#include "Buffer.h"
13#include "IntrusivePtr.h"
14#include "JITModule.h"
15
16namespace Halide {
17
18struct Argument;
19struct CallableContents;
20
21namespace PythonBindings {
22class PyCallable;
23}
24
25namespace Internal {
26
27template<typename>
28struct IsHalideBuffer : std::false_type {};
29
30template<typename T, int Dims>
31struct IsHalideBuffer<::Halide::Buffer<T, Dims>> : std::true_type {};
32
33template<typename T, int Dims>
34struct IsHalideBuffer<::Halide::Runtime::Buffer<T, Dims>> : std::true_type {};
35
36template<>
37struct IsHalideBuffer<halide_buffer_t *> : std::true_type {};
38
39template<>
40struct IsHalideBuffer<const halide_buffer_t *> : std::true_type {};
41
42template<typename>
44 static constexpr halide_type_t type() {
45 return halide_type_t();
46 }
47 static constexpr int dims() {
48 return -1;
49 }
50};
51
52template<typename T, int Dims>
54 static constexpr halide_type_t type() {
55 if constexpr (std::is_void_v<T>) {
56 return halide_type_t();
57 } else {
58 return halide_type_of<T>();
59 }
60 }
61 static constexpr int dims() {
62 return Dims;
63 }
64};
65
66template<typename T, int Dims>
68 static constexpr halide_type_t type() {
69 if constexpr (std::is_void_v<T>) {
70 return halide_type_t();
71 } else {
72 return halide_type_of<T>();
73 }
74 }
75 static constexpr int dims() {
76 return Dims;
77 }
78};
79
80} // namespace Internal
81
82class Callable {
83private:
84 friend class Pipeline;
85 friend struct CallableContents;
86 friend class PythonBindings::PyCallable;
87
89
90 // ---------------------------------
91
92 // This value is constructed so we can do the necessary runtime check
93 // with a single 16-bit compare. It's designed to to the minimal checking
94 // necessary to ensure that the arguments are well-formed, but not necessarily
95 // "correct"; in particular, it deliberately skips checking type-and-dim
96 // of Buffer arguments, since the generated code has assertions to check
97 // for that anyway.
98 using QuickCallCheckInfo = uint16_t;
99
100 static constexpr QuickCallCheckInfo _make_qcci(uint8_t code, uint8_t bits) {
101 return (((uint16_t)code) << 8) | (uint16_t)bits;
102 }
103
104 static constexpr QuickCallCheckInfo make_scalar_qcci(halide_type_t t) {
105 return _make_qcci(t.code, t.bits);
106 }
107
108 static constexpr QuickCallCheckInfo make_buffer_qcci() {
109 constexpr uint8_t fake_bits_buffer_cci = 3;
110 return _make_qcci(halide_type_handle, fake_bits_buffer_cci);
111 }
112
113 static constexpr QuickCallCheckInfo make_ucon_qcci() {
114 constexpr uint8_t fake_bits_ucon_cci = 5;
115 return _make_qcci(halide_type_handle, fake_bits_ucon_cci);
116 }
117
118 template<typename T>
119 static constexpr QuickCallCheckInfo make_qcci() {
120 using T0 = typename std::remove_const<typename std::remove_reference<T>::type>::type;
121 if constexpr (std::is_same<T0, JITUserContext *>::value) {
122 return make_ucon_qcci();
123 } else if constexpr (Internal::IsHalideBuffer<T0>::value) {
124 // Don't bother checking type-and-dimensions here (the callee will do that)
125 return make_buffer_qcci();
126 } else if constexpr (std::is_arithmetic<T0>::value || std::is_pointer<T0>::value) {
127 return make_scalar_qcci(halide_type_of<T0>());
128 } else {
129 // static_assert(false) will fail all the time, even inside constexpr,
130 // but gating on sizeof(T) is a nice trick that ensures we will always
131 // fail here (since no T is ever size 0).
132 static_assert(!sizeof(T), "Illegal type passed to Callable.");
133 }
134 }
135
136 template<typename... Args>
137 static constexpr std::array<QuickCallCheckInfo, sizeof...(Args)> make_qcci_array() {
138 return std::array<QuickCallCheckInfo, sizeof...(Args)>{make_qcci<Args>()...};
139 }
140
141 // ---------------------------------
142
143 // This value is constructed so we can do a complete type-and-dim check
144 // of Buffers, and is used for the make_std_function() method, to ensure
145 // that if we specify static type-and-dims for Buffers, the ones we specify
146 // actually match the underlying code. We take horrible liberties with halide_type_t
147 // to make this happen -- specifically, encoding dimensionality and buffer-vs-scalar
148 // into the 'lanes' field -- but that's ok since this never escapes into other usage.
149 using FullCallCheckInfo = halide_type_t;
150
151 static constexpr FullCallCheckInfo _make_fcci(halide_type_t type, int dims, bool is_buffer) {
152 return type.with_lanes(((uint16_t)dims << 1) | (uint16_t)(is_buffer ? 1 : 0));
153 }
154
155 static constexpr FullCallCheckInfo make_scalar_fcci(halide_type_t t) {
156 return _make_fcci(t, 0, false);
157 }
158
159 static constexpr FullCallCheckInfo make_buffer_fcci(halide_type_t t, int dims) {
160 return _make_fcci(t, dims, true);
161 }
162
163 static bool is_compatible_fcci(FullCallCheckInfo actual, FullCallCheckInfo expected) {
164 if (actual == expected) {
165 return true; // my, that was easy
166 }
167
168 // Might still be compatible
169 const bool a_is_buffer = (actual.lanes & 1) != 0;
170 const int a_dims = (((int16_t)actual.lanes) >> 1);
171 const halide_type_t a_type = actual.with_lanes(0);
172
173 const bool e_is_buffer = (expected.lanes & 1) != 0;
174 const int e_dims = (((int16_t)expected.lanes) >> 1);
175 const halide_type_t e_type = expected.with_lanes(0);
176
177 const bool types_match = (a_type == halide_type_t()) ||
178 (e_type == halide_type_t()) ||
179 (a_type == e_type);
180
181 const bool dims_match = a_dims < 0 ||
182 e_dims < 0 ||
183 a_dims == e_dims;
184
185 return a_is_buffer == e_is_buffer && types_match && dims_match;
186 }
187
188 template<typename T>
189 static constexpr FullCallCheckInfo make_fcci() {
190 using T0 = typename std::remove_const<typename std::remove_reference<T>::type>::type;
191 if constexpr (Internal::IsHalideBuffer<T0>::value) {
192 using TypeAndDims = Internal::HalideBufferStaticTypeAndDims<T0>;
193 return make_buffer_fcci(TypeAndDims::type(), TypeAndDims::dims());
194 } else if constexpr (std::is_arithmetic<T0>::value || std::is_pointer<T0>::value) {
195 return make_scalar_fcci(halide_type_of<T0>());
196 } else {
197 // static_assert(false) will fail all the time, even inside constexpr,
198 // but gating on sizeof(T) is a nice trick that ensures we will always
199 // fail here (since no T is ever size 0).
200 static_assert(!sizeof(T), "Illegal type passed to Callable.");
201 }
202 }
203
204 template<typename... Args>
205 static constexpr std::array<FullCallCheckInfo, sizeof...(Args)> make_fcci_array() {
206 return std::array<FullCallCheckInfo, sizeof...(Args)>{make_fcci<Args>()...};
207 }
208
209 // ---------------------------------
210
211 template<int Size>
212 struct ArgvStorage {
213 const void *argv[Size];
214 // We need a place to store the scalar inputs, since we need a pointer
215 // to them and it's better to avoid relying on stack spill of arguments.
216 // Note that this will usually have unused slots, but it's cheap and easy
217 // compile-time allocation on the stack.
218 uintptr_t argv_scalar_store[Size];
219
220 template<typename... Args>
221 explicit ArgvStorage(Args &&...args) {
222 fill_slots(0, std::forward<Args>(args)...);
223 }
224
225 private:
226 template<typename T, int Dims>
227 HALIDE_ALWAYS_INLINE void fill_slot(size_t idx, const ::Halide::Buffer<T, Dims> &value) {
228 // Don't call ::Halide::Buffer::raw_buffer(): it includes "user_assert(defined())"
229 // as part of the wrapper code, and we want this lean-and-mean. Instead, stick in a null
230 // value for undefined buffers, and let the Halide pipeline fail with the usual null-ptr
231 // check. (Note that H::R::B::get() *never* returns null; you must check defined() first.)
232 argv[idx] = value.defined() ? value.get()->raw_buffer() : nullptr;
233 }
234
235 template<typename T, int Dims>
236 HALIDE_ALWAYS_INLINE void fill_slot(size_t idx, const ::Halide::Runtime::Buffer<T, Dims> &value) {
237 argv[idx] = value.raw_buffer();
238 }
239
241 void fill_slot(size_t idx, halide_buffer_t *value) {
242 argv[idx] = value;
243 }
244
246 void fill_slot(size_t idx, const halide_buffer_t *value) {
247 argv[idx] = value;
248 }
249
251 void fill_slot(size_t idx, JITUserContext *value) {
252 auto *dest = &argv_scalar_store[idx];
253 *dest = (uintptr_t)value;
254 argv[idx] = dest;
255 }
256
257 template<typename T>
258 HALIDE_ALWAYS_INLINE void fill_slot(size_t idx, const T &value) {
259 auto *dest = &argv_scalar_store[idx];
260 *(T *)dest = value;
261 argv[idx] = dest;
262 }
263
264 template<typename T>
265 HALIDE_ALWAYS_INLINE void fill_slots(size_t idx, const T &value) {
266 fill_slot(idx, value);
267 }
268
269 template<typename First, typename Second, typename... Rest>
270 HALIDE_ALWAYS_INLINE void fill_slots(int idx, First &&first, Second &&second, Rest &&...rest) {
271 fill_slots<First>(idx, std::forward<First>(first));
272 fill_slots<Second, Rest...>(idx + 1, std::forward<Second>(second), std::forward<Rest>(rest)...);
273 }
274 };
275
276 Callable(const std::string &name,
277 const JITHandlers &jit_handlers,
278 const std::map<std::string, JITExtern> &jit_externs,
279 Internal::JITCache &&jit_cache);
280
281 // Note that the first entry in argv must always be a JITUserContext*.
282 int call_argv_checked(size_t argc, const void *const *argv, const QuickCallCheckInfo *actual_cci) const;
283
284 using FailureFn = std::function<int(JITUserContext *)>;
285
286 FailureFn do_check_fail(int bad_idx, size_t argc, const char *verb) const;
287 FailureFn check_qcci(size_t argc, const QuickCallCheckInfo *actual_cci) const;
288 FailureFn check_fcci(size_t argc, const FullCallCheckInfo *actual_cci) const;
289
290 template<typename... Args>
291 int call(JITUserContext *context, Args &&...args) const {
292 // This is built at compile time!
293 static constexpr auto actual_arg_types = make_qcci_array<JITUserContext *, Args...>();
294
295 constexpr size_t count = sizeof...(args) + 1;
296 ArgvStorage<count> argv(context, std::forward<Args>(args)...);
297 return call_argv_checked(count, &argv.argv[0], actual_arg_types.data());
298 }
299
300 /** Return the expected Arguments for this Callable, in the order they must be specified, including all outputs.
301 * Note that the first entry will *always* specify a JITUserContext. */
302 const std::vector<Argument> &arguments() const;
303
304public:
305 /** Construct a default Callable. This is not usable (trying to call it will fail).
306 * The defined() method will return false. */
308
309 /** Return true if the Callable is well-defined and usable, false if it is a default-constructed empty Callable. */
310 bool defined() const;
311
312 template<typename... Args>
314 operator()(JITUserContext *context, Args &&...args) const {
315 return call(context, std::forward<Args>(args)...);
316 }
317
318 template<typename... Args>
320 operator()(Args &&...args) const {
321 JITUserContext empty;
322 return call(&empty, std::forward<Args>(args)...);
323 }
324
325 /** This allows us to construct a std::function<> that wraps the Callable.
326 * This is nice in that it is, well, just a std::function, but also in that
327 * since the argument-count-and-type checking are baked into the language,
328 * we can do the relevant checking only once -- when we first create the std::function --
329 * and skip it on all actual *calls* to the function, making it slightly more efficient.
330 * It's also more type-forgiving, in that the usual C++ numeric coercion rules apply here.
331 *
332 * The downside is that there isn't (currently) any way to automatically infer
333 * the static types reliably, since we may be using (e.g.) a Param<void>, where the
334 * type in question isn't available to the C++ compiler. This means that the coder
335 * must supply the correct type signature when calling this function -- but the good news
336 * is that if you get it wrong, this function will fail when you call it. (In other words:
337 * it can't choose the right thing for you, but it can tell you when you do the wrong thing.)
338 *
339 * TODO: it's possible that we could infer the correct signatures in some cases,
340 * and only fail for the ambiguous cases, but that would require a lot more template-fu
341 * here and elsewhere. I think this is good enough for now.
342 *
343 * TODO: is it possible to annotate the result of a std::function<> with HALIDE_FUNCTION_ATTRS?
344 */
345 template<typename First, typename... Rest>
346 std::function<int(First, Rest...)>
348 if constexpr (std::is_same_v<First, JITUserContext *>) {
349 constexpr auto actual_arg_types = make_fcci_array<First, Rest...>();
350 const auto failure_fn = check_fcci(actual_arg_types.size(), actual_arg_types.data());
351 if (failure_fn) {
352 // Return a wrapper for the failure_fn in case the error handler is a no-op,
353 // so that subsequent calls won't attempt to use possibly-wrong argv packing.
354 return [*this, failure_fn](auto &&first, auto &&...rest) -> int {
355 return failure_fn(std::forward<First>(first));
356 };
357 }
358
359 // Capture *this to ensure that the CallableContents stay valid as long as the std::function does
360 return [*this](auto &&first, auto &&...rest) -> int {
361 constexpr size_t count = 1 + sizeof...(rest);
362 ArgvStorage<count> argv(std::forward<First>(first), std::forward<Rest>(rest)...);
363 return call_argv_fast(count, &argv.argv[0]);
364 };
365 } else {
366 // Explicitly prepend JITUserContext* as first actual-arg-type.
367 constexpr auto actual_arg_types = make_fcci_array<JITUserContext *, First, Rest...>();
368 const auto failure_fn = check_fcci(actual_arg_types.size(), actual_arg_types.data());
369 if (failure_fn) {
370 // Return a wrapper for the failure_fn in case the error handler is a no-op,
371 // so that subsequent calls won't attempt to use possibly-wrong argv packing.
372 return [*this, failure_fn](auto &&first, auto &&...rest) -> int {
373 JITUserContext empty;
374 return failure_fn(&empty);
375 };
376 }
377
378 // Capture *this to ensure that the CallableContents stay valid as long as the std::function does
379 return [*this](auto &&first, auto &&...rest) -> int {
380 // Explicitly prepend an (empty) JITUserContext to the args.
381 JITUserContext empty;
382 constexpr size_t count = 1 + 1 + sizeof...(rest);
383 ArgvStorage<count> argv(&empty, std::forward<First>(first), std::forward<Rest>(rest)...);
384 return call_argv_fast(count, &argv.argv[0]);
385 };
386 }
387 }
388
389 /** Unsafe low-overhead way of invoking the Callable.
390 *
391 * This function relies on the same calling convention as the argv-based
392 * functions generated for ahead-of-time compiled Halide pilelines.
393 *
394 * Very rough specifications of the calling convention (but check the source
395 * code to be sure):
396 *
397 * * Arguments are passed in the same order as they appear in the C
398 * function argument list.
399 * * The first entry in argv must always be a JITUserContext*. Please,
400 * note that this means that argv[0] actually contains JITUserContext**.
401 * * All scalar arguments are passed by pointer, not by value, regardless of size.
402 * * All buffer arguments (input or output) are passed as halide_buffer_t*.
403 *
404 */
405 int call_argv_fast(size_t argc, const void *const *argv) const;
406};
407
408} // namespace Halide
409
410#endif
#define HALIDE_FUNCTION_ATTRS
@ halide_type_handle
opaque pointer type (void *)
#define HALIDE_ALWAYS_INLINE
Support classes for reference-counting via intrusive shared pointers.
Defines the struct representing lifetime and dependencies of a JIT compiled halide pipeline.
A Halide::Buffer is a named shared reference to a Halide::Runtime::Buffer.
Definition RDom.h:21
int call_argv_fast(size_t argc, const void *const *argv) const
Unsafe low-overhead way of invoking the Callable.
std::function< int(First, Rest...)> make_std_function() const
This allows us to construct a std::function<> that wraps the Callable.
Definition Callable.h:347
Callable()
Construct a default Callable.
friend struct CallableContents
Definition Callable.h:85
HALIDE_FUNCTION_ATTRS int operator()(Args &&...args) const
Definition Callable.h:320
bool defined() const
Return true if the Callable is well-defined and usable, false if it is a default-constructed empty Ca...
HALIDE_FUNCTION_ATTRS int operator()(JITUserContext *context, Args &&...args) const
Definition Callable.h:314
A class representing a Halide pipeline.
Definition Pipeline.h:107
A templated Buffer class that wraps halide_buffer_t and adds functionality.
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
__UINTPTR_TYPE__ uintptr_t
unsigned __INT8_TYPE__ uint8_t
unsigned __INT16_TYPE__ uint16_t
signed __INT16_TYPE__ int16_t
static constexpr halide_type_t type()
Definition Callable.h:44
Intrusive shared pointers have a reference count (a RefCount object) stored in the class itself.
A context to be passed to Pipeline::realize.
Definition JITModule.h:136
The raw representation of an image passed around by generated Halide code.
A runtime tag for a type in the halide type system.
uint8_t bits
The number of bits of precision of a single scalar value of this type.
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.