Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
JITModule.h
Go to the documentation of this file.
1#ifndef HALIDE_JIT_MODULE_H
2#define HALIDE_JIT_MODULE_H
3
4/** \file
5 * Defines the struct representing lifetime and dependencies of
6 * a JIT compiled halide pipeline
7 */
8
9#include <map>
10#include <memory>
11#include <vector>
12
13#include "IntrusivePtr.h"
14#include "Target.h"
15#include "Type.h"
16#include "WasmExecutor.h"
18
19namespace llvm {
20class Module;
21}
22
23namespace Halide {
24
25struct ExternCFunction;
26struct JITExtern;
27class Module;
28
29struct JITUserContext;
30
31/** A set of custom overrides of runtime functions. These only apply
32 * when JIT-compiling code. If you are doing AOT compilation, see
33 * HalideRuntime.h for instructions on how to replace runtime
34 * functions. */
36 /** Set the function called to print messages from the runtime. */
37 void (*custom_print)(JITUserContext *, const char *){nullptr};
38
39 /** A custom malloc and free for halide to use. Malloc should
40 * return 32-byte aligned chunks of memory, and it should be safe
41 * for Halide to read slightly out of bounds (up to 8 bytes before
42 * the start or beyond the end). */
43 // @{
44 void *(*custom_malloc)(JITUserContext *, size_t){nullptr};
45 void (*custom_free)(JITUserContext *, void *){nullptr};
46 // @}
47
48 /** A custom task handler to be called by the parallel for
49 * loop. It is useful to set this if you want to do some
50 * additional bookkeeping at the granularity of parallel
51 * tasks. The default implementation does this:
52 \code
53 extern "C" int halide_do_task(JITUserContext *user_context,
54 int (*f)(void *, int, uint8_t *),
55 int idx, uint8_t *state) {
56 return f(user_context, idx, state);
57 }
58 \endcode
59 *
60 * If you're trying to use a custom parallel runtime, you probably
61 * don't want to call this. See instead custom_do_par_for.
62 */
63 int (*custom_do_task)(JITUserContext *, int (*)(JITUserContext *, int, uint8_t *), int, uint8_t *){nullptr};
64
65 /** A custom parallel for loop launcher. Useful if your app
66 * already manages a thread pool. The default implementation is
67 * equivalent to this:
68 \code
69 extern "C" int halide_do_par_for(JITUserContext *user_context,
70 int (*f)(void *, int, uint8_t *),
71 int min, int extent, uint8_t *state) {
72 int exit_status = 0;
73 parallel for (int idx = min; idx < min+extent; idx++) {
74 int job_status = halide_do_task(user_context, f, idx, state);
75 if (job_status) exit_status = job_status;
76 }
77 return exit_status;
78 }
79 \endcode
80 *
81 * However, notwithstanding the above example code, if one task
82 * fails, we may skip over other tasks, and if two tasks return
83 * different error codes, we may select one arbitrarily to return.
84 */
85 int (*custom_do_par_for)(JITUserContext *, int (*)(JITUserContext *, int, uint8_t *), int, int, uint8_t *){nullptr};
86
87 /** The error handler function that be called in the case of
88 * runtime errors during halide pipelines. */
89 void (*custom_error)(JITUserContext *, const char *){nullptr};
90
91 /** A custom routine to call when tracing is enabled. Call this
92 * on the output Func of your pipeline. This then sets custom
93 * routines for the entire pipeline, not just calls to this
94 * Func. */
96
97 /** A method to use for Halide to resolve symbol names dynamically
98 * in the calling process or library from within the Halide
99 * runtime. Equivalent to dlsym with a null first argument. */
100 void *(*custom_get_symbol)(const char *name){nullptr};
101
102 /** A method to use for Halide to dynamically load libraries from
103 * within the runtime. Equivalent to dlopen. Returns a handle to
104 * the opened library. */
105 void *(*custom_load_library)(const char *name){nullptr};
106
107 /** A method to use for Halide to dynamically find a symbol within
108 * an opened library. Equivalent to dlsym. Takes a handle
109 * returned by custom_load_library as the first argument. */
110 void *(*custom_get_library_symbol)(void *lib, const char *name){nullptr};
111
112 /** A custom method for the Halide runtime acquire a cuda
113 * context. The cuda context is treated as a void * to avoid a
114 * dependence on the cuda headers. If the create argument is set
115 * to true, a context should be created if one does not already
116 * exist. */
117 int32_t (*custom_cuda_acquire_context)(JITUserContext *user_context, void **cuda_context_ptr, bool create){nullptr};
118
119 /** The Halide runtime calls this when it is done with a cuda
120 * context. The default implementation does nothing. */
122
123 /** A custom method for the Halide runtime to acquire a cuda
124 * stream to use. The cuda context and stream are both modelled
125 * as a void *, to avoid a dependence on the cuda headers. */
126 int32_t (*custom_cuda_get_stream)(JITUserContext *user_context, void *cuda_context, void **stream_ptr){nullptr};
127};
128
129namespace Internal {
130struct JITErrorBuffer;
131}
132
133/** A context to be passed to Pipeline::realize. Inherit from this to
134 * pass your own custom context object. Modify the handlers field to
135 * override runtime functions per-call to realize. */
140
141namespace Internal {
142
143class JITModuleContents;
144struct LoweredFunc;
145
146struct JITModule {
148
149 struct Symbol {
150 void *address = nullptr;
151 Symbol() = default;
152 explicit Symbol(void *address)
153 : address(address) {
154 }
155 };
156
158 JITModule(const Module &m, const LoweredFunc &fn,
159 const std::vector<JITModule> &dependencies = std::vector<JITModule>());
160
161 /** Take a list of JITExterns and generate trampoline functions
162 * which can be called dynamically via a function pointer that
163 * takes an array of void *'s for each argument and the return
164 * value.
165 */
167 const std::map<std::string, JITExtern> &externs,
168 const std::string &suffix,
169 const std::vector<JITModule> &deps);
170
171 /** The exports map of a JITModule contains all symbols which are
172 * available to other JITModules which depend on this one. For
173 * runtime modules, this is all of the symbols exported from the
174 * runtime. For a JITted Func, it generally only contains the main
175 * result Func of the compilation, which takes its name directly
176 * from the Func declaration. One can also make a module which
177 * contains no code itself but is just an exports maps providing
178 * arbitrary pointers to functions or global variables to JITted
179 * code. */
180 const std::map<std::string, Symbol> &exports() const;
181
182 /** A pointer to the raw halide function. Its true type depends
183 * on the Argument vector passed to CodeGen_LLVM::compile. Image
184 * parameters become (halide_buffer_t *), and scalar parameters become
185 * pointers to the appropriate values. The final argument is a
186 * pointer to the halide_buffer_t defining the output. This will be nullptr for
187 * a JITModule which has not yet been compiled or one that is not
188 * a Halide Func compilation at all. */
189 void *main_function() const;
190
191 /** Returns the Symbol structure for the routine documented in
192 * main_function. Returning a Symbol allows access to the LLVM
193 * type as well as the address. The address and type will be nullptr
194 * if the module has not been compiled. */
196
197 /** Returns the Symbol structure for the argv wrapper routine
198 * corresponding to the entrypoint. The argv wrapper is callable
199 * via an array of void * pointers to the arguments for the
200 * call. Returning a Symbol allows access to the LLVM type as well
201 * as the address. The address and type will be nullptr if the module
202 * has not been compiled. */
204
205 /** A slightly more type-safe wrapper around the raw halide
206 * module. Takes it arguments as an array of pointers that
207 * correspond to the arguments to \ref main_function . This will
208 * be nullptr for a JITModule which has not yet been compiled or one
209 * that is not a Halide Func compilation at all. */
210 // @{
211 typedef int (*argv_wrapper)(const void *const *args);
213 // @}
214
215 /** Add another JITModule to the dependency chain. Dependencies
216 * are searched to resolve symbols not found in the current
217 * compilation unit while JITting. */
219 /** Registers a single Symbol as available to modules which depend
220 * on this one. The Symbol structure provides both the address and
221 * the LLVM type for the function, which allows type safe linkage of
222 * extenal routines. */
223 void add_symbol_for_export(const std::string &name, const Symbol &extern_symbol);
224 /** Registers a single function as available to modules which
225 * depend on this one. This routine converts the ExternSignature
226 * info into an LLVM type, which allows type safe linkage of
227 * external routines. */
228 void add_extern_for_export(const std::string &name,
229 const ExternCFunction &extern_c_function);
230
231 /** Look up a symbol by name in this module or its dependencies. */
232 Symbol find_symbol_by_name(const std::string &) const;
233
234 /** Take an llvm module and compile it. The requested exports will
235 be available via the exports method. */
236 void compile_module(std::unique_ptr<llvm::Module> mod,
237 const std::string &function_name, const Target &target,
238 const std::vector<JITModule> &dependencies = std::vector<JITModule>(),
239 const std::vector<std::string> &requested_exports = std::vector<std::string>());
240
241 /** See JITSharedRuntime::memoization_cache_set_size */
243
244 /** See JITSharedRuntime::memoization_cache_evict */
245 void memoization_cache_evict(uint64_t eviction_key) const;
246
247 /** See JITSharedRuntime::reuse_device_allocations */
248 void reuse_device_allocations(bool) const;
249
250 /** See JITSharedRuntime::get_num_threads */
251 int get_num_threads() const;
252
253 /** See JITSharedRuntime::set_num_threads */
254 int set_num_threads(int) const;
255
256 /** Return true if compile_module has been called on this module. */
257 bool compiled() const;
258};
259
261public:
262 // Note only the first llvm::Module passed in here is used. The same shared runtime is used for all JIT.
263 static std::vector<JITModule> get(llvm::Module *m, const Target &target, bool create = true);
264 static void populate_jit_handlers(JITUserContext *jit_user_context, const JITHandlers &handlers);
266
267 /** Set the maximum number of bytes used by memoization caching.
268 * If you are compiling statically, you should include HalideRuntime.h
269 * and call halide_memoization_cache_set_size() instead.
270 */
272
273 /** Evict all cache entries that were tagged with the given
274 * eviction_key in the memoize scheduling directive. If you are
275 * compiling statically, you should include HalideRuntime.h and
276 * call halide_memoization_cache_evict() instead.
277 */
278 static void memoization_cache_evict(uint64_t eviction_key);
279
280 /** Set whether or not Halide may hold onto and reuse device
281 * allocations to avoid calling expensive device API allocation
282 * functions. If you are compiling statically, you should include
283 * HalideRuntime.h and call halide_reuse_device_allocations
284 * instead. */
285 static void reuse_device_allocations(bool);
286
287 static void release_all();
288
289 /** Get the number of threads in the Halide thread pool. Includes the
290 * calling thread. Meaningless if a custom_do_par_for has been set. */
291 static int get_num_threads();
292
293 /** Set the number of threads to use in the Halide thread pool, inclusive of
294 * the calling thread. Pass zero to use a reasonable default (typically the
295 * number of CPUs online). Calling this is meaningless if custom_do_par_for
296 * has been set. Halide may launch more threads than this if necessary to
297 * avoid deadlock when using the async scheduling directive. Returns the old
298 * number. */
299 static int set_num_threads(int);
300};
301
302void *get_symbol_address(const char *s);
303
304struct JITCache {
306 // Arguments for all inputs and outputs
307 std::vector<Argument> arguments;
308 std::map<std::string, JITExtern> jit_externs;
311
312 JITCache() = default;
314 std::vector<Argument> arguments,
315 std::map<std::string, JITExtern> jit_externs,
318
320
321 int call_jit_code(const void *const *args);
322
324};
325
327 enum { MaxBufSize = 4096 };
329 std::atomic<size_t> end{0};
330
331 void concat(const char *message);
332
333 std::string str() const;
334
335 static void handler(JITUserContext *ctx, const char *message);
336};
337
347
348} // namespace Internal
349} // namespace Halide
350
351#endif
This file declares the routines used by Halide internally in its runtime.
Support classes for reference-counting via intrusive shared pointers.
Defines the structure that describes a Halide target.
Defines halide types.
Support for running Halide-compiled Wasm code in-process.
static int get_num_threads()
Get the number of threads in the Halide thread pool.
static void memoization_cache_evict(uint64_t eviction_key)
Evict all cache entries that were tagged with the given eviction_key in the memoize scheduling direct...
static void memoization_cache_set_size(int64_t size)
Set the maximum number of bytes used by memoization caching.
static int set_num_threads(int)
Set the number of threads to use in the Halide thread pool, inclusive of the calling thread.
static JITHandlers set_default_handlers(const JITHandlers &handlers)
static void reuse_device_allocations(bool)
Set whether or not Halide may hold onto and reuse device allocations to avoid calling expensive devic...
static void populate_jit_handlers(JITUserContext *jit_user_context, const JITHandlers &handlers)
static std::vector< JITModule > get(llvm::Module *m, const Target &target, bool create=true)
A halide module.
Definition Module.h:142
void * get_symbol_address(const char *s)
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
signed __INT32_TYPE__ int32_t
unsigned __INT8_TYPE__ uint8_t
__SIZE_TYPE__ size_t
Intrusive shared pointers have a reference count (a RefCount object) stored in the class itself.
int call_jit_code(const void *const *args)
Target get_compiled_jit_target() const
std::map< std::string, JITExtern > jit_externs
Definition JITModule.h:308
std::vector< Argument > arguments
Definition JITModule.h:307
JITCache(Target jit_target, std::vector< Argument > arguments, std::map< std::string, JITExtern > jit_externs, JITModule jit_module, WasmModule wasm_module)
void finish_profiling(JITUserContext *context)
void concat(const char *message)
static void handler(JITUserContext *ctx, const char *message)
std::atomic< size_t > end
Definition JITModule.h:329
JITFuncCallContext(JITUserContext *context, const JITHandlers &pipeline_handlers)
void memoization_cache_evict(uint64_t eviction_key) const
See JITSharedRuntime::memoization_cache_evict.
int set_num_threads(int) const
See JITSharedRuntime::set_num_threads.
void memoization_cache_set_size(int64_t size) const
See JITSharedRuntime::memoization_cache_set_size.
void add_symbol_for_export(const std::string &name, const Symbol &extern_symbol)
Registers a single Symbol as available to modules which depend on this one.
void compile_module(std::unique_ptr< llvm::Module > mod, const std::string &function_name, const Target &target, const std::vector< JITModule > &dependencies=std::vector< JITModule >(), const std::vector< std::string > &requested_exports=std::vector< std::string >())
Take an llvm module and compile it.
int(*) argv_wrapper(const void *const *args)
A slightly more type-safe wrapper around the raw halide module.
Definition JITModule.h:211
int get_num_threads() const
See JITSharedRuntime::get_num_threads.
void add_extern_for_export(const std::string &name, const ExternCFunction &extern_c_function)
Registers a single function as available to modules which depend on this one.
const std::map< std::string, Symbol > & exports() const
The exports map of a JITModule contains all symbols which are available to other JITModules which dep...
void reuse_device_allocations(bool) const
See JITSharedRuntime::reuse_device_allocations.
void add_dependency(JITModule &dep)
Add another JITModule to the dependency chain.
Symbol find_symbol_by_name(const std::string &) const
Look up a symbol by name in this module or its dependencies.
static JITModule make_trampolines_module(const Target &target, const std::map< std::string, JITExtern > &externs, const std::string &suffix, const std::vector< JITModule > &deps)
Take a list of JITExterns and generate trampoline functions which can be called dynamically via a fun...
void * main_function() const
A pointer to the raw halide function.
Symbol argv_entrypoint_symbol() const
Returns the Symbol structure for the argv wrapper routine corresponding to the entrypoint.
bool compiled() const
Return true if compile_module has been called on this module.
Symbol entrypoint_symbol() const
Returns the Symbol structure for the routine documented in main_function.
argv_wrapper argv_function() const
IntrusivePtr< JITModuleContents > jit_module
Definition JITModule.h:147
JITModule(const Module &m, const LoweredFunc &fn, const std::vector< JITModule > &dependencies=std::vector< JITModule >())
Definition of a lowered function.
Definition Module.h:101
Handle to compiled wasm code which can be called later.
A set of custom overrides of runtime functions.
Definition JITModule.h:35
int(* custom_do_par_for)(JITUserContext *, int(*)(JITUserContext *, int, uint8_t *), int, int, uint8_t *)
A custom parallel for loop launcher.
Definition JITModule.h:85
int32_t(* custom_cuda_acquire_context)(JITUserContext *user_context, void **cuda_context_ptr, bool create)
A custom method for the Halide runtime acquire a cuda context.
Definition JITModule.h:117
int32_t(* custom_cuda_get_stream)(JITUserContext *user_context, void *cuda_context, void **stream_ptr)
A custom method for the Halide runtime to acquire a cuda stream to use.
Definition JITModule.h:126
void(* custom_error)(JITUserContext *, const char *)
The error handler function that be called in the case of runtime errors during halide pipelines.
Definition JITModule.h:89
int32_t(* custom_cuda_release_context)(JITUserContext *user_context)
The Halide runtime calls this when it is done with a cuda context.
Definition JITModule.h:121
void(* custom_free)(JITUserContext *, void *)
Definition JITModule.h:45
int32_t(* custom_trace)(JITUserContext *, const halide_trace_event_t *)
A custom routine to call when tracing is enabled.
Definition JITModule.h:95
int(* custom_do_task)(JITUserContext *, int(*)(JITUserContext *, int, uint8_t *), int, uint8_t *)
A custom task handler to be called by the parallel for loop.
Definition JITModule.h:63
void(* custom_print)(JITUserContext *, const char *)
Set the function called to print messages from the runtime.
Definition JITModule.h:37
A context to be passed to Pipeline::realize.
Definition JITModule.h:136
Internal::JITErrorBuffer * error_buffer
Definition JITModule.h:137
A struct representing a target machine and os to generate code for.
Definition Target.h:19
void * user_context