Halide
JITModule.h
Go to the documentation of this file.
1 #ifndef HALIDE_JIT_MODULE_H
2 #define HALIDE_JIT_MODULE_H
3 
4 /** \file
5  * Defines the struct representing lifetime and dependencies of
6  * a JIT compiled halide pipeline
7  */
8 
9 #include <map>
10 #include <memory>
11 #include <vector>
12 
13 #include "IntrusivePtr.h"
14 #include "Target.h"
15 #include "Type.h"
16 #include "WasmExecutor.h"
17 #include "runtime/HalideRuntime.h"
18 
19 namespace llvm {
20 class Module;
21 }
22 
23 namespace Halide {
24 
25 struct ExternCFunction;
26 struct JITExtern;
27 class Module;
28 
29 struct JITUserContext;
30 
31 /** A set of custom overrides of runtime functions. These only apply
32  * when JIT-compiling code. If you are doing AOT compilation, see
33  * HalideRuntime.h for instructions on how to replace runtime
34  * functions. */
35 struct JITHandlers {
36  /** Set the function called to print messages from the runtime. */
37  void (*custom_print)(JITUserContext *, const char *){nullptr};
38 
39  /** A custom malloc and free for halide to use. Malloc should
40  * return 32-byte aligned chunks of memory, and it should be safe
41  * for Halide to read slightly out of bounds (up to 8 bytes before
42  * the start or beyond the end). */
43  // @{
44  void *(*custom_malloc)(JITUserContext *, size_t){nullptr};
45  void (*custom_free)(JITUserContext *, void *){nullptr};
46  // @}
47 
48  /** A custom task handler to be called by the parallel for
49  * loop. It is useful to set this if you want to do some
50  * additional bookkeeping at the granularity of parallel
51  * tasks. The default implementation does this:
52  \code
53  extern "C" int halide_do_task(JITUserContext *user_context,
54  int (*f)(void *, int, uint8_t *),
55  int idx, uint8_t *state) {
56  return f(user_context, idx, state);
57  }
58  \endcode
59  *
60  * If you're trying to use a custom parallel runtime, you probably
61  * don't want to call this. See instead custom_do_par_for.
62  */
63  int (*custom_do_task)(JITUserContext *, int (*)(JITUserContext *, int, uint8_t *), int, uint8_t *){nullptr};
64 
65  /** A custom parallel for loop launcher. Useful if your app
66  * already manages a thread pool. The default implementation is
67  * equivalent to this:
68  \code
69  extern "C" int halide_do_par_for(JITUserContext *user_context,
70  int (*f)(void *, int, uint8_t *),
71  int min, int extent, uint8_t *state) {
72  int exit_status = 0;
73  parallel for (int idx = min; idx < min+extent; idx++) {
74  int job_status = halide_do_task(user_context, f, idx, state);
75  if (job_status) exit_status = job_status;
76  }
77  return exit_status;
78  }
79  \endcode
80  *
81  * However, notwithstanding the above example code, if one task
82  * fails, we may skip over other tasks, and if two tasks return
83  * different error codes, we may select one arbitrarily to return.
84  */
85  int (*custom_do_par_for)(JITUserContext *, int (*)(JITUserContext *, int, uint8_t *), int, int, uint8_t *){nullptr};
86 
87  /** The error handler function that be called in the case of
88  * runtime errors during halide pipelines. */
89  void (*custom_error)(JITUserContext *, const char *){nullptr};
90 
91  /** A custom routine to call when tracing is enabled. Call this
92  * on the output Func of your pipeline. This then sets custom
93  * routines for the entire pipeline, not just calls to this
94  * Func. */
96 
97  /** A method to use for Halide to resolve symbol names dynamically
98  * in the calling process or library from within the Halide
99  * runtime. Equivalent to dlsym with a null first argument. */
100  void *(*custom_get_symbol)(const char *name){nullptr};
101 
102  /** A method to use for Halide to dynamically load libraries from
103  * within the runtime. Equivalent to dlopen. Returns a handle to
104  * the opened library. */
105  void *(*custom_load_library)(const char *name){nullptr};
106 
107  /** A method to use for Halide to dynamically find a symbol within
108  * an opened library. Equivalent to dlsym. Takes a handle
109  * returned by custom_load_library as the first argument. */
110  void *(*custom_get_library_symbol)(void *lib, const char *name){nullptr};
111 
112  /** A custom method for the Halide runtime acquire a cuda
113  * context. The cuda context is treated as a void * to avoid a
114  * dependence on the cuda headers. If the create argument is set
115  * to true, a context should be created if one does not already
116  * exist. */
117  int32_t (*custom_cuda_acquire_context)(JITUserContext *user_context, void **cuda_context_ptr, bool create){nullptr};
118 
119  /** The Halide runtime calls this when it is done with a cuda
120  * context. The default implementation does nothing. */
122 
123  /** A custom method for the Halide runtime to acquire a cuda
124  * stream to use. The cuda context and stream are both modelled
125  * as a void *, to avoid a dependence on the cuda headers. */
126  int32_t (*custom_cuda_get_stream)(JITUserContext *user_context, void *cuda_context, void **stream_ptr){nullptr};
127 };
128 
129 namespace Internal {
130 struct JITErrorBuffer;
131 }
132 
133 /** A context to be passed to Pipeline::realize. Inherit from this to
134  * pass your own custom context object. Modify the handlers field to
135  * override runtime functions per-call to realize. */
139 };
140 
141 namespace Internal {
142 
143 class JITModuleContents;
144 struct LoweredFunc;
145 
146 struct JITModule {
148 
149  struct Symbol {
150  void *address = nullptr;
151  Symbol() = default;
152  explicit Symbol(void *address)
153  : address(address) {
154  }
155  };
156 
157  JITModule();
158  JITModule(const Module &m, const LoweredFunc &fn,
159  const std::vector<JITModule> &dependencies = std::vector<JITModule>());
160 
161  /** Take a list of JITExterns and generate trampoline functions
162  * which can be called dynamically via a function pointer that
163  * takes an array of void *'s for each argument and the return
164  * value.
165  */
166  static JITModule make_trampolines_module(const Target &target,
167  const std::map<std::string, JITExtern> &externs,
168  const std::string &suffix,
169  const std::vector<JITModule> &deps);
170 
171  /** The exports map of a JITModule contains all symbols which are
172  * available to other JITModules which depend on this one. For
173  * runtime modules, this is all of the symbols exported from the
174  * runtime. For a JITted Func, it generally only contains the main
175  * result Func of the compilation, which takes its name directly
176  * from the Func declaration. One can also make a module which
177  * contains no code itself but is just an exports maps providing
178  * arbitrary pointers to functions or global variables to JITted
179  * code. */
180  const std::map<std::string, Symbol> &exports() const;
181 
182  /** A pointer to the raw halide function. Its true type depends
183  * on the Argument vector passed to CodeGen_LLVM::compile. Image
184  * parameters become (halide_buffer_t *), and scalar parameters become
185  * pointers to the appropriate values. The final argument is a
186  * pointer to the halide_buffer_t defining the output. This will be nullptr for
187  * a JITModule which has not yet been compiled or one that is not
188  * a Halide Func compilation at all. */
189  void *main_function() const;
190 
191  /** Returns the Symbol structure for the routine documented in
192  * main_function. Returning a Symbol allows access to the LLVM
193  * type as well as the address. The address and type will be nullptr
194  * if the module has not been compiled. */
195  Symbol entrypoint_symbol() const;
196 
197  /** Returns the Symbol structure for the argv wrapper routine
198  * corresponding to the entrypoint. The argv wrapper is callable
199  * via an array of void * pointers to the arguments for the
200  * call. Returning a Symbol allows access to the LLVM type as well
201  * as the address. The address and type will be nullptr if the module
202  * has not been compiled. */
203  Symbol argv_entrypoint_symbol() const;
204 
205  /** A slightly more type-safe wrapper around the raw halide
206  * module. Takes it arguments as an array of pointers that
207  * correspond to the arguments to \ref main_function . This will
208  * be nullptr for a JITModule which has not yet been compiled or one
209  * that is not a Halide Func compilation at all. */
210  // @{
211  typedef int (*argv_wrapper)(const void *const *args);
212  argv_wrapper argv_function() const;
213  // @}
214 
215  /** Add another JITModule to the dependency chain. Dependencies
216  * are searched to resolve symbols not found in the current
217  * compilation unit while JITting. */
218  void add_dependency(JITModule &dep);
219  /** Registers a single Symbol as available to modules which depend
220  * on this one. The Symbol structure provides both the address and
221  * the LLVM type for the function, which allows type safe linkage of
222  * extenal routines. */
223  void add_symbol_for_export(const std::string &name, const Symbol &extern_symbol);
224  /** Registers a single function as available to modules which
225  * depend on this one. This routine converts the ExternSignature
226  * info into an LLVM type, which allows type safe linkage of
227  * external routines. */
228  void add_extern_for_export(const std::string &name,
229  const ExternCFunction &extern_c_function);
230 
231  /** Look up a symbol by name in this module or its dependencies. */
232  Symbol find_symbol_by_name(const std::string &) const;
233 
234  /** Take an llvm module and compile it. The requested exports will
235  be available via the exports method. */
236  void compile_module(std::unique_ptr<llvm::Module> mod,
237  const std::string &function_name, const Target &target,
238  const std::vector<JITModule> &dependencies = std::vector<JITModule>(),
239  const std::vector<std::string> &requested_exports = std::vector<std::string>());
240 
241  /** See JITSharedRuntime::memoization_cache_set_size */
242  void memoization_cache_set_size(int64_t size) const;
243 
244  /** See JITSharedRuntime::memoization_cache_evict */
245  void memoization_cache_evict(uint64_t eviction_key) const;
246 
247  /** See JITSharedRuntime::reuse_device_allocations */
248  void reuse_device_allocations(bool) const;
249 
250  /** Return true if compile_module has been called on this module. */
251  bool compiled() const;
252 };
253 
255 public:
256  // Note only the first llvm::Module passed in here is used. The same shared runtime is used for all JIT.
257  static std::vector<JITModule> get(llvm::Module *m, const Target &target, bool create = true);
258  static void populate_jit_handlers(JITUserContext *jit_user_context, const JITHandlers &handlers);
259  static JITHandlers set_default_handlers(const JITHandlers &handlers);
260 
261  /** Set the maximum number of bytes used by memoization caching.
262  * If you are compiling statically, you should include HalideRuntime.h
263  * and call halide_memoization_cache_set_size() instead.
264  */
265  static void memoization_cache_set_size(int64_t size);
266 
267  /** Evict all cache entries that were tagged with the given
268  * eviction_key in the memoize scheduling directive. If you are
269  * compiling statically, you should include HalideRuntime.h and
270  * call halide_memoization_cache_evict() instead.
271  */
272  static void memoization_cache_evict(uint64_t eviction_key);
273 
274  /** Set whether or not Halide may hold onto and reuse device
275  * allocations to avoid calling expensive device API allocation
276  * functions. If you are compiling statically, you should include
277  * HalideRuntime.h and call halide_reuse_device_allocations
278  * instead. */
279  static void reuse_device_allocations(bool);
280 
281  static void release_all();
282 };
283 
284 void *get_symbol_address(const char *s);
285 
286 struct JITCache {
288  // Arguments for all inputs and outputs
289  std::vector<Argument> arguments;
290  std::map<std::string, JITExtern> jit_externs;
293 
294  JITCache() = default;
296  std::vector<Argument> arguments,
297  std::map<std::string, JITExtern> jit_externs,
300 
302 
303  int call_jit_code(const Target &target, const void *const *args);
304 
305  void finish_profiling(JITUserContext *context);
306 };
307 
309  enum { MaxBufSize = 4096 };
311  std::atomic<size_t> end{0};
312 
313  void concat(const char *message);
314 
315  std::string str() const;
316 
317  static void handler(JITUserContext *ctx, const char *message);
318 };
319 
324 
325  JITFuncCallContext(JITUserContext *context, const JITHandlers &pipeline_handlers);
326 
327  void finalize(int exit_status);
328 };
329 
330 } // namespace Internal
331 } // namespace Halide
332 
333 #endif
Halide::Internal::JITErrorBuffer
Definition: JITModule.h:308
int32_t
signed __INT32_TYPE__ int32_t
Definition: runtime_internal.h:24
Halide::JITHandlers::custom_trace
int32_t(* custom_trace)(JITUserContext *, const halide_trace_event_t *)
A custom routine to call when tracing is enabled.
Definition: JITModule.h:95
Halide::Internal::JITModule::add_dependency
void add_dependency(JITModule &dep)
Add another JITModule to the dependency chain.
Halide::Internal::JITFuncCallContext::finalize
void finalize(int exit_status)
Halide::Internal::JITModule::compiled
bool compiled() const
Return true if compile_module has been called on this module.
Halide::Internal::IRMatcher::mod
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1066
llvm
Definition: CodeGen_Internal.h:18
Halide::JITHandlers::custom_print
void(* custom_print)(JITUserContext *, const char *)
Set the function called to print messages from the runtime.
Definition: JITModule.h:37
halide_trace_event_t
Definition: HalideRuntime.h:562
Halide::Internal::JITSharedRuntime::release_all
static void release_all()
Halide::Internal::JITModule::add_symbol_for_export
void add_symbol_for_export(const std::string &name, const Symbol &extern_symbol)
Registers a single Symbol as available to modules which depend on this one.
Halide::Internal::JITModule::JITModule
JITModule()
uint8_t
unsigned __INT8_TYPE__ uint8_t
Definition: runtime_internal.h:29
Halide::Internal::JITModule::Symbol::Symbol
Symbol()=default
Halide::JITUserContext::handlers
JITHandlers handlers
Definition: JITModule.h:138
Halide::Internal::JITSharedRuntime::populate_jit_handlers
static void populate_jit_handlers(JITUserContext *jit_user_context, const JITHandlers &handlers)
Halide::Internal::JITModule::memoization_cache_evict
void memoization_cache_evict(uint64_t eviction_key) const
See JITSharedRuntime::memoization_cache_evict.
Halide::Internal::JITCache::jit_target
Target jit_target
Definition: JITModule.h:287
Halide::Internal::JITCache::call_jit_code
int call_jit_code(const Target &target, const void *const *args)
Halide::Internal::JITCache::finish_profiling
void finish_profiling(JITUserContext *context)
Halide::Internal::JITModule::add_extern_for_export
void add_extern_for_export(const std::string &name, const ExternCFunction &extern_c_function)
Registers a single function as available to modules which depend on this one.
Halide::Internal::JITSharedRuntime::memoization_cache_set_size
static void memoization_cache_set_size(int64_t size)
Set the maximum number of bytes used by memoization caching.
Halide::Internal::JITModule::compile_module
void compile_module(std::unique_ptr< llvm::Module > mod, const std::string &function_name, const Target &target, const std::vector< JITModule > &dependencies=std::vector< JITModule >(), const std::vector< std::string > &requested_exports=std::vector< std::string >())
Take an llvm module and compile it.
Halide::JITHandlers::custom_cuda_get_stream
int32_t(* custom_cuda_get_stream)(JITUserContext *user_context, void *cuda_context, void **stream_ptr)
A custom method for the Halide runtime to acquire a cuda stream to use.
Definition: JITModule.h:126
Target.h
Halide::Internal::JITCache::jit_externs
std::map< std::string, JITExtern > jit_externs
Definition: JITModule.h:290
Halide::Internal::JITCache::jit_module
JITModule jit_module
Definition: JITModule.h:291
Halide::Internal::JITModule::jit_module
IntrusivePtr< JITModuleContents > jit_module
Definition: JITModule.h:147
Halide::Internal::JITModule::memoization_cache_set_size
void memoization_cache_set_size(int64_t size) const
See JITSharedRuntime::memoization_cache_set_size.
Halide::Internal::IntrusivePtr< JITModuleContents >
uint64_t
unsigned __INT64_TYPE__ uint64_t
Definition: runtime_internal.h:23
Halide::JITUserContext
A context to be passed to Pipeline::realize.
Definition: JITModule.h:136
Halide::Module
A halide module.
Definition: Module.h:138
Halide::ExternCFunction
Definition: Pipeline.h:554
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AbstractGenerator.h:19
Halide::Internal::JITSharedRuntime::reuse_device_allocations
static void reuse_device_allocations(bool)
Set whether or not Halide may hold onto and reuse device allocations to avoid calling expensive devic...
Halide::Internal::JITErrorBuffer::end
std::atomic< size_t > end
Definition: JITModule.h:311
Halide::JITHandlers::custom_cuda_acquire_context
int32_t(* custom_cuda_acquire_context)(JITUserContext *user_context, void **cuda_context_ptr, bool create)
A custom method for the Halide runtime acquire a cuda context.
Definition: JITModule.h:117
Halide::LinkageType::Internal
@ Internal
Not visible externally, similar to 'static' linkage in C.
Halide::Internal::JITModule::argv_entrypoint_symbol
Symbol argv_entrypoint_symbol() const
Returns the Symbol structure for the argv wrapper routine corresponding to the entrypoint.
Halide::JITHandlers
A set of custom overrides of runtime functions.
Definition: JITModule.h:35
size_t
__SIZE_TYPE__ size_t
Definition: runtime_internal.h:31
Halide::Internal::JITModule::Symbol::address
void * address
Definition: JITModule.h:150
Halide::Internal::JITCache::get_compiled_jit_target
Target get_compiled_jit_target() const
Halide::Internal::JITModule::argv_function
argv_wrapper argv_function() const
Halide::Internal::JITCache
Definition: JITModule.h:286
Halide::Internal::JITErrorBuffer::str
std::string str() const
Halide::Internal::JITFuncCallContext::custom_error_handler
bool custom_error_handler
Definition: JITModule.h:323
Halide::Internal::JITModule::entrypoint_symbol
Symbol entrypoint_symbol() const
Returns the Symbol structure for the routine documented in main_function.
Halide::JITHandlers::custom_cuda_release_context
int32_t(* custom_cuda_release_context)(JITUserContext *user_context)
The Halide runtime calls this when it is done with a cuda context.
Definition: JITModule.h:121
Halide::Internal::JITModule::argv_wrapper
int(* argv_wrapper)(const void *const *args)
A slightly more type-safe wrapper around the raw halide module.
Definition: JITModule.h:211
Halide::Internal::JITFuncCallContext::context
JITUserContext * context
Definition: JITModule.h:322
int64_t
signed __INT64_TYPE__ int64_t
Definition: runtime_internal.h:22
Halide::Internal::JITModule::Symbol
Definition: JITModule.h:149
Halide::Internal::JITModule::exports
const std::map< std::string, Symbol > & exports() const
The exports map of a JITModule contains all symbols which are available to other JITModules which dep...
Halide::Internal::LoweredFunc
Definition of a lowered function.
Definition: Module.h:97
Type.h
Halide::Internal::JITSharedRuntime::memoization_cache_evict
static void memoization_cache_evict(uint64_t eviction_key)
Evict all cache entries that were tagged with the given eviction_key in the memoize scheduling direct...
Halide::Internal::JITModule::main_function
void * main_function() const
A pointer to the raw halide function.
Halide::JITUserContext::error_buffer
Internal::JITErrorBuffer * error_buffer
Definition: JITModule.h:137
Halide::Internal::JITFuncCallContext::JITFuncCallContext
JITFuncCallContext(JITUserContext *context, const JITHandlers &pipeline_handlers)
Halide::Internal::get_symbol_address
void * get_symbol_address(const char *s)
Halide::JITHandlers::custom_free
void(* custom_free)(JITUserContext *, void *)
Definition: JITModule.h:45
Halide::Internal::JITModule::find_symbol_by_name
Symbol find_symbol_by_name(const std::string &) const
Look up a symbol by name in this module or its dependencies.
Halide::Internal::JITSharedRuntime
Definition: JITModule.h:254
Halide::JITHandlers::custom_do_task
int(* custom_do_task)(JITUserContext *, int(*)(JITUserContext *, int, uint8_t *), int, uint8_t *)
A custom task handler to be called by the parallel for loop.
Definition: JITModule.h:63
Halide::Internal::JITCache::JITCache
JITCache()=default
HalideRuntime.h
Halide::Internal::JITErrorBuffer::buf
char buf[MaxBufSize]
Definition: JITModule.h:310
Halide::Internal::JITModule::reuse_device_allocations
void reuse_device_allocations(bool) const
See JITSharedRuntime::reuse_device_allocations.
Halide::JITHandlers::custom_do_par_for
int(* custom_do_par_for)(JITUserContext *, int(*)(JITUserContext *, int, uint8_t *), int, int, uint8_t *)
A custom parallel for loop launcher.
Definition: JITModule.h:85
Halide::Internal::JITCache::arguments
std::vector< Argument > arguments
Definition: JITModule.h:289
Halide::Internal::JITModule::make_trampolines_module
static JITModule make_trampolines_module(const Target &target, const std::map< std::string, JITExtern > &externs, const std::string &suffix, const std::vector< JITModule > &deps)
Take a list of JITExterns and generate trampoline functions which can be called dynamically via a fun...
Halide::Internal::JITErrorBuffer::concat
void concat(const char *message)
Halide::Internal::JITErrorBuffer::MaxBufSize
@ MaxBufSize
Definition: JITModule.h:309
WasmExecutor.h
Halide::Internal::JITErrorBuffer::handler
static void handler(JITUserContext *ctx, const char *message)
Halide::Internal::JITModule::Symbol::Symbol
Symbol(void *address)
Definition: JITModule.h:152
Halide::JITHandlers::custom_error
void(* custom_error)(JITUserContext *, const char *)
The error handler function that be called in the case of runtime errors during halide pipelines.
Definition: JITModule.h:89
Halide::Internal::WasmModule
Handle to compiled wasm code which can be called later.
Definition: WasmExecutor.h:32
IntrusivePtr.h
Halide::Internal::JITFuncCallContext::error_buffer
JITErrorBuffer error_buffer
Definition: JITModule.h:321
Halide::Internal::JITSharedRuntime::set_default_handlers
static JITHandlers set_default_handlers(const JITHandlers &handlers)
Halide::Target
A struct representing a target machine and os to generate code for.
Definition: Target.h:19
Halide::Internal::JITFuncCallContext
Definition: JITModule.h:320
Halide::Internal::JITCache::wasm_module
WasmModule wasm_module
Definition: JITModule.h:292
Halide::Internal::JITSharedRuntime::get
static std::vector< JITModule > get(llvm::Module *m, const Target &target, bool create=true)
Halide::Internal::JITModule
Definition: JITModule.h:146