13 struct CachedCompilation {
15 ModuleStateT module_state{};
22 static constexpr float kLoadFactor{.5f};
23 static constexpr int kInitialTableBits{7};
24 int log2_compilations_size{0};
25 CachedCompilation *compilations{
nullptr};
38 return (addr * (
uintptr_t)0x9E3779B97F4A7C15) >> (64 - bits);
40 return (addr * (
uintptr_t)0x9E3779B9) >> (32 - bits);
45 if (log2_compilations_size == 0) {
46 if (!resize_table(kInitialTableBits)) {
50 if ((count + 1) > (1 << log2_compilations_size) * kLoadFactor) {
51 if (!resize_table(log2_compilations_size + 1)) {
56 uintptr_t index = kernel_hash(entry.context, entry.kernel_id, log2_compilations_size);
57 for (
int i = 0; i < (1 << log2_compilations_size); i++) {
58 uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
59 if (compilations[effective_index].kernel_id <= kDeletedId) {
60 compilations[effective_index] = entry;
71 ModuleStateT *&module_state,
int increment) {
72 if (log2_compilations_size == 0) {
75 uintptr_t index = kernel_hash(context,
id, log2_compilations_size);
76 for (
int i = 0; i < (1 << log2_compilations_size); i++) {
77 uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
79 if (compilations[effective_index].kernel_id == kInvalidId) {
82 if (compilations[effective_index].context == context &&
83 compilations[effective_index].kernel_id ==
id) {
84 module_state = &compilations[effective_index].module_state;
86 compilations[effective_index].use_count += increment;
95 if (size_bits != log2_compilations_size) {
96 int new_size = (1 << size_bits);
97 int old_size = (1 << log2_compilations_size);
98 CachedCompilation *new_table = (CachedCompilation *)
malloc(new_size *
sizeof(CachedCompilation));
99 if (new_table ==
nullptr) {
103 memset(new_table, 0, new_size *
sizeof(CachedCompilation));
104 CachedCompilation *old_table = compilations;
105 compilations = new_table;
106 log2_compilations_size = size_bits;
109 for (
int32_t i = 0; i < old_size; i++) {
110 if (old_table[i].kernel_id != kInvalidId &&
111 old_table[i].kernel_id != kDeletedId) {
112 bool result = insert(old_table[i]);
123 template<
typename FreeModuleT>
124 void release_context_already_locked(
void *user_context,
bool all, ContextT context, FreeModuleT &f) {
129 for (
int i = 0; i < (1 << log2_compilations_size); i++) {
130 if (compilations[i].kernel_id > kInvalidId &&
131 (all || (compilations[i].context == context)) &&
132 compilations[i].use_count == 0) {
133 debug(user_context) <<
"Releasing cached compilation: " << compilations[i].module_state
134 <<
" id " << compilations[i].kernel_id
135 <<
" context " << compilations[i].context <<
"\n";
136 f(compilations[i].module_state);
137 compilations[i].module_state =
nullptr;
138 compilations[i].kernel_id = kDeletedId;
146 ScopedMutexLock lock_guard(&mutex);
149 ModuleStateT *mod_ptr;
150 if (find_internal(context,
id, mod_ptr, 0)) {
151 module_state = *mod_ptr;
157 template<
typename FreeModuleT>
158 void release_context(
void *user_context,
bool all, ContextT context, FreeModuleT &f) {
159 ScopedMutexLock lock_guard(&mutex);
161 release_context_already_locked(user_context, all, context, f);
164 template<
typename FreeModuleT>
166 ScopedMutexLock lock_guard(&mutex);
168 release_context_already_locked(user_context,
false, context, f);
171 template<
typename FreeModuleT>
173 ScopedMutexLock lock_guard(&mutex);
175 release_context_already_locked(user_context,
true,
nullptr, f);
179 compilations =
nullptr;
180 log2_compilations_size = 0;
184 template<
typename CompileModuleT,
typename... Args>
186 ContextT context, ModuleStateT &result,
189 ScopedMutexLock lock_guard(&mutex);
193 *id_ptr = unique_id++;
201 if (find_internal(context, *id_ptr, mod, 1)) {
207 ModuleStateT compiled_module = f(args...);
208 debug(user_context) <<
"Caching compiled kernel: " << compiled_module
209 <<
" id " << *id_ptr <<
" context " << context <<
"\n";
210 if (compiled_module ==
nullptr) {
214 if (!insert({context, compiled_module, *id_ptr, 1})) {
217 result = compiled_module;
222 void release_hold(
void *user_context, ContextT context,
void *state_ptr) {
223 ScopedMutexLock lock_guard(&mutex);
227 bool result = find_internal(context,
id, mod, -1);