Halide
JITModule.h
Go to the documentation of this file.
1 #ifndef HALIDE_JIT_MODULE_H
2 #define HALIDE_JIT_MODULE_H
3 
4 /** \file
5  * Defines the struct representing lifetime and dependencies of
6  * a JIT compiled halide pipeline
7  */
8 
9 #include <map>
10 #include <memory>
11 
12 #include "IntrusivePtr.h"
13 #include "Type.h"
14 #include "runtime/HalideRuntime.h"
15 
16 namespace llvm {
17 class Module;
18 }
19 
20 namespace Halide {
21 
22 struct ExternCFunction;
23 struct JITExtern;
24 struct Target;
25 class Module;
26 
27 namespace Internal {
28 
29 class JITModuleContents;
30 struct LoweredFunc;
31 
32 struct JITModule {
34 
35  struct Symbol {
36  void *address = nullptr;
37  Symbol() = default;
38  explicit Symbol(void *address)
39  : address(address) {
40  }
41  };
42 
43  JITModule();
44  JITModule(const Module &m, const LoweredFunc &fn,
45  const std::vector<JITModule> &dependencies = std::vector<JITModule>());
46 
47  /** Take a list of JITExterns and generate trampoline functions
48  * which can be called dynamically via a function pointer that
49  * takes an array of void *'s for each argument and the return
50  * value.
51  */
52  static JITModule make_trampolines_module(const Target &target,
53  const std::map<std::string, JITExtern> &externs,
54  const std::string &suffix,
55  const std::vector<JITModule> &deps);
56 
57  /** The exports map of a JITModule contains all symbols which are
58  * available to other JITModules which depend on this one. For
59  * runtime modules, this is all of the symbols exported from the
60  * runtime. For a JITted Func, it generally only contains the main
61  * result Func of the compilation, which takes its name directly
62  * from the Func declaration. One can also make a module which
63  * contains no code itself but is just an exports maps providing
64  * arbitrary pointers to functions or global variables to JITted
65  * code. */
66  const std::map<std::string, Symbol> &exports() const;
67 
68  /** A pointer to the raw halide function. Its true type depends
69  * on the Argument vector passed to CodeGen_LLVM::compile. Image
70  * parameters become (halide_buffer_t *), and scalar parameters become
71  * pointers to the appropriate values. The final argument is a
72  * pointer to the halide_buffer_t defining the output. This will be nullptr for
73  * a JITModule which has not yet been compiled or one that is not
74  * a Halide Func compilation at all. */
75  void *main_function() const;
76 
77  /** Returns the Symbol structure for the routine documented in
78  * main_function. Returning a Symbol allows access to the LLVM
79  * type as well as the address. The address and type will be nullptr
80  * if the module has not been compiled. */
81  Symbol entrypoint_symbol() const;
82 
83  /** Returns the Symbol structure for the argv wrapper routine
84  * corresponding to the entrypoint. The argv wrapper is callable
85  * via an array of void * pointers to the arguments for the
86  * call. Returning a Symbol allows access to the LLVM type as well
87  * as the address. The address and type will be nullptr if the module
88  * has not been compiled. */
89  Symbol argv_entrypoint_symbol() const;
90 
91  /** A slightly more type-safe wrapper around the raw halide
92  * module. Takes it arguments as an array of pointers that
93  * correspond to the arguments to \ref main_function . This will
94  * be nullptr for a JITModule which has not yet been compiled or one
95  * that is not a Halide Func compilation at all. */
96  // @{
97  typedef int (*argv_wrapper)(const void **args);
99  // @}
100 
101  /** Add another JITModule to the dependency chain. Dependencies
102  * are searched to resolve symbols not found in the current
103  * compilation unit while JITting. */
104  void add_dependency(JITModule &dep);
105  /** Registers a single Symbol as available to modules which depend
106  * on this one. The Symbol structure provides both the address and
107  * the LLVM type for the function, which allows type safe linkage of
108  * extenal routines. */
109  void add_symbol_for_export(const std::string &name, const Symbol &extern_symbol);
110  /** Registers a single function as available to modules which
111  * depend on this one. This routine converts the ExternSignature
112  * info into an LLVM type, which allows type safe linkage of
113  * external routines. */
114  void add_extern_for_export(const std::string &name,
115  const ExternCFunction &extern_c_function);
116 
117  /** Look up a symbol by name in this module or its dependencies. */
118  Symbol find_symbol_by_name(const std::string &) const;
119 
120  /** Take an llvm module and compile it. The requested exports will
121  be available via the exports method. */
122  void compile_module(std::unique_ptr<llvm::Module> mod,
123  const std::string &function_name, const Target &target,
124  const std::vector<JITModule> &dependencies = std::vector<JITModule>(),
125  const std::vector<std::string> &requested_exports = std::vector<std::string>());
126 
127  /** See JITSharedRuntime::memoization_cache_set_size */
128  void memoization_cache_set_size(int64_t size) const;
129 
130  /** See JITSharedRuntime::reuse_device_allocations */
131  void reuse_device_allocations(bool) const;
132 
133  /** Return true if compile_module has been called on this module. */
134  bool compiled() const;
135 };
136 
137 typedef int (*halide_task)(void *user_context, int, uint8_t *);
138 
139 struct JITHandlers {
140  void (*custom_print)(void *, const char *){nullptr};
141  void *(*custom_malloc)(void *, size_t){nullptr};
142  void (*custom_free)(void *, void *){nullptr};
143  int (*custom_do_task)(void *, halide_task, int, uint8_t *){nullptr};
144  int (*custom_do_par_for)(void *, halide_task, int, int, uint8_t *){nullptr};
145  void (*custom_error)(void *, const char *){nullptr};
146  int32_t (*custom_trace)(void *, const halide_trace_event_t *){nullptr};
147  void *(*custom_get_symbol)(const char *name){nullptr};
148  void *(*custom_load_library)(const char *name){nullptr};
149  void *(*custom_get_library_symbol)(void *lib, const char *name){nullptr};
150 };
151 
155 };
156 
158 public:
159  // Note only the first llvm::Module passed in here is used. The same shared runtime is used for all JIT.
160  static std::vector<JITModule> get(llvm::Module *m, const Target &target, bool create = true);
161  static void init_jit_user_context(JITUserContext &jit_user_context, void *user_context, const JITHandlers &handlers);
162  static JITHandlers set_default_handlers(const JITHandlers &handlers);
163 
164  /** Set the maximum number of bytes used by memoization caching.
165  * If you are compiling statically, you should include HalideRuntime.h
166  * and call halide_memoization_cache_set_size() instead.
167  */
168  static void memoization_cache_set_size(int64_t size);
169 
170  /** Set whether or not Halide may hold onto and reuse device
171  * allocations to avoid calling expensive device API allocation
172  * functions. If you are compiling statically, you should include
173  * HalideRuntime.h and call halide_reuse_device_allocations
174  * instead. */
175  static void reuse_device_allocations(bool);
176 
177  static void release_all();
178 };
179 
180 void *get_symbol_address(const char *s);
181 
182 } // namespace Internal
183 } // namespace Halide
184 
185 #endif
int32_t
signed __INT32_TYPE__ int32_t
Definition: runtime_internal.h:20
Halide::Internal::JITHandlers::custom_print
void(* custom_print)(void *, const char *)
Definition: JITModule.h:140
Halide::Internal::JITModule::add_dependency
void add_dependency(JITModule &dep)
Add another JITModule to the dependency chain.
Halide::Internal::JITModule::compiled
bool compiled() const
Return true if compile_module has been called on this module.
llvm
Definition: CodeGen_Internal.h:19
halide_trace_event_t
Definition: HalideRuntime.h:501
Halide::Internal::JITSharedRuntime::release_all
static void release_all()
Halide::Internal::JITModule::add_symbol_for_export
void add_symbol_for_export(const std::string &name, const Symbol &extern_symbol)
Registers a single Symbol as available to modules which depend on this one.
uint8_t
unsigned __INT8_TYPE__ uint8_t
Definition: runtime_internal.h:25
Halide::Internal::JITModule::JITModule
JITModule()
Halide::Internal::JITModule::Symbol::Symbol
Symbol()=default
Halide::Internal::JITHandlers::custom_error
void(* custom_error)(void *, const char *)
Definition: JITModule.h:145
Halide::Internal::JITHandlers::custom_trace
int32_t(* custom_trace)(void *, const halide_trace_event_t *)
Definition: JITModule.h:146
Halide::Internal::JITModule::argv_wrapper
int(* argv_wrapper)(const void **args)
A slightly more type-safe wrapper around the raw halide module.
Definition: JITModule.h:97
Halide::Internal::JITModule::add_extern_for_export
void add_extern_for_export(const std::string &name, const ExternCFunction &extern_c_function)
Registers a single function as available to modules which depend on this one.
Halide::Internal::JITSharedRuntime::memoization_cache_set_size
static void memoization_cache_set_size(int64_t size)
Set the maximum number of bytes used by memoization caching.
Halide::Internal::JITModule::compile_module
void compile_module(std::unique_ptr< llvm::Module > mod, const std::string &function_name, const Target &target, const std::vector< JITModule > &dependencies=std::vector< JITModule >(), const std::vector< std::string > &requested_exports=std::vector< std::string >())
Take an llvm module and compile it.
Halide::Internal::JITModule::jit_module
IntrusivePtr< JITModuleContents > jit_module
Definition: JITModule.h:33
Halide::Internal::JITModule::memoization_cache_set_size
void memoization_cache_set_size(int64_t size) const
See JITSharedRuntime::memoization_cache_set_size.
Halide::Internal::IntrusivePtr< JITModuleContents >
Halide::Module
A halide module.
Definition: Module.h:136
Halide::ExternCFunction
Definition: Pipeline.h:636
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AddAtomicMutex.h:21
Halide::Internal::JITSharedRuntime::reuse_device_allocations
static void reuse_device_allocations(bool)
Set whether or not Halide may hold onto and reuse device allocations to avoid calling expensive devic...
Halide::LinkageType::Internal
@ Internal
Not visible externally, similar to 'static' linkage in C.
Halide::Internal::JITModule::argv_entrypoint_symbol
Symbol argv_entrypoint_symbol() const
Returns the Symbol structure for the argv wrapper routine corresponding to the entrypoint.
size_t
__SIZE_TYPE__ size_t
Definition: runtime_internal.h:27
Halide::Internal::JITHandlers
Definition: JITModule.h:139
Halide::Internal::JITModule::Symbol::address
void * address
Definition: JITModule.h:36
Halide::Internal::JITModule::argv_function
argv_wrapper argv_function() const
Halide::Internal::IRMatcher::mod
HALIDE_ALWAYS_INLINE auto mod(A a, B b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1032
Halide::Internal::JITModule::entrypoint_symbol
Symbol entrypoint_symbol() const
Returns the Symbol structure for the routine documented in main_function.
Halide::Internal::JITSharedRuntime::init_jit_user_context
static void init_jit_user_context(JITUserContext &jit_user_context, void *user_context, const JITHandlers &handlers)
Halide::Internal::JITUserContext::user_context
void * user_context
Definition: JITModule.h:153
Halide::Internal::halide_task
int(* halide_task)(void *user_context, int, uint8_t *)
Definition: JITModule.h:137
int64_t
signed __INT64_TYPE__ int64_t
Definition: runtime_internal.h:18
Halide::Internal::JITModule::Symbol
Definition: JITModule.h:35
Halide::Internal::JITModule::exports
const std::map< std::string, Symbol > & exports() const
The exports map of a JITModule contains all symbols which are available to other JITModules which dep...
Halide::Internal::JITUserContext
Definition: JITModule.h:152
Halide::Internal::LoweredFunc
Definition of a lowered function.
Definition: Module.h:97
Type.h
Halide::Internal::JITModule::main_function
void * main_function() const
A pointer to the raw halide function.
Halide::Internal::JITHandlers::custom_free
void(* custom_free)(void *, void *)
Definition: JITModule.h:142
Halide::Internal::JITHandlers::custom_do_task
int(* custom_do_task)(void *, halide_task, int, uint8_t *)
Definition: JITModule.h:143
Halide::Internal::get_symbol_address
void * get_symbol_address(const char *s)
Halide::Internal::JITModule::find_symbol_by_name
Symbol find_symbol_by_name(const std::string &) const
Look up a symbol by name in this module or its dependencies.
Halide::Internal::JITSharedRuntime
Definition: JITModule.h:157
HalideRuntime.h
Halide::Internal::JITModule::reuse_device_allocations
void reuse_device_allocations(bool) const
See JITSharedRuntime::reuse_device_allocations.
Halide::Internal::JITUserContext::handlers
JITHandlers handlers
Definition: JITModule.h:154
Halide::Internal::JITModule::make_trampolines_module
static JITModule make_trampolines_module(const Target &target, const std::map< std::string, JITExtern > &externs, const std::string &suffix, const std::vector< JITModule > &deps)
Take a list of JITExterns and generate trampoline functions which can be called dynamically via a fun...
Halide::Internal::JITModule::Symbol::Symbol
Symbol(void *address)
Definition: JITModule.h:38
IntrusivePtr.h
Halide::Internal::JITSharedRuntime::set_default_handlers
static JITHandlers set_default_handlers(const JITHandlers &handlers)
user_context
void * user_context
Definition: printer.h:33
Halide::Target
A struct representing a target machine and os to generate code for.
Definition: Target.h:19
Halide::Internal::JITHandlers::custom_do_par_for
int(* custom_do_par_for)(void *, halide_task, int, int, uint8_t *)
Definition: JITModule.h:144
Halide::Internal::JITSharedRuntime::get
static std::vector< JITModule > get(llvm::Module *m, const Target &target, bool create=true)
Halide::Internal::JITModule
Definition: JITModule.h:32