Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
PerfectHashMap.h
Go to the documentation of this file.
1#ifndef PERFECT_HASH_MAP_H
2#define PERFECT_HASH_MAP_H
3
4#include <algorithm>
5#include <iostream>
6#include <vector>
7
8// Avoid a dependence on libHalide by defining a local variant we can use
10 const bool c;
11
13 : c(c) {
14 }
15
16 template<typename T>
18 if (!c) {
19 std::cerr << t;
20 }
21 return *this;
22 }
24 if (!c) {
25 exit(1);
26 }
27 }
28};
29
30// A specialized hash map used in the autoscheduler. It can only grow,
31// and it requires a perfect hash in the form of "id" and "max_id"
32// fields on each key. If the keys don't all have a consistent max_id,
33// or if you call make_large with the wrong max_id, you get UB. If you
34// think that might be happening, uncomment the assertions below for
35// some extra checking.
36
37template<typename K, typename T, int max_small_size = 4, typename phm_assert = PerfectHashMapAsserter>
39
40 using storage_type = std::vector<std::pair<const K *, T>>;
41
42 storage_type storage;
43
44 int occupied = 0;
45
46 // Equivalent to storage[i], but broken out into a separate method
47 // to allow for bounds checks when debugging this.
48 std::pair<const K *, T> &storage_bucket(int i) {
49 /*
50 phm_assert(i >= 0 && i < (int)storage.size())
51 << "Out of bounds access: " << i << " " << storage.size() << "\n";
52 */
53 return storage[i];
54 }
55
56 const std::pair<const K *, T> &storage_bucket(int i) const {
57 /*
58 phm_assert(i >= 0 && i < (int)storage.size())
59 << "Out of bounds access: " << i << " " << storage.size() << "\n";
60 */
61 return storage[i];
62 }
63
64 enum {
65 Empty = 0, // No storage allocated
66 Small = 1, // Storage is just an array of key/value pairs
67 Large = 2 // Storage is an array with empty slots, indexed by the 'id' field of each key
68 } state = Empty;
69
70 void upgrade_from_empty_to_small() {
71 storage.resize(max_small_size);
72 state = Small;
73 }
74
75 void upgrade_from_empty_to_large(int n) {
76 storage.resize(n);
77 state = Large;
78 }
79
80 void upgrade_from_small_to_large(int n) {
81 phm_assert(occupied <= max_small_size) << occupied << " " << max_small_size << "\n";
82 storage_type tmp(n);
83 state = Large;
84 tmp.swap(storage);
85 int o = occupied;
86 for (int i = 0; i < o; i++) {
87 emplace_large(tmp[i].first, std::move(tmp[i].second));
88 }
89 occupied = o;
90 }
91
92 // Methods when the map is in the empty state
93 T &emplace_empty(const K *n, T &&t) {
94 upgrade_from_empty_to_small();
95 storage_bucket(0).first = n;
96 storage_bucket(0).second = std::move(t);
97 occupied = 1;
98 return storage_bucket(0).second;
99 }
100
101 const T &get_empty(const K *n) const {
102 phm_assert(0) << "Calling get on an empty PerfectHashMap";
103 return unreachable_value();
104 }
105
106 T &get_empty(const K *n) {
107 phm_assert(0) << "Calling get on an empty PerfectHashMap";
108 return unreachable_value();
109 }
110
111 T &get_or_create_empty(const K *n) {
112 occupied = 1;
113 return emplace_empty(n, T());
114 }
115
116 bool contains_empty(const K *n) const {
117 return false;
118 }
119
120 // Methods when the map is in the small state
121 int find_index_small(const K *n) const {
122 int i;
123 for (i = 0; i < (int)occupied; i++) {
124 if (storage_bucket(i).first == n) {
125 return i;
126 }
127 }
128 return i;
129 }
130
131 T &emplace_small(const K *n, T &&t) {
132 int idx = find_index_small(n);
133 if (idx >= max_small_size) {
134 upgrade_from_small_to_large((int)(n->max_id));
135 return emplace_large(n, std::move(t));
136 }
137 auto &p = storage_bucket(idx);
138 if (p.first == nullptr) {
139 occupied++;
140 p.first = n;
141 }
142 p.second = std::move(t);
143 return p.second;
144 }
145
146 const T &get_small(const K *n) const {
147 int idx = find_index_small(n);
148 return storage_bucket(idx).second;
149 }
150
151 T &get_small(const K *n) {
152 int idx = find_index_small(n);
153 return storage_bucket(idx).second;
154 }
155
156 T &get_or_create_small(const K *n) {
157 int idx = find_index_small(n);
158 if (idx >= max_small_size) {
159 upgrade_from_small_to_large((int)(n->max_id));
160 return get_or_create_large(n);
161 }
162 auto &p = storage_bucket(idx);
163 if (p.first == nullptr) {
164 occupied++;
165 p.first = n;
166 }
167 return p.second;
168 }
169
170 bool contains_small(const K *n) const {
171 int idx = find_index_small(n);
172 return (idx < max_small_size) && (storage_bucket(idx).first == n);
173 }
174
175 // Methods when the map is in the large state
176 T &emplace_large(const K *n, T &&t) {
177 auto &p = storage_bucket(n->id);
178 if (!p.first) {
179 occupied++;
180 }
181 p.first = n;
182 p.second = std::move(t);
183 return p.second;
184 }
185
186 const T &get_large(const K *n) const {
187 return storage_bucket(n->id).second;
188 }
189
190 T &get_large(const K *n) {
191 return storage_bucket(n->id).second;
192 }
193
194 T &get_or_create_large(const K *n) {
195 auto &p = storage_bucket(n->id);
196 if (p.first == nullptr) {
197 occupied++;
198 p.first = n;
199 }
200 return storage_bucket(n->id).second;
201 }
202
203 bool contains_large(const K *n) const {
204 return storage_bucket(n->id).first != nullptr;
205 }
206
207 void check_key(const K *n) const {
208 /*
209 phm_assert(n->id >= 0 && n->id < n->max_id)
210 << "Invalid hash key: " << n->id << " " << n->max_id << "\n";
211 phm_assert(state != Large || (int)storage.size() == n->max_id)
212 << "Inconsistent key count: " << n->max_id << " vs " << storage.size() << "\n";
213 */
214 }
215
216 // Helpers used to pacify compilers
217 T &unreachable_value() {
218 return storage_bucket(0).second;
219 }
220
221 const T &unreachable_value() const {
222 return storage_bucket(0).second;
223 }
224
225public:
226 // Jump straight to the large state
227 void make_large(int n) {
228 if (state == Empty) {
229 upgrade_from_empty_to_large(n);
230 } else if (state == Small) {
231 upgrade_from_small_to_large(n);
232 }
233 }
234
235 T &emplace(const K *n, T &&t) {
236 check_key(n);
237 switch (state) {
238 case Empty:
239 return emplace_empty(n, std::move(t));
240 case Small:
241 return emplace_small(n, std::move(t));
242 case Large:
243 return emplace_large(n, std::move(t));
244 }
245 return unreachable_value();
246 }
247
248 T &insert(const K *n, const T &t) {
249 check_key(n);
250 T tmp(t);
251 switch (state) {
252 case Empty:
253 return emplace_empty(n, std::move(tmp));
254 case Small:
255 return emplace_small(n, std::move(tmp));
256 case Large:
257 return emplace_large(n, std::move(tmp));
258 }
259 return unreachable_value();
260 }
261
262 const T &get(const K *n) const {
263 check_key(n);
264 switch (state) {
265 case Empty:
266 return get_empty(n);
267 case Small:
268 return get_small(n);
269 case Large:
270 return get_large(n);
271 }
272 return unreachable_value();
273 }
274
275 T &get(const K *n) {
276 check_key(n);
277 switch (state) {
278 case Empty:
279 return get_empty(n);
280 case Small:
281 return get_small(n);
282 case Large:
283 return get_large(n);
284 }
285 return unreachable_value();
286 }
287
288 T &get_or_create(const K *n) {
289 check_key(n);
290 switch (state) {
291 case Empty:
292 return get_or_create_empty(n);
293 case Small:
294 return get_or_create_small(n);
295 case Large:
296 return get_or_create_large(n);
297 }
298 return unreachable_value();
299 }
300
301 bool contains(const K *n) const {
302 check_key(n);
303 switch (state) {
304 case Empty:
305 return contains_empty(n);
306 case Small:
307 return contains_small(n);
308 case Large:
309 return contains_large(n);
310 }
311 return false; // Unreachable
312 }
313
314 size_t size() const {
315 return occupied;
316 }
317
318 struct iterator {
319 std::pair<const K *, T> *iter, *end;
320
321 void operator++(int) {
322 do {
323 iter++;
324 } while (iter != end && iter->first == nullptr);
325 }
326
327 void operator++() {
328 (*this)++;
329 }
330
331 const K *key() const {
332 return iter->first;
333 }
334
335 T &value() const {
336 return iter->second;
337 }
338
339 bool operator!=(const iterator &other) const {
340 return iter != other.iter;
341 }
342
343 bool operator==(const iterator &other) const {
344 return iter == other.iter;
345 }
346
347 std::pair<const K *, T> &operator*() {
348 return *iter;
349 }
350 };
351
353 const std::pair<const K *, T> *iter, *end;
354
355 void operator++(int) {
356 do {
357 iter++;
358 } while (iter != end && iter->first == nullptr);
359 }
360
361 void operator++() {
362 (*this)++;
363 }
364
365 const K *key() const {
366 return iter->first;
367 }
368
369 const T &value() const {
370 return iter->second;
371 }
372
373 bool operator!=(const const_iterator &other) const {
374 return iter != other.iter;
375 }
376
377 bool operator==(const const_iterator &other) const {
378 return iter == other.iter;
379 }
380
381 const std::pair<const K *, T> &operator*() const {
382 return *iter;
383 }
384 };
385
387 if (state == Empty) {
388 return end();
389 }
390 iterator it;
391 it.iter = storage.data();
392 it.end = it.iter + storage.size();
393 if (it.key() == nullptr) {
394 it++;
395 }
396 phm_assert(it.iter == it.end || it.key());
397 return it;
398 }
399
401 iterator it;
402 it.iter = it.end = storage.data() + storage.size();
403 return it;
404 }
405
407 if (storage.empty()) {
408 return end();
409 }
411 it.iter = storage.data();
412 it.end = it.iter + storage.size();
413 if (it.key() == nullptr) {
414 it++;
415 }
416 phm_assert(it.iter == it.end || it.key());
417 return it;
418 }
419
422 it.iter = it.end = storage.data() + storage.size();
423 return it;
424 }
425};
426
427#endif
T & get_or_create(const K *n)
bool contains(const K *n) const
T & get(const K *n)
T & insert(const K *n, const T &t)
const T & get(const K *n) const
const_iterator begin() const
void make_large(int n)
size_t size() const
T & emplace(const K *n, T &&t)
const_iterator end() const
bool operator==(const const_iterator &other) const
const std::pair< const K *, T > & operator*() const
bool operator!=(const const_iterator &other) const
const std::pair< const K *, T > * iter
const std::pair< const K *, T > * end
bool operator!=(const iterator &other) const
std::pair< const K *, T > * iter
std::pair< const K *, T > & operator*()
bool operator==(const iterator &other) const
std::pair< const K *, T > * end
PerfectHashMapAsserter & operator<<(T &&t)