Halide
gpu_context_common.h
Go to the documentation of this file.
1 #ifndef HALIDE_RUNTIME_GPU_CONTEXT_COMMON_H_
2 #define HALIDE_RUNTIME_GPU_CONTEXT_COMMON_H_
3 
4 #include "HalideRuntime.h"
5 #include "printer.h"
6 #include "scoped_mutex_lock.h"
7 
8 namespace Halide {
9 namespace Internal {
10 
11 template<typename ContextT, typename ModuleStateT>
13  struct CachedCompilation {
14  ContextT context{};
15  ModuleStateT module_state{};
16  uintptr_t kernel_id{0};
17  uintptr_t use_count{0};
18  };
19 
20  halide_mutex mutex;
21 
22  static constexpr float kLoadFactor{.5f};
23  static constexpr int kInitialTableBits{7};
24  int log2_compilations_size{0}; // number of bits in index into compilations table.
25  CachedCompilation *compilations{nullptr};
26  int count{0};
27 
28  static constexpr uintptr_t kInvalidId{0};
29  static constexpr uintptr_t kDeletedId{1};
30 
31  uintptr_t unique_id{2}; // zero is an invalid id
32 
33  static ALWAYS_INLINE uintptr_t kernel_hash(ContextT context, uintptr_t id, int bits) {
34  uintptr_t addr = (uintptr_t)context + id;
35  // Fibonacci hashing. The golden ratio is 1.9E3779B97F4A7C15F39...
36  // in hexadecimal.
37  if constexpr (sizeof(uintptr_t) >= 8) {
38  return (addr * (uintptr_t)0x9E3779B97F4A7C15) >> (64 - bits);
39  } else {
40  return (addr * (uintptr_t)0x9E3779B9) >> (32 - bits);
41  }
42  }
43 
44  HALIDE_MUST_USE_RESULT bool insert(const CachedCompilation &entry) {
45  if (log2_compilations_size == 0) {
46  if (!resize_table(kInitialTableBits)) {
47  return false;
48  }
49  }
50  if ((count + 1) > (1 << log2_compilations_size) * kLoadFactor) {
51  if (!resize_table(log2_compilations_size + 1)) {
52  return false;
53  }
54  }
55  count += 1;
56  uintptr_t index = kernel_hash(entry.context, entry.kernel_id, log2_compilations_size);
57  for (int i = 0; i < (1 << log2_compilations_size); i++) {
58  uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
59  if (compilations[effective_index].kernel_id <= kDeletedId) {
60  compilations[effective_index] = entry;
61  return true;
62  }
63  }
64  // This is a logic error that should never occur. It means the table is
65  // full, but it should have been resized.
66  halide_debug_assert(nullptr, false);
67  return false;
68  }
69 
70  HALIDE_MUST_USE_RESULT bool find_internal(ContextT context, uintptr_t id,
71  ModuleStateT *&module_state, int increment) {
72  if (log2_compilations_size == 0) {
73  return false;
74  }
75  uintptr_t index = kernel_hash(context, id, log2_compilations_size);
76  for (int i = 0; i < (1 << log2_compilations_size); i++) {
77  uintptr_t effective_index = (index + i) & ((1 << log2_compilations_size) - 1);
78 
79  if (compilations[effective_index].kernel_id == kInvalidId) {
80  return false;
81  }
82  if (compilations[effective_index].context == context &&
83  compilations[effective_index].kernel_id == id) {
84  module_state = &compilations[effective_index].module_state;
85  if (increment != 0) {
86  compilations[effective_index].use_count += increment;
87  }
88  return true;
89  }
90  }
91  return false;
92  }
93 
94  HALIDE_MUST_USE_RESULT bool resize_table(int size_bits) {
95  if (size_bits != log2_compilations_size) {
96  int new_size = (1 << size_bits);
97  int old_size = (1 << log2_compilations_size);
98  CachedCompilation *new_table = (CachedCompilation *)malloc(new_size * sizeof(CachedCompilation));
99  if (new_table == nullptr) {
100  // signal error.
101  return false;
102  }
103  memset(new_table, 0, new_size * sizeof(CachedCompilation));
104  CachedCompilation *old_table = compilations;
105  compilations = new_table;
106  log2_compilations_size = size_bits;
107 
108  if (count > 0) { // Mainly to catch empty initial table case
109  for (int32_t i = 0; i < old_size; i++) {
110  if (old_table[i].kernel_id != kInvalidId &&
111  old_table[i].kernel_id != kDeletedId) {
112  bool result = insert(old_table[i]);
113  halide_debug_assert(nullptr, result); // Resizing the table while resizing the table is a logic error.
114  (void)result;
115  }
116  }
117  }
118  free(old_table);
119  }
120  return true;
121  }
122 
123  template<typename FreeModuleT>
124  void release_context_already_locked(void *user_context, bool all, ContextT context, FreeModuleT &f) {
125  if (count == 0) {
126  return;
127  }
128 
129  for (int i = 0; i < (1 << log2_compilations_size); i++) {
130  if (compilations[i].kernel_id > kInvalidId &&
131  (all || (compilations[i].context == context)) &&
132  compilations[i].use_count == 0) {
133  debug(user_context) << "Releasing cached compilation: " << compilations[i].module_state
134  << " id " << compilations[i].kernel_id
135  << " context " << compilations[i].context << "\n";
136  f(compilations[i].module_state);
137  compilations[i].module_state = nullptr;
138  compilations[i].kernel_id = kDeletedId;
139  count--;
140  }
141  }
142  }
143 
144 public:
145  HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state) {
146  ScopedMutexLock lock_guard(&mutex);
147 
148  uintptr_t id = (uintptr_t)state_ptr;
149  ModuleStateT *mod_ptr;
150  if (find_internal(context, id, mod_ptr, 0)) {
151  module_state = *mod_ptr;
152  return true;
153  }
154  return false;
155  }
156 
157  template<typename FreeModuleT>
158  void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f) {
159  ScopedMutexLock lock_guard(&mutex);
160 
161  release_context_already_locked(user_context, all, context, f);
162  }
163 
164  template<typename FreeModuleT>
165  void delete_context(void *user_context, ContextT context, FreeModuleT &f) {
166  ScopedMutexLock lock_guard(&mutex);
167 
168  release_context_already_locked(user_context, false, context, f);
169  }
170 
171  template<typename FreeModuleT>
172  void release_all(void *user_context, FreeModuleT &f) {
173  ScopedMutexLock lock_guard(&mutex);
174 
175  release_context_already_locked(user_context, true, nullptr, f);
176  // Some items may have been in use, so can't free.
177  if (count == 0) {
178  free(compilations);
179  compilations = nullptr;
180  log2_compilations_size = 0;
181  }
182  }
183 
184  template<typename CompileModuleT, typename... Args>
185  HALIDE_MUST_USE_RESULT bool kernel_state_setup(void *user_context, void **state_ptr_ptr,
186  ContextT context, ModuleStateT &result,
187  CompileModuleT f,
188  Args... args) {
189  ScopedMutexLock lock_guard(&mutex);
190 
191  uintptr_t *id_ptr = (uintptr_t *)state_ptr_ptr;
192  if (*id_ptr == 0) {
193  *id_ptr = unique_id++;
194  if (unique_id == (uintptr_t)-1) {
195  // Sorry, out of ids
196  return false;
197  }
198  }
199 
200  ModuleStateT *mod;
201  if (find_internal(context, *id_ptr, mod, 1)) {
202  result = *mod;
203  return true;
204  }
205 
206  // TODO(zvookin): figure out the calling signature here...
207  ModuleStateT compiled_module = f(args...);
208  debug(user_context) << "Caching compiled kernel: " << compiled_module
209  << " id " << *id_ptr << " context " << context << "\n";
210  if (compiled_module == nullptr) {
211  return false;
212  }
213 
214  if (!insert({context, compiled_module, *id_ptr, 1})) {
215  return false;
216  }
217  result = compiled_module;
218 
219  return true;
220  }
221 
222  void release_hold(void *user_context, ContextT context, void *state_ptr) {
223  ScopedMutexLock lock_guard(&mutex);
224 
225  ModuleStateT *mod;
226  uintptr_t id = (uintptr_t)state_ptr;
227  bool result = find_internal(context, id, mod, -1);
228  halide_debug_assert(user_context, result); // Value must be in cache to be released
229  (void)result;
230  }
231 };
232 
233 } // namespace Internal
234 } // namespace Halide
235 
236 #endif // HALIDE_RUNTIME_GPU_CONTEXT_COMMON_H_
int32_t
signed __INT32_TYPE__ int32_t
Definition: runtime_internal.h:24
Halide::Internal::GPUCompilationCache
Definition: gpu_context_common.h:12
Halide::Internal::IRMatcher::mod
HALIDE_ALWAYS_INLINE auto mod(A &&a, B &&b) -> decltype(IRMatcher::operator%(a, b))
Definition: IRMatch.h:1066
Halide::Internal::GPUCompilationCache::release_all
void release_all(void *user_context, FreeModuleT &f)
Definition: gpu_context_common.h:172
halide_debug_assert
#define halide_debug_assert(user_context, cond)
halide_debug_assert() is like halide_assert(), but only expands into a check when DEBUG_RUNTIME is de...
Definition: runtime_internal.h:281
Halide::Internal::GPUCompilationCache::kernel_state_setup
HALIDE_MUST_USE_RESULT bool kernel_state_setup(void *user_context, void **state_ptr_ptr, ContextT context, ModuleStateT &result, CompileModuleT f, Args... args)
Definition: gpu_context_common.h:185
scoped_mutex_lock.h
uintptr_t
__UINTPTR_TYPE__ uintptr_t
Definition: runtime_internal.h:73
malloc
void * malloc(size_t)
Halide::Internal::GPUCompilationCache::release_context
void release_context(void *user_context, bool all, ContextT context, FreeModuleT &f)
Definition: gpu_context_common.h:158
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AbstractGenerator.h:19
Halide::LinkageType::Internal
@ Internal
Not visible externally, similar to 'static' linkage in C.
memset
void * memset(void *s, int val, size_t n)
Halide::Internal::GPUCompilationCache::release_hold
void release_hold(void *user_context, ContextT context, void *state_ptr)
Definition: gpu_context_common.h:222
printer.h
HALIDE_MUST_USE_RESULT
#define HALIDE_MUST_USE_RESULT
Definition: HalideRuntime.h:56
Halide::Internal::debug
For optional debugging during codegen, use the debug class as follows:
Definition: Debug.h:49
halide_mutex
Cross-platform mutex.
Definition: HalideRuntime.h:166
Halide::Internal::Autoscheduler::all
bool all(const vector< int > &v)
ALWAYS_INLINE
#define ALWAYS_INLINE
Definition: runtime_internal.h:55
HalideRuntime.h
free
void free(void *)
Halide::Internal::GPUCompilationCache::lookup
HALIDE_MUST_USE_RESULT bool lookup(ContextT context, void *state_ptr, ModuleStateT &module_state)
Definition: gpu_context_common.h:145
Halide::Internal::GPUCompilationCache::delete_context
void delete_context(void *user_context, ContextT context, FreeModuleT &f)
Definition: gpu_context_common.h:165