Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
gpu_context_common.h
Go to the documentation of this file.
1#ifndef HALIDE_RUNTIME_GPU_CONTEXT_COMMON_H_
2#define HALIDE_RUNTIME_GPU_CONTEXT_COMMON_H_
3
4#include "HalideRuntime.h"
5#include "printer.h"
6#include "scoped_mutex_lock.h"
7
8namespace Halide {
9namespace Internal {
10
11template<typename ContextT, typename ModuleStateT>
13 struct CachedCompilation {
14 ContextT context{};
15 ModuleStateT module_state{};
16 uintptr_t kernel_id{0};
17 uintptr_t use_count{0};
18 };
19
20 halide_mutex mutex;
21
22 static constexpr float kLoadFactor{.5f};
23 static constexpr int kInitialTableBits{7};
24 int log2_compilations_size{0}; // number of bits in index into compilations table.
25 CachedCompilation *compilations{nullptr};
26 int count{0};
27
28 static constexpr uintptr_t kInvalidId{0};
29 static constexpr uintptr_t kDeletedId{1};
30
31 uintptr_t unique_id{2}; // zero is an invalid id
32
33 static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uintptr_t id, int bits) {
34 uintptr_t addr = (uintptr_t)context + id;
35 // Fibonacci hashing. The golden ratio is 1.9E3779B97F4A7C15F39...
36 // in hexadecimal.
37 if constexpr (sizeof(uintptr_t) >= 8) {
38 return (addr * (uintptr_t)0x9E3779B97F4A7C15) >> (64 - bits);
39 } else {
40 return (addr * (uintptr_t)0x9E3779B9) >> (32 - bits);
41 }
42 }
43
44 HALIDE_MUST_USE_RESULT bool insert(const CachedCompilation &entry) {
45 if (log2_compilations_size == 0) {
46 if (!resize_table(kInitialTableBits)) {
47 return false;
48 }
49 }
50 if ((count + 1) > (1 << log2_compilations_size) * kLoadFactor) {
51 if (!resize_table(log2_compilations_size + 1)) {
52 return false;
53 }
54 }
55 count += 1;
56 uintptr_t index = kernel_hash(entry.context, entry.kernel_id, log2_compilations_size);
57 for (int i = 0; i < (1 << log2_compilations_size); i++) {
58 uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
59 if (compilations[effective_index].kernel_id <= kDeletedId) {
60 compilations[effective_index] = entry;
61 return true;
62 }
63 }
64 // This is a logic error that should never occur. It means the table is
65 // full, but it should have been resized.
66 halide_debug_assert(nullptr, false);
67 return false;
68 }
69
70 HALIDE_MUST_USE_RESULT bool find_internal(ContextT context, uintptr_t id,
71 ModuleStateT *&module_state, int increment) {
72 if (log2_compilations_size == 0) {
73 return false;
74 }
75 uintptr_t index = kernel_hash(context, id, log2_compilations_size);
76 for (int i = 0; i < (1 << log2_compilations_size); i++) {
77 uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
78
79 if (compilations[effective_index].kernel_id == kInvalidId) {
80 return false;
81 }
82 if (compilations[effective_index].context == context &&
83 compilations[effective_index].kernel_id == id) {
84 module_state = &compilations[effective_index].module_state;
85 if (increment != 0) {
86 compilations[effective_index].use_count += increment;
87 }
88 return true;
89 }
90 }
91 return false;
92 }
93
94 HALIDE_MUST_USE_RESULT bool resize_table(int size_bits) {
95 if (size_bits != log2_compilations_size) {
96 int new_size = (1 << size_bits);
97 int old_size = (1 << log2_compilations_size);
98 CachedCompilation *new_table = (CachedCompilation *)malloc(new_size * sizeof(CachedCompilation));
99 if (new_table == nullptr) {
100 // signal error.
101 return false;
102 }
103 memset(new_table, 0, new_size * sizeof(CachedCompilation));
104 CachedCompilation *old_table = compilations;
105 compilations = new_table;
106 log2_compilations_size = size_bits;
107
108 if (count > 0) { // Mainly to catch empty initial table case
109 for (int32_t i = 0; i < old_size; i++) {
110 if (old_table[i].kernel_id != kInvalidId &&
111 old_table[i].kernel_id != kDeletedId) {
112 bool result = insert(old_table[i]);
113 halide_debug_assert(nullptr, result); // Resizing the table while resizing the table is a logic error.
114 (void)result;
115 }
116 }
117 }
118 free(old_table);
119 }
120 return true;
121 }
122
123 template<typename FreeModuleT>
124 void release_context_already_locked(void *user_context, bool all, ContextT context, FreeModuleT &f) {
125 if (count == 0) {
126 return;
127 }
128
129 for (int i = 0; i < (1 << log2_compilations_size); i++) {
130 if (compilations[i].kernel_id > kInvalidId &&
131 (all || (compilations[i].context == context)) &&
132 compilations[i].use_count == 0) {
133 debug(user_context) << "Releasing cached compilation: " << compilations[i].module_state
134 << " id " << compilations[i].kernel_id
135 << " context " << compilations[i].context << "\n";
136 f(compilations[i].module_state);
137 compilations[i].module_state = nullptr;
138 compilations[i].kernel_id = kDeletedId;
139 count--;
140 }
141 }
142 }
143
144public:
145 HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state) {
146 ScopedMutexLock lock_guard(&mutex);
147
148 uintptr_t id = (uintptr_t)state_ptr;
149 ModuleStateT *mod_ptr;
150 if (find_internal(context, id, mod_ptr, 0)) {
151 module_state = *mod_ptr;
152 return true;
153 }
154 return false;
155 }
156
157 template<typename FreeModuleT>
158 void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f) {
159 ScopedMutexLock lock_guard(&mutex);
160
161 release_context_already_locked(user_context, all, context, f);
162 }
163
164 template<typename FreeModuleT>
165 void delete_context(void *user_context, ContextT context, FreeModuleT &f) {
166 ScopedMutexLock lock_guard(&mutex);
167
168 release_context_already_locked(user_context, false, context, f);
169 }
170
171 template<typename FreeModuleT>
172 void release_all(void *user_context, FreeModuleT &f) {
173 ScopedMutexLock lock_guard(&mutex);
174
175 release_context_already_locked(user_context, true, nullptr, f);
176 // Some items may have been in use, so can't free.
177 if (count == 0) {
178 free(compilations);
179 compilations = nullptr;
180 log2_compilations_size = 0;
181 }
182 }
183
184 template<typename CompileModuleT, typename... Args>
186 ContextT context, ModuleStateT &result,
187 CompileModuleT f,
188 Args... args) {
189 ScopedMutexLock lock_guard(&mutex);
190
191 uintptr_t *id_ptr = (uintptr_t *)state_ptr_ptr;
192 if (*id_ptr == 0) {
193 *id_ptr = unique_id++;
194 if (unique_id == (uintptr_t)-1) {
195 // Sorry, out of ids
196 return false;
197 }
198 }
199
200 ModuleStateT *mod;
201 if (find_internal(context, *id_ptr, mod, 1)) {
202 result = *mod;
203 return true;
204 }
205
206 // TODO(zvookin): figure out the calling signature here...
207 ModuleStateT compiled_module = f(args...);
208 debug(user_context) << "Caching compiled kernel: " << compiled_module
209 << " id " << *id_ptr << " context " << context << "\n";
210 if (compiled_module == nullptr) {
211 return false;
212 }
213
214 if (!insert({context, compiled_module, *id_ptr, 1})) {
215 return false;
216 }
217 result = compiled_module;
218
219 return true;
220 }
221
222 void release_hold(void *user_context, ContextT context, void *state_ptr) {
223 ScopedMutexLock lock_guard(&mutex);
224
225 ModuleStateT *mod;
226 uintptr_t id = (uintptr_t)state_ptr;
227 bool result = find_internal(context, id, mod, -1);
228 halide_debug_assert(user_context, result); // Value must be in cache to be released
229 (void)result;
230 }
231};
232
233} // namespace Internal
234} // namespace Halide
235
236#endif // HALIDE_RUNTIME_GPU_CONTEXT_COMMON_H_
This file declares the routines used by Halide internally in its runtime.
#define HALIDE_MUST_USE_RESULT
HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state)
void release_hold(void *user_context, ContextT context, void *state_ptr)
void release_all(void *user_context, FreeModuleT &f)
void delete_context(void *user_context, ContextT context, FreeModuleT &f)
HALIDE_MUST_USE_RESULT bool kernel_state_setup(void *user_context, void **state_ptr_ptr, ContextT context, ModuleStateT &result, CompileModuleT f, Args... args)
void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f)
For optional debugging during codegen, use the debug class as follows:
Definition Debug.h:49
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
#define halide_debug_assert(user_context, cond)
halide_debug_assert() is like halide_assert(), but only expands into a check when DEBUG_RUNTIME is de...
__UINTPTR_TYPE__ uintptr_t
void * malloc(size_t)
signed __INT32_TYPE__ int32_t
#define ALWAYS_INLINE
void * memset(void *s, int val, size_t n)
void free(void *)
Cross-platform mutex.
void * user_context