Halide
Callable.h
Go to the documentation of this file.
1 #ifndef HALIDE_CALLABLE_H
2 #define HALIDE_CALLABLE_H
3 
4 /** \file
5  *
6  * Defines the front-end class representing a jitted, callable Halide pipeline.
7  */
8 
9 #include <array>
10 #include <map>
11 
12 #include "Buffer.h"
13 #include "IntrusivePtr.h"
14 #include "JITModule.h"
15 
16 namespace Halide {
17 
18 struct Argument;
19 struct CallableContents;
20 
21 namespace PythonBindings {
22 class PyCallable;
23 }
24 
25 namespace Internal {
26 
27 template<typename>
28 struct IsHalideBuffer : std::false_type {};
29 
30 template<typename T, int Dims>
31 struct IsHalideBuffer<::Halide::Buffer<T, Dims>> : std::true_type {};
32 
33 template<typename T, int Dims>
34 struct IsHalideBuffer<::Halide::Runtime::Buffer<T, Dims>> : std::true_type {};
35 
36 template<>
37 struct IsHalideBuffer<halide_buffer_t *> : std::true_type {};
38 
39 template<>
40 struct IsHalideBuffer<const halide_buffer_t *> : std::true_type {};
41 
42 template<typename>
44  static constexpr halide_type_t type() {
45  return halide_type_t();
46  }
47  static constexpr int dims() {
48  return -1;
49  }
50 };
51 
52 template<typename T, int Dims>
54  static constexpr halide_type_t type() {
55  if constexpr (std::is_void_v<T>) {
56  return halide_type_t();
57  } else {
58  return halide_type_of<T>();
59  }
60  }
61  static constexpr int dims() {
62  return Dims;
63  }
64 };
65 
66 template<typename T, int Dims>
68  static constexpr halide_type_t type() {
69  if constexpr (std::is_void_v<T>) {
70  return halide_type_t();
71  } else {
72  return halide_type_of<T>();
73  }
74  }
75  static constexpr int dims() {
76  return Dims;
77  }
78 };
79 
80 } // namespace Internal
81 
82 class Callable {
83 private:
84  friend class Pipeline;
85  friend struct CallableContents;
86  friend class PythonBindings::PyCallable;
87 
89 
90  // ---------------------------------
91 
92  // This value is constructed so we can do the necessary runtime check
93  // with a single 16-bit compare. It's designed to to the minimal checking
94  // necessary to ensure that the arguments are well-formed, but not necessarily
95  // "correct"; in particular, it deliberately skips checking type-and-dim
96  // of Buffer arguments, since the generated code has assertions to check
97  // for that anyway.
98  using QuickCallCheckInfo = uint16_t;
99 
100  static constexpr QuickCallCheckInfo _make_qcci(uint8_t code, uint8_t bits) {
101  return (((uint16_t)code) << 8) | (uint16_t)bits;
102  }
103 
104  static constexpr QuickCallCheckInfo make_scalar_qcci(halide_type_t t) {
105  return _make_qcci(t.code, t.bits);
106  }
107 
108  static constexpr QuickCallCheckInfo make_buffer_qcci() {
109  constexpr uint8_t fake_bits_buffer_cci = 3;
110  return _make_qcci(halide_type_handle, fake_bits_buffer_cci);
111  }
112 
113  static constexpr QuickCallCheckInfo make_ucon_qcci() {
114  constexpr uint8_t fake_bits_ucon_cci = 5;
115  return _make_qcci(halide_type_handle, fake_bits_ucon_cci);
116  }
117 
118  template<typename T>
119  static constexpr QuickCallCheckInfo make_qcci() {
120  using T0 = typename std::remove_const<typename std::remove_reference<T>::type>::type;
121  if constexpr (std::is_same<T0, JITUserContext *>::value) {
122  return make_ucon_qcci();
123  } else if constexpr (Internal::IsHalideBuffer<T0>::value) {
124  // Don't bother checking type-and-dimensions here (the callee will do that)
125  return make_buffer_qcci();
126  } else if constexpr (std::is_arithmetic<T0>::value || std::is_pointer<T0>::value) {
127  return make_scalar_qcci(halide_type_of<T0>());
128  } else {
129  // static_assert(false) will fail all the time, even inside constexpr,
130  // but gating on sizeof(T) is a nice trick that ensures we will always
131  // fail here (since no T is ever size 0).
132  static_assert(!sizeof(T), "Illegal type passed to Callable.");
133  }
134  }
135 
136  template<typename... Args>
137  static constexpr std::array<QuickCallCheckInfo, sizeof...(Args)> make_qcci_array() {
138  return std::array<QuickCallCheckInfo, sizeof...(Args)>{make_qcci<Args>()...};
139  }
140 
141  // ---------------------------------
142 
143  // This value is constructed so we can do a complete type-and-dim check
144  // of Buffers, and is used for the make_std_function() method, to ensure
145  // that if we specify static type-and-dims for Buffers, the ones we specify
146  // actually match the underlying code. We take horrible liberties with halide_type_t
147  // to make this happen -- specifically, encoding dimensionality and buffer-vs-scalar
148  // into the 'lanes' field -- but that's ok since this never escapes into other usage.
149  using FullCallCheckInfo = halide_type_t;
150 
151  static constexpr FullCallCheckInfo _make_fcci(halide_type_t type, int dims, bool is_buffer) {
152  return type.with_lanes(((uint16_t)dims << 1) | (uint16_t)(is_buffer ? 1 : 0));
153  }
154 
155  static constexpr FullCallCheckInfo make_scalar_fcci(halide_type_t t) {
156  return _make_fcci(t, 0, false);
157  }
158 
159  static constexpr FullCallCheckInfo make_buffer_fcci(halide_type_t t, int dims) {
160  return _make_fcci(t, dims, true);
161  }
162 
163  static bool is_compatible_fcci(FullCallCheckInfo actual, FullCallCheckInfo expected) {
164  if (actual == expected) {
165  return true; // my, that was easy
166  }
167 
168  // Might still be compatible
169  const bool a_is_buffer = (actual.lanes & 1) != 0;
170  const int a_dims = (((int16_t)actual.lanes) >> 1);
171  const halide_type_t a_type = actual.with_lanes(0);
172 
173  const bool e_is_buffer = (expected.lanes & 1) != 0;
174  const int e_dims = (((int16_t)expected.lanes) >> 1);
175  const halide_type_t e_type = expected.with_lanes(0);
176 
177  const bool types_match = (a_type == halide_type_t()) ||
178  (e_type == halide_type_t()) ||
179  (a_type == e_type);
180 
181  const bool dims_match = a_dims < 0 ||
182  e_dims < 0 ||
183  a_dims == e_dims;
184 
185  return a_is_buffer == e_is_buffer && types_match && dims_match;
186  }
187 
188  template<typename T>
189  static constexpr FullCallCheckInfo make_fcci() {
190  using T0 = typename std::remove_const<typename std::remove_reference<T>::type>::type;
191  if constexpr (Internal::IsHalideBuffer<T0>::value) {
192  using TypeAndDims = Internal::HalideBufferStaticTypeAndDims<T0>;
193  return make_buffer_fcci(TypeAndDims::type(), TypeAndDims::dims());
194  } else if constexpr (std::is_arithmetic<T0>::value || std::is_pointer<T0>::value) {
195  return make_scalar_fcci(halide_type_of<T0>());
196  } else {
197  // static_assert(false) will fail all the time, even inside constexpr,
198  // but gating on sizeof(T) is a nice trick that ensures we will always
199  // fail here (since no T is ever size 0).
200  static_assert(!sizeof(T), "Illegal type passed to Callable.");
201  }
202  }
203 
204  template<typename... Args>
205  static constexpr std::array<FullCallCheckInfo, sizeof...(Args)> make_fcci_array() {
206  return std::array<FullCallCheckInfo, sizeof...(Args)>{make_fcci<Args>()...};
207  }
208 
209  // ---------------------------------
210 
211  template<int Size>
212  struct ArgvStorage {
213  const void *argv[Size];
214  // We need a place to store the scalar inputs, since we need a pointer
215  // to them and it's better to avoid relying on stack spill of arguments.
216  // Note that this will usually have unused slots, but it's cheap and easy
217  // compile-time allocation on the stack.
218  uintptr_t argv_scalar_store[Size];
219 
220  template<typename... Args>
221  explicit ArgvStorage(Args &&...args) {
222  fill_slots(0, std::forward<Args>(args)...);
223  }
224 
225  private:
226  template<typename T, int Dims>
227  HALIDE_ALWAYS_INLINE void fill_slot(size_t idx, const ::Halide::Buffer<T, Dims> &value) {
228  // Don't call ::Halide::Buffer::raw_buffer(): it includes "user_assert(defined())"
229  // as part of the wrapper code, and we want this lean-and-mean. Instead, stick in a null
230  // value for undefined buffers, and let the Halide pipeline fail with the usual null-ptr
231  // check. (Note that H::R::B::get() *never* returns null; you must check defined() first.)
232  argv[idx] = value.defined() ? value.get()->raw_buffer() : nullptr;
233  }
234 
235  template<typename T, int Dims>
236  HALIDE_ALWAYS_INLINE void fill_slot(size_t idx, const ::Halide::Runtime::Buffer<T, Dims> &value) {
237  argv[idx] = value.raw_buffer();
238  }
239 
241  void fill_slot(size_t idx, halide_buffer_t *value) {
242  argv[idx] = value;
243  }
244 
246  void fill_slot(size_t idx, const halide_buffer_t *value) {
247  argv[idx] = value;
248  }
249 
251  void fill_slot(size_t idx, JITUserContext *value) {
252  auto *dest = &argv_scalar_store[idx];
253  *dest = (uintptr_t)value;
254  argv[idx] = dest;
255  }
256 
257  template<typename T>
258  HALIDE_ALWAYS_INLINE void fill_slot(size_t idx, const T &value) {
259  auto *dest = &argv_scalar_store[idx];
260  *(T *)dest = value;
261  argv[idx] = dest;
262  }
263 
264  template<typename T>
265  HALIDE_ALWAYS_INLINE void fill_slots(size_t idx, const T &value) {
266  fill_slot(idx, value);
267  }
268 
269  template<typename First, typename Second, typename... Rest>
270  HALIDE_ALWAYS_INLINE void fill_slots(int idx, First &&first, Second &&second, Rest &&...rest) {
271  fill_slots<First>(idx, std::forward<First>(first));
272  fill_slots<Second, Rest...>(idx + 1, std::forward<Second>(second), std::forward<Rest>(rest)...);
273  }
274  };
275 
276  Callable(const std::string &name,
277  const JITHandlers &jit_handlers,
278  const std::map<std::string, JITExtern> &jit_externs,
279  Internal::JITCache &&jit_cache);
280 
281  // Note that the first entry in argv must always be a JITUserContext*.
282  int call_argv_checked(size_t argc, const void *const *argv, const QuickCallCheckInfo *actual_cci) const;
283 
284  using FailureFn = std::function<int(JITUserContext *)>;
285 
286  FailureFn do_check_fail(int bad_idx, size_t argc, const char *verb) const;
287  FailureFn check_qcci(size_t argc, const QuickCallCheckInfo *actual_cci) const;
288  FailureFn check_fcci(size_t argc, const FullCallCheckInfo *actual_cci) const;
289 
290  template<typename... Args>
291  int call(JITUserContext *context, Args &&...args) const {
292  // This is built at compile time!
293  static constexpr auto actual_arg_types = make_qcci_array<JITUserContext *, Args...>();
294 
295  constexpr size_t count = sizeof...(args) + 1;
296  ArgvStorage<count> argv(context, std::forward<Args>(args)...);
297  return call_argv_checked(count, &argv.argv[0], actual_arg_types.data());
298  }
299 
300  /** Return the expected Arguments for this Callable, in the order they must be specified, including all outputs.
301  * Note that the first entry will *always* specify a JITUserContext. */
302  const std::vector<Argument> &arguments() const;
303 
304 public:
305  /** Construct a default Callable. This is not usable (trying to call it will fail).
306  * The defined() method will return false. */
307  Callable();
308 
309  /** Return true if the Callable is well-defined and usable, false if it is a default-constructed empty Callable. */
310  bool defined() const;
311 
312  template<typename... Args>
314  operator()(JITUserContext *context, Args &&...args) const {
315  return call(context, std::forward<Args>(args)...);
316  }
317 
318  template<typename... Args>
320  operator()(Args &&...args) const {
321  JITUserContext empty;
322  return call(&empty, std::forward<Args>(args)...);
323  }
324 
325  /** This allows us to construct a std::function<> that wraps the Callable.
326  * This is nice in that it is, well, just a std::function, but also in that
327  * since the argument-count-and-type checking are baked into the language,
328  * we can do the relevant checking only once -- when we first create the std::function --
329  * and skip it on all actual *calls* to the function, making it slightly more efficient.
330  * It's also more type-forgiving, in that the usual C++ numeric coercion rules apply here.
331  *
332  * The downside is that there isn't (currently) any way to automatically infer
333  * the static types reliably, since we may be using (e.g.) a Param<void>, where the
334  * type in question isn't available to the C++ compiler. This means that the coder
335  * must supply the correct type signature when calling this function -- but the good news
336  * is that if you get it wrong, this function will fail when you call it. (In other words:
337  * it can't choose the right thing for you, but it can tell you when you do the wrong thing.)
338  *
339  * TODO: it's possible that we could infer the correct signatures in some cases,
340  * and only fail for the ambiguous cases, but that would require a lot more template-fu
341  * here and elsewhere. I think this is good enough for now.
342  *
343  * TODO: is it possible to annotate the result of a std::function<> with HALIDE_FUNCTION_ATTRS?
344  */
345  template<typename First, typename... Rest>
346  std::function<int(First, Rest...)>
348  if constexpr (std::is_same_v<First, JITUserContext *>) {
349  constexpr auto actual_arg_types = make_fcci_array<First, Rest...>();
350  const auto failure_fn = check_fcci(actual_arg_types.size(), actual_arg_types.data());
351  if (failure_fn) {
352  // Return a wrapper for the failure_fn in case the error handler is a no-op,
353  // so that subsequent calls won't attempt to use possibly-wrong argv packing.
354  return [*this, failure_fn](auto &&first, auto &&...rest) -> int {
355  return failure_fn(std::forward<First>(first));
356  };
357  }
358 
359  // Capture *this to ensure that the CallableContents stay valid as long as the std::function does
360  return [*this](auto &&first, auto &&...rest) -> int {
361  constexpr size_t count = 1 + sizeof...(rest);
362  ArgvStorage<count> argv(std::forward<First>(first), std::forward<Rest>(rest)...);
363  return call_argv_fast(count, &argv.argv[0]);
364  };
365  } else {
366  // Explicitly prepend JITUserContext* as first actual-arg-type.
367  constexpr auto actual_arg_types = make_fcci_array<JITUserContext *, First, Rest...>();
368  const auto failure_fn = check_fcci(actual_arg_types.size(), actual_arg_types.data());
369  if (failure_fn) {
370  // Return a wrapper for the failure_fn in case the error handler is a no-op,
371  // so that subsequent calls won't attempt to use possibly-wrong argv packing.
372  return [*this, failure_fn](auto &&first, auto &&...rest) -> int {
373  JITUserContext empty;
374  return failure_fn(&empty);
375  };
376  }
377 
378  // Capture *this to ensure that the CallableContents stay valid as long as the std::function does
379  return [*this](auto &&first, auto &&...rest) -> int {
380  // Explicitly prepend an (empty) JITUserContext to the args.
381  JITUserContext empty;
382  constexpr size_t count = 1 + 1 + sizeof...(rest);
383  ArgvStorage<count> argv(&empty, std::forward<First>(first), std::forward<Rest>(rest)...);
384  return call_argv_fast(count, &argv.argv[0]);
385  };
386  }
387  }
388 
389  /** Unsafe low-overhead way of invoking the Callable.
390  *
391  * This function relies on the same calling convention as the argv-based
392  * functions generated for ahead-of-time compiled Halide pilelines.
393  *
394  * Very rough specifications of the calling convention (but check the source
395  * code to be sure):
396  *
397  * * Arguments are passed in the same order as they appear in the C
398  * function argument list.
399  * * The first entry in argv must always be a JITUserContext*. Please,
400  * note that this means that argv[0] actually contains JITUserContext**.
401  * * All scalar arguments are passed by pointer, not by value, regardless of size.
402  * * All buffer arguments (input or output) are passed as halide_buffer_t*.
403  *
404  */
405  int call_argv_fast(size_t argc, const void *const *argv) const;
406 };
407 
408 } // namespace Halide
409 
410 #endif
Halide::Callable
Definition: Callable.h:82
halide_type_handle
@ halide_type_handle
opaque pointer type (void *)
Definition: HalideRuntime.h:457
uint8_t
unsigned __INT8_TYPE__ uint8_t
Definition: runtime_internal.h:29
Halide::Internal::HalideBufferStaticTypeAndDims
Definition: Callable.h:43
Halide::Internal::HalideBufferStaticTypeAndDims<::Halide::Runtime::Buffer< T, Dims > >::dims
static constexpr int dims()
Definition: Callable.h:75
uint16_t
unsigned __INT16_TYPE__ uint16_t
Definition: runtime_internal.h:27
Halide::Internal::IsHalideBuffer
Definition: Callable.h:28
Halide::Callable::Callable
Callable()
Construct a default Callable.
Halide::Callable::operator()
HALIDE_FUNCTION_ATTRS int operator()(Args &&...args) const
Definition: Callable.h:320
Halide::Internal::HalideBufferStaticTypeAndDims::type
static constexpr halide_type_t type()
Definition: Callable.h:44
halide_type_t::bits
uint8_t bits
The number of bits of precision of a single scalar value of this type.
Definition: HalideRuntime.h:488
halide_type_t
A runtime tag for a type in the halide type system.
Definition: HalideRuntime.h:476
Halide::Internal::HalideBufferStaticTypeAndDims<::Halide::Runtime::Buffer< T, Dims > >::type
static constexpr halide_type_t type()
Definition: Callable.h:68
Halide::Internal::IntrusivePtr< CallableContents >
uintptr_t
__UINTPTR_TYPE__ uintptr_t
Definition: runtime_internal.h:73
HALIDE_FUNCTION_ATTRS
#define HALIDE_FUNCTION_ATTRS
Definition: HalideRuntime.h:67
Halide::Callable::make_std_function
std::function< int(First, Rest...)> make_std_function() const
This allows us to construct a std::function<> that wraps the Callable.
Definition: Callable.h:347
Halide::JITUserContext
A context to be passed to Pipeline::realize.
Definition: JITModule.h:136
Halide::Pipeline
A class representing a Halide pipeline.
Definition: Pipeline.h:108
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AbstractGenerator.h:19
Halide::LinkageType::Internal
@ Internal
Not visible externally, similar to 'static' linkage in C.
Halide::Callable::defined
bool defined() const
Return true if the Callable is well-defined and usable, false if it is a default-constructed empty Ca...
Halide::Internal::HalideBufferStaticTypeAndDims<::Halide::Buffer< T, Dims > >::dims
static constexpr int dims()
Definition: Callable.h:61
Halide::Internal::HalideBufferStaticTypeAndDims<::Halide::Buffer< T, Dims > >::type
static constexpr halide_type_t type()
Definition: Callable.h:54
JITModule.h
Halide::Buffer
A Halide::Buffer is a named shared reference to a Halide::Runtime::Buffer.
Definition: Argument.h:16
Buffer.h
Halide::Runtime::Buffer
A templated Buffer class that wraps halide_buffer_t and adds functionality.
Definition: HalideBuffer.h:121
HALIDE_ALWAYS_INLINE
#define HALIDE_ALWAYS_INLINE
Definition: HalideRuntime.h:40
Halide::Callable::call_argv_fast
int call_argv_fast(size_t argc, const void *const *argv) const
Unsafe low-overhead way of invoking the Callable.
halide_buffer_t
The raw representation of an image passed around by generated Halide code.
Definition: HalideRuntime.h:1490
Halide::Internal::HalideBufferStaticTypeAndDims::dims
static constexpr int dims()
Definition: Callable.h:47
int16_t
signed __INT16_TYPE__ int16_t
Definition: runtime_internal.h:26
IntrusivePtr.h
Halide::Callable::operator()
HALIDE_FUNCTION_ATTRS int operator()(JITUserContext *context, Args &&...args) const
Definition: Callable.h:314
Halide::Callable::CallableContents
friend struct CallableContents
Definition: Callable.h:85
halide_type_t::code
uint8_t code
The basic type code: signed integer, unsigned integer, or floating point.
Definition: HalideRuntime.h:483