Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
SpirvIR.h
Go to the documentation of this file.
1#ifndef HALIDE_SPIRV_IR_H
2#define HALIDE_SPIRV_IR_H
3
4/** \file
5 * Defines methods for constructing and encoding instructions into the Khronos
6 * format specification known as the Standard Portable Intermediate Representation
7 * for Vulkan (SPIR-V). These class interfaces adopt Halide's conventions for its
8 * own IR, but is implemented as a stand-alone optional component that can be
9 * enabled as required for certain runtimes (eg Vulkan).
10 *
11 * NOTE: This file is only used internally for CodeGen! *DO NOT* add this file
12 * to the list of exported Halide headers in the src/CMakeFiles.txt or the
13 * top level Makefile.
14 */
15#ifdef WITH_SPIRV
16
17#include <map>
18#include <set>
19#include <stack>
20#include <unordered_map>
21#include <vector>
22
23#include "IntrusivePtr.h"
24#include "Type.h"
25
26#include <spirv/unified1/GLSL.std.450.h> // GLSL extended instructions for common intrinsics
27#include <spirv/unified1/spirv.h> // Use v1.6 headers but only use the minimal viable format version (for maximum compatiblity)
28
29namespace Halide {
30namespace Internal {
31
32/** Precision requirment for return values */
33enum SpvPrecision {
34 SpvFullPrecision,
35 SpvRelaxedPrecision,
36};
37
38/** Scope qualifiers for Execution & Memory operations */
39enum SpvScope {
40 SpvCrossDeviceScope = 0,
41 SpvDeviceScope = 1,
42 SpvWorkgroupScope = 2,
43 SpvSubgroupScope = 3,
44 SpvInvocationScope = 4
45};
46
47/** Specific types of predefined constants */
48enum SpvPredefinedConstant {
49 SpvNullConstant,
50 SpvTrueConstant,
51 SpvFalseConstant,
52};
53
54/** Specific types of SPIR-V object ids */
55enum SpvKind {
56 SpvInvalidItem,
57 SpvTypeId,
58 SpvVoidTypeId,
59 SpvBoolTypeId,
60 SpvIntTypeId,
61 SpvUIntTypeId,
62 SpvFloatTypeId,
63 SpvVectorTypeId,
64 SpvArrayTypeId,
65 SpvRuntimeArrayTypeId,
66 SpvStringTypeId,
67 SpvPointerTypeId,
68 SpvStructTypeId,
69 SpvFunctionTypeId,
70 SpvAccessChainId,
71 SpvConstantId,
72 SpvBoolConstantId,
73 SpvIntConstantId,
74 SpvFloatConstantId,
75 SpvStringConstantId,
76 SpvCompositeConstantId,
77 SpvResultId,
78 SpvVariableId,
79 SpvInstructionId,
80 SpvFunctionId,
81 SpvBlockId,
82 SpvLabelId,
83 SpvParameterId,
84 SpvImportId,
85 SpvModuleId,
86 SpvUnknownItem,
87};
88
89/** Specific types of SPIR-V operand types */
90enum SpvValueType {
91 SpvInvalidValueType,
92 SpvOperandId,
93 SpvBitMaskLiteral,
94 SpvIntegerLiteral,
95 SpvIntegerData,
96 SpvFloatData,
97 SpvStringData,
98 SpvUnknownValueType
99};
100
101/** SPIR-V requires all IDs to be 32-bit unsigned integers */
102using SpvId = uint32_t;
103using SpvBinary = std::vector<uint32_t>;
104
105static constexpr SpvStorageClass SpvInvalidStorageClass = SpvStorageClassMax; // sentinel for invalid storage class
106static constexpr SpvId SpvInvalidId = SpvId(-1);
107static constexpr SpvId SpvNoResult = 0;
108static constexpr SpvId SpvNoType = 0;
109
110/** Pre-declarations for SPIR-V IR classes */
111class SpvModule;
112class SpvFunction;
113class SpvBlock;
114class SpvInstruction;
115class SpvBuilder;
116class SpvContext;
117struct SpvFactory;
118
119/** Pre-declarations for SPIR-V IR data structures */
120struct SpvModuleContents;
121struct SpvFunctionContents;
122struct SpvBlockContents;
123struct SpvInstructionContents;
124
125/** Intrusive pointer types for SPIR-V IR data */
126using SpvModuleContentsPtr = IntrusivePtr<SpvModuleContents>;
127using SpvFunctionContentsPtr = IntrusivePtr<SpvFunctionContents>;
128using SpvBlockContentsPtr = IntrusivePtr<SpvBlockContents>;
129using SpvInstructionContentsPtr = IntrusivePtr<SpvInstructionContents>;
130
131/** General interface for representing a SPIR-V Instruction */
132class SpvInstruction {
133public:
134 using LiteralValue = std::pair<uint32_t, SpvValueType>;
135 using Immediates = std::vector<LiteralValue>;
136 using Operands = std::vector<SpvId>;
137 using ValueTypes = std::vector<SpvValueType>;
138
139 SpvInstruction() = default;
140 ~SpvInstruction();
141
142 SpvInstruction(const SpvInstruction &) = default;
143 SpvInstruction &operator=(const SpvInstruction &) = default;
144 SpvInstruction(SpvInstruction &&) = default;
145 SpvInstruction &operator=(SpvInstruction &&) = default;
146
147 void set_result_id(SpvId id);
148 void set_type_id(SpvId id);
149 void set_op_code(SpvOp opcode);
150 void add_operand(SpvId id);
151 void add_operands(const Operands &operands);
152 void add_immediate(SpvId id, SpvValueType type);
153 void add_immediates(const Immediates &Immediates);
154 void add_data(uint32_t bytes, const void *data, SpvValueType type);
155 void add_string(const std::string &str);
156
157 template<typename T>
158 void append(const T &operands_or_immediates_or_strings);
159
160 SpvId result_id() const;
161 SpvId type_id() const;
162 SpvOp op_code() const;
163 SpvId operand(uint32_t index) const;
164 const void *data(uint32_t index = 0) const;
165 SpvValueType value_type(uint32_t index) const;
166 const Operands &operands() const;
167
168 bool has_type() const;
169 bool has_result() const;
170 bool is_defined() const;
171 bool is_immediate(uint32_t index) const;
172 uint32_t length() const;
173 void check_defined() const;
174 void clear();
175
176 void encode(SpvBinary &binary) const;
177
178 static SpvInstruction make(SpvOp op_code);
179
180protected:
181 SpvInstructionContentsPtr contents;
182};
183
184/** General interface for representing a SPIR-V Block */
185class SpvBlock {
186public:
187 using Instructions = std::vector<SpvInstruction>;
188 using Variables = std::vector<SpvInstruction>;
189 using Blocks = std::vector<SpvBlock>;
190
191 SpvBlock() = default;
192 ~SpvBlock();
193
194 SpvBlock(const SpvBlock &) = default;
195 SpvBlock &operator=(const SpvBlock &) = default;
196 SpvBlock(SpvBlock &&) = default;
197 SpvBlock &operator=(SpvBlock &&) = default;
198
199 void add_instruction(SpvInstruction inst);
200 void add_variable(SpvInstruction var);
201 const Instructions &instructions() const;
202 const Variables &variables() const;
203 bool is_reachable() const;
204 bool is_terminated() const;
205 bool is_defined() const;
206 SpvId id() const;
207 void check_defined() const;
208 void clear();
209
210 void encode(SpvBinary &binary) const;
211
212 static SpvBlock make(SpvId block_id);
213
214protected:
215 SpvBlockContentsPtr contents;
216};
217
218/** General interface for representing a SPIR-V Function */
219class SpvFunction {
220public:
221 using Blocks = std::vector<SpvBlock>;
222 using Parameters = std::vector<SpvInstruction>;
223
224 SpvFunction() = default;
225 ~SpvFunction();
226
227 SpvFunction(const SpvFunction &) = default;
228 SpvFunction &operator=(const SpvFunction &) = default;
229 SpvFunction(SpvFunction &&) = default;
230 SpvFunction &operator=(SpvFunction &&) = default;
231
232 SpvBlock create_block(SpvId block_id);
233 void add_block(SpvBlock block);
234 void add_parameter(SpvInstruction param);
235 void set_return_precision(SpvPrecision precision);
236 void set_parameter_precision(uint32_t index, SpvPrecision precision);
237 bool is_defined() const;
238 void clear();
239
240 const Blocks &blocks() const;
241 SpvBlock entry_block() const;
242 SpvBlock tail_block() const;
243 SpvPrecision return_precision() const;
244 const Parameters &parameters() const;
245 SpvPrecision parameter_precision(uint32_t index) const;
246 uint32_t parameter_count() const;
247 uint32_t control_mask() const;
248 SpvInstruction declaration() const;
249 SpvId return_type_id() const;
250 SpvId type_id() const;
251 SpvId id() const;
252 void check_defined() const;
253
254 void encode(SpvBinary &binary) const;
255
256 static SpvFunction make(SpvId func_id, SpvId func_type_id, SpvId return_type_id, uint32_t control_mask = SpvFunctionControlMaskNone);
257
258protected:
259 SpvFunctionContentsPtr contents;
260};
261
262/** General interface for representing a SPIR-V code module */
263class SpvModule {
264public:
265 using ImportDefinition = std::pair<SpvId, std::string>;
266 using ImportNames = std::vector<std::string>;
267 using EntryPointNames = std::vector<std::string>;
268 using Instructions = std::vector<SpvInstruction>;
269 using Functions = std::vector<SpvFunction>;
270 using Capabilities = std::vector<SpvCapability>;
271 using Extensions = std::vector<std::string>;
272 using Imports = std::vector<ImportDefinition>;
273
274 SpvModule() = default;
275 ~SpvModule();
276
277 SpvModule(const SpvModule &) = default;
278 SpvModule &operator=(const SpvModule &) = default;
279 SpvModule(SpvModule &&) = default;
280 SpvModule &operator=(SpvModule &&) = default;
281
282 void add_debug_string(SpvId result_id, const std::string &string);
283 void add_debug_symbol(SpvId id, const std::string &symbol);
284 void add_annotation(SpvInstruction val);
285 void add_type(SpvInstruction val);
286 void add_constant(SpvInstruction val);
287 void add_global(SpvInstruction val);
288 void add_execution_mode(SpvInstruction val);
289 void add_function(SpvFunction val);
290 void add_instruction(SpvInstruction val);
291 void add_entry_point(const std::string &name, SpvInstruction entry_point);
292
293 void import_instruction_set(SpvId id, const std::string &instruction_set);
294 void require_capability(SpvCapability val);
295 void require_extension(const std::string &val);
296
297 void set_version_format(uint32_t version);
298 void set_source_language(SpvSourceLanguage val);
299 void set_addressing_model(SpvAddressingModel val);
300 void set_memory_model(SpvMemoryModel val);
301 void set_binding_count(SpvId count);
302
303 uint32_t version_format() const;
304 SpvSourceLanguage source_language() const;
305 SpvAddressingModel addressing_model() const;
306 SpvMemoryModel memory_model() const;
307 SpvInstruction entry_point(const std::string &name) const;
308 EntryPointNames entry_point_names() const;
309 ImportNames import_names() const;
310 SpvId lookup_import(const std::string &Instruction_set) const;
311 uint32_t entry_point_count() const;
312
313 Imports imports() const;
314 Extensions extensions() const;
315 Capabilities capabilities() const;
316 Instructions entry_points() const;
317 const Instructions &execution_modes() const;
318 const Instructions &debug_source() const;
319 const Instructions &debug_symbols() const;
320 const Instructions &annotations() const;
321 const Instructions &type_definitions() const;
322 const Instructions &global_constants() const;
323 const Instructions &global_variables() const;
324 const Functions &function_definitions() const;
325
326 uint32_t binding_count() const;
327 SpvModule module() const;
328
329 bool is_imported(const std::string &instruction_set) const;
330 bool is_capability_required(SpvCapability val) const;
331 bool is_extension_required(const std::string &val) const;
332 bool is_defined() const;
333 SpvId id() const;
334 void check_defined() const;
335 void clear();
336
337 void encode(SpvBinary &binary) const;
338
339 static SpvModule make(SpvId module_id,
340 SpvSourceLanguage source_language = SpvSourceLanguageUnknown,
341 SpvAddressingModel addressing_model = SpvAddressingModelLogical,
342 SpvMemoryModel memory_model = SpvMemoryModelSimple);
343
344protected:
345 SpvModuleContentsPtr contents;
346};
347
348/** Builder interface for constructing a SPIR-V code module and
349 * all associated types, declarations, blocks, functions &
350 * instructions */
351class SpvBuilder {
352public:
353 using ParamTypes = std::vector<SpvId>;
354 using Components = std::vector<SpvId>;
355 using StructMemberTypes = std::vector<SpvId>;
356 using Variables = std::vector<SpvId>;
357 using Indices = std::vector<uint32_t>;
358 using Literals = std::vector<uint32_t>;
359
360 SpvBuilder();
361 ~SpvBuilder() = default;
362
363 SpvBuilder(const SpvBuilder &) = delete;
364 SpvBuilder &operator=(const SpvBuilder &) = delete;
365
366 // Reserve a unique ID to use for identifying a specifc kind of SPIR-V result **/
367 SpvId reserve_id(SpvKind = SpvResultId);
368
369 // Look up the specific kind of SPIR-V item from its unique ID
370 SpvKind kind_of(SpvId id) const;
371
372 // Get a human readable name for a specific kind of SPIR-V item
373 std::string kind_name(SpvKind kind) const;
374
375 // Look up the ID associated with the type for a given variable ID
376 SpvId type_of(SpvId variable_id) const;
377
378 // Top-Level declaration methods ... each of these is a convenvience
379 // function that checks to see if the requested thing has already been
380 // declared, in which case it returns its existing id, otherwise it
381 // adds a new declaration, and returns the new id. This avoids all
382 // the logic checks in the calling code, and also ensures that
383 // duplicates aren't created.
384
385 SpvId declare_void_type();
386 SpvId declare_type(const Type &type, uint32_t array_size = 1);
387 SpvId declare_pointer_type(const Type &type, SpvStorageClass storage_class);
388 SpvId declare_pointer_type(SpvId type_id, SpvStorageClass storage_class);
389 SpvId declare_constant(const Type &type, const void *data, bool is_specialization = false);
390 SpvId declare_null_constant(const Type &type);
391 SpvId declare_bool_constant(bool value);
392 SpvId declare_string_constant(const std::string &str);
393 SpvId declare_integer_constant(const Type &type, int64_t value);
394 SpvId declare_float_constant(const Type &type, double value);
395 SpvId declare_scalar_constant(const Type &type, const void *data);
396 SpvId declare_vector_constant(const Type &type, const void *data);
397 SpvId declare_specialization_constant(const Type &type, const void *data);
398 SpvId declare_access_chain(SpvId ptr_type_id, SpvId base_id, const Indices &indices);
399 SpvId declare_pointer_access_chain(SpvId ptr_type_id, SpvId base_id, SpvId element_id, const Indices &indices);
400 SpvId declare_function_type(SpvId return_type, const ParamTypes &param_types = {});
401 SpvId declare_function(const std::string &name, SpvId function_type);
402 SpvId declare_struct(const std::string &name, const StructMemberTypes &member_types);
403 SpvId declare_variable(const std::string &name, SpvId type_id, SpvStorageClass storage_class, SpvId initializer_id = SpvInvalidId);
404 SpvId declare_global_variable(const std::string &name, SpvId type_id, SpvStorageClass storage_class, SpvId initializer_id = SpvInvalidId);
405 SpvId declare_symbol(const std::string &symbol, SpvId id, SpvId scope_id);
406
407 // Top level creation methods for adding new items ... these have a limited
408 // number of checks and the caller must ensure that duplicates aren't created
409 SpvId add_type(const Type &type, uint32_t array_size = 1);
410 SpvId add_struct(const std::string &name, const StructMemberTypes &member_types);
411 SpvId add_array_with_default_size(SpvId base_type_id, SpvId array_size_id);
412 SpvId add_runtime_array(SpvId base_type_id);
413 SpvId add_pointer_type(const Type &type, SpvStorageClass storage_class);
414 SpvId add_pointer_type(SpvId base_type_id, SpvStorageClass storage_class);
415 SpvId add_constant(const Type &type, const void *data, bool is_specialization = false);
416 SpvId add_function_type(SpvId return_type_id, const ParamTypes &param_type_ids);
417 SpvId add_function(const std::string &name, SpvId return_type, const ParamTypes &param_types = {});
418 SpvId add_instruction(SpvInstruction val);
419
420 void add_annotation(SpvId target_id, SpvDecoration decoration_type, const Literals &literals = {});
421 void add_struct_annotation(SpvId struct_type_id, uint32_t member_index, SpvDecoration decoration_type, const Literals &literals = {});
422 void add_symbol(const std::string &symbol, SpvId id, SpvId scope_id);
423
424 void add_entry_point(SpvId func_id, SpvExecutionModel exec_model,
425 const Variables &variables = {});
426
427 // Define the execution mode with a fixed local size for the workgroup (using literal values)
428 void add_execution_mode_local_size(SpvId entry_point_id, uint32_t local_size_x, uint32_t local_size_y, uint32_t local_size_z);
429
430 // Same as above but uses id's for the local size (to allow specialization constants to be used)
431 void add_execution_mode_local_size_id(SpvId entry_point_id, SpvId local_size_x, SpvId local_size_y, SpvId local_size_z);
432
433 // Assigns a specific SPIR-V version format for output (needed for compatibility)
434 void set_version_format(uint32_t version);
435
436 // Assigns a specific source language hint to the module
437 void set_source_language(SpvSourceLanguage val);
438
439 // Sets the addressing model to use for the module
440 void set_addressing_model(SpvAddressingModel val);
441
442 // Sets the memory model to use for the module
443 void set_memory_model(SpvMemoryModel val);
444
445 // Returns the source language hint for the module
446 SpvSourceLanguage source_language() const;
447
448 // Returns the addressing model used for the module
449 SpvAddressingModel addressing_model() const;
450
451 // Returns the memory model used for the module
452 SpvMemoryModel memory_model() const;
453
454 // Import the GLSL.std.450 external instruction set. Returns its corresponding ID.
455 SpvId import_glsl_intrinsics();
456
457 // Import an external instruction set bby name. Returns its corresponding ID.
458 SpvId import_instruction_set(const std::string &instruction_set);
459
460 // Add an extension string to the list of required extensions for the module
461 void require_extension(const std::string &extension);
462
463 // Add a specific capability to the list of requirements for the module
464 void require_capability(SpvCapability);
465
466 // Returns true if the given instruction set has been imported
467 bool is_imported(const std::string &instruction_set) const;
468
469 // Returns true if the given extension string is required by the module
470 bool is_extension_required(const std::string &extension) const;
471
472 // Returns true if the given capability is required by the module
473 bool is_capability_required(SpvCapability) const;
474
475 // Change the current build location to the given block. All local
476 // declarations and instructions will be added here.
477 void enter_block(const SpvBlock &block);
478
479 // Create a new block with the given ID
480 SpvBlock create_block(SpvId block_id);
481
482 // Returns the current block (the active scope for building)
483 SpvBlock current_block() const;
484
485 // Resets the block build scope, and unassigns the current block
486 SpvBlock leave_block();
487
488 // Change the current build scope to be within the given function
489 void enter_function(const SpvFunction &func);
490
491 // Returns the function object for the given ID (or an invalid function if none is found)
492 SpvFunction lookup_function(SpvId func_id) const;
493
494 // Returns the current function being used as the active build scope
495 SpvFunction current_function() const;
496
497 // Resets the function build scope, and unassigns the current function
498 SpvFunction leave_function();
499
500 // Returns the current id being used for building (ie the last item created)
501 SpvId current_id() const;
502
503 // Updates the current id being used for building
504 void update_id(SpvId id);
505
506 // Returns true if the given id is of the corresponding type
507 bool is_pointer_type(SpvId id) const;
508 bool is_struct_type(SpvId id) const;
509 bool is_vector_type(SpvId id) const;
510 bool is_scalar_type(SpvId id) const;
511 bool is_array_type(SpvId id) const;
512 bool is_constant(SpvId id) const;
513
514 // Looks up the given pointer type id and returns a corresponding base type id (or an invalid id if none is found)
515 SpvId lookup_base_type(SpvId pointer_type) const;
516
517 // Returns the storage class for the given variable id (or invalid if none is found)
518 SpvStorageClass lookup_storage_class(SpvId id) const;
519
520 // Returns the item id for the given symbol name (or an invalid id if none is found)
521 SpvId lookup_id(const std::string &symbol) const;
522
523 // Returns the build scope id for the item id (or an invalid id if none is found)
524 SpvId lookup_scope(SpvId id) const;
525
526 // Returns the id for the imported instruction set (or an invalid id if none is found)
527 SpvId lookup_import(const std::string &instruction_set) const;
528
529 // Returns the symbol string for the given id (or an empty string if none is found)
530 std::string lookup_symbol(SpvId id) const;
531
532 // Returns the current module being used for building
533 SpvModule current_module() const;
534
535 // Appends the given instruction to the current build location
536 void append(SpvInstruction inst);
537
538 // Finalizes the module and prepares it for encoding (must be called before module can be used)
539 void finalize();
540
541 // Encodes the current module to the given binary
542 void encode(SpvBinary &binary) const;
543
544 // Resets the builder and all internal state
545 void reset();
546
547protected:
548 using TypeKey = uint64_t;
549 using TypeMap = std::unordered_map<TypeKey, SpvId>;
550 using KindMap = std::unordered_map<SpvId, SpvKind>;
551 using PointerTypeKey = std::pair<SpvId, SpvStorageClass>;
552 using PointerTypeMap = std::map<PointerTypeKey, SpvId>;
553 using BaseTypeMap = std::unordered_map<SpvId, SpvId>;
554 using VariableTypeMap = std::unordered_map<SpvId, SpvId>;
555 using StorageClassMap = std::unordered_map<SpvId, SpvStorageClass>;
556 using ConstantKey = uint64_t;
557 using ConstantMap = std::unordered_map<ConstantKey, SpvId>;
558 using StringMap = std::unordered_map<ConstantKey, SpvId>;
559 using ScopeMap = std::unordered_map<SpvId, SpvId>;
560 using IdSymbolMap = std::unordered_map<SpvId, std::string>;
561 using SymbolIdMap = std::unordered_map<std::string, SpvId>;
562 using FunctionTypeKey = uint64_t;
563 using FunctionTypeMap = std::unordered_map<FunctionTypeKey, SpvId>;
564 using FunctionMap = std::unordered_map<SpvId, SpvFunction>;
565
566 // Internal methods for creating ids, keys, and look ups
567
568 SpvId make_id(SpvKind kind);
569
570 TypeKey make_type_key(const Type &type, uint32_t array_size = 1) const;
571 SpvId lookup_type(const Type &type, uint32_t array_size = 1) const;
572
573 TypeKey make_struct_type_key(const StructMemberTypes &member_types) const;
574 SpvId lookup_struct(const std::string &name, const StructMemberTypes &member_types) const;
575
576 PointerTypeKey make_pointer_type_key(const Type &type, SpvStorageClass storage_class) const;
577 SpvId lookup_pointer_type(const Type &type, SpvStorageClass storage_class) const;
578
579 PointerTypeKey make_pointer_type_key(SpvId base_type_id, SpvStorageClass storage_class) const;
580 SpvId lookup_pointer_type(SpvId base_type_id, SpvStorageClass storage_class) const;
581
582 template<typename T>
583 SpvId declare_scalar_constant_of_type(const Type &scalar_type, const T *data);
584
585 template<typename T>
586 SpvId declare_specialization_constant_of_type(const Type &scalar_type, const T *data);
587
588 template<typename T>
589 SpvBuilder::Components declare_constants_for_each_lane(Type type, const void *data);
590
591 ConstantKey make_bool_constant_key(bool value) const;
592 ConstantKey make_string_constant_key(const std::string &value) const;
593 ConstantKey make_constant_key(uint8_t code, uint8_t bits, int lanes, size_t bytes, const void *data, bool is_specialization = false) const;
594 ConstantKey make_constant_key(const Type &type, const void *data, bool is_specialization = false) const;
595 SpvId lookup_constant(const Type &type, const void *data, bool is_specialization = false) const;
596
597 ConstantKey make_null_constant_key(const Type &type) const;
598 SpvId lookup_null_constant(const Type &type) const;
599
600 SpvId lookup_variable(const std::string &name, SpvId type_id, SpvStorageClass storage_class, SpvId scope_id) const;
601 bool has_variable(const std::string &name, SpvId type_id, SpvStorageClass storage_class, SpvId scope_id) const;
602
603 FunctionTypeKey make_function_type_key(SpvId return_type_id, const ParamTypes &param_type_ids) const;
604 SpvId lookup_function_type(SpvId return_type_id, const ParamTypes &param_type_ids) const;
605
606 SpvId active_id = SpvInvalidId;
607 SpvFunction active_function;
608 SpvBlock active_block;
609 SpvModule module;
610 KindMap kind_map;
611 TypeMap type_map;
612 TypeMap struct_map;
613 ScopeMap scope_map;
614 StringMap string_map;
615 ConstantMap constant_map;
616 FunctionMap function_map;
617 IdSymbolMap id_symbol_map;
618 SymbolIdMap symbol_id_map;
619 BaseTypeMap base_type_map;
620 StorageClassMap storage_class_map;
621 PointerTypeMap pointer_type_map;
622 VariableTypeMap variable_type_map;
623 FunctionTypeMap function_type_map;
624};
625
626/** Factory interface for constructing specific SPIR-V instructions */
627struct SpvFactory {
628 using Indices = std::vector<uint32_t>;
629 using Literals = std::vector<uint32_t>;
630 using BranchWeights = std::vector<uint32_t>;
631 using Components = std::vector<SpvId>;
632 using ParamTypes = std::vector<SpvId>;
633 using MemberTypeIds = std::vector<SpvId>;
634 using Operands = std::vector<SpvId>;
635 using Variables = std::vector<SpvId>;
636 using VariableBlockIdPair = std::pair<SpvId, SpvId>; // (Variable Id, Block Id)
637 using BlockVariables = std::vector<VariableBlockIdPair>;
638
639 static SpvInstruction no_op(SpvId result_id);
640 static SpvInstruction capability(const SpvCapability &capability);
641 static SpvInstruction extension(const std::string &extension);
642 static SpvInstruction import(SpvId instruction_set_id, const std::string &instruction_set_name);
643 static SpvInstruction label(SpvId result_id);
644 static SpvInstruction debug_line(SpvId string_id, uint32_t line, uint32_t column);
645 static SpvInstruction debug_string(SpvId result_id, const std::string &string);
646 static SpvInstruction debug_symbol(SpvId target_id, const std::string &symbol);
647 static SpvInstruction decorate(SpvId target_id, SpvDecoration decoration_type, const Literals &literals = {});
648 static SpvInstruction decorate_member(SpvId struct_type_id, uint32_t member_index, SpvDecoration decoration_type, const Literals &literals = {});
649 static SpvInstruction void_type(SpvId void_type_id);
650 static SpvInstruction bool_type(SpvId bool_type_id);
651 static SpvInstruction integer_type(SpvId int_type_id, uint32_t bits, uint32_t signedness);
652 static SpvInstruction float_type(SpvId float_type_id, uint32_t bits);
653 static SpvInstruction vector_type(SpvId vector_type_id, SpvId element_type_id, uint32_t vector_size);
654 static SpvInstruction array_type(SpvId array_type_id, SpvId element_type_id, SpvId array_size_id);
655 static SpvInstruction struct_type(SpvId result_id, const MemberTypeIds &member_type_ids);
656 static SpvInstruction runtime_array_type(SpvId result_type_id, SpvId base_type_id);
657 static SpvInstruction pointer_type(SpvId pointer_type_id, SpvStorageClass storage_class, SpvId base_type_id);
658 static SpvInstruction function_type(SpvId function_type_id, SpvId return_type_id, const ParamTypes &param_type_ids);
659 static SpvInstruction constant(SpvId result_id, SpvId type_id, size_t bytes, const void *data, SpvValueType value_type);
660 static SpvInstruction null_constant(SpvId result_id, SpvId type_id);
661 static SpvInstruction bool_constant(SpvId result_id, SpvId type_id, bool value);
662 static SpvInstruction string_constant(SpvId result_id, const std::string &value);
663 static SpvInstruction composite_constant(SpvId result_id, SpvId type_id, const Components &components);
664 static SpvInstruction specialization_constant(SpvId result_id, SpvId type_id, size_t bytes, const void *data, SpvValueType value_type);
665 static SpvInstruction variable(SpvId result_id, SpvId result_type_id, uint32_t storage_class, SpvId initializer_id = SpvInvalidId);
666 static SpvInstruction function(SpvId return_type_id, SpvId func_id, uint32_t control_mask, SpvId func_type_id);
667 static SpvInstruction function_parameter(SpvId param_type_id, SpvId param_id);
668 static SpvInstruction function_end();
669 static SpvInstruction return_stmt(SpvId return_value_id = SpvInvalidId);
670 static SpvInstruction entry_point(SpvId exec_model, SpvId func_id, const std::string &name, const Variables &variables);
671 static SpvInstruction memory_model(SpvAddressingModel addressing_model, SpvMemoryModel memory_model);
672 static SpvInstruction exec_mode_local_size(SpvId function_id, uint32_t local_size_size_x, uint32_t local_size_size_y, uint32_t local_size_size_z);
673 static SpvInstruction exec_mode_local_size_id(SpvId function_id, SpvId local_size_x_id, SpvId local_size_y_id, SpvId local_size_z_id); // only avail in 1.2
674 static SpvInstruction memory_barrier(SpvId memory_scope_id, SpvId semantics_mask_id);
675 static SpvInstruction control_barrier(SpvId execution_scope_id, SpvId memory_scope_id, SpvId semantics_mask_id);
676 static SpvInstruction bitwise_not(SpvId type_id, SpvId result_id, SpvId src_id);
677 static SpvInstruction bitwise_and(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
678 static SpvInstruction logical_not(SpvId type_id, SpvId result_id, SpvId src_id);
679 static SpvInstruction logical_and(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
680 static SpvInstruction shift_right_logical(SpvId type_id, SpvId result_id, SpvId src_id, SpvId shift_id);
681 static SpvInstruction shift_right_arithmetic(SpvId type_id, SpvId result_id, SpvId src_id, SpvId shift_id);
682 static SpvInstruction multiply_extended(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
683 static SpvInstruction select(SpvId type_id, SpvId result_id, SpvId condition_id, SpvId true_id, SpvId false_id);
684 static SpvInstruction in_bounds_access_chain(SpvId type_id, SpvId result_id, SpvId base_id, const Indices &indices);
685 static SpvInstruction pointer_access_chain(SpvId type_id, SpvId result_id, SpvId base_id, SpvId element_id, const Indices &indices);
686 static SpvInstruction load(SpvId type_id, SpvId result_id, SpvId ptr_id, uint32_t access_mask = 0x0);
687 static SpvInstruction store(SpvId ptr_id, SpvId obj_id, uint32_t access_mask = 0x0);
688 static SpvInstruction vector_insert_dynamic(SpvId type_id, SpvId result_id, SpvId vector_id, SpvId value_id, SpvId index_id);
689 static SpvInstruction vector_extract_dynamic(SpvId type_id, SpvId result_id, SpvId vector_id, SpvId value_id, SpvId index_id);
690 static SpvInstruction vector_shuffle(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, const Indices &indices);
691 static SpvInstruction composite_insert(SpvId type_id, SpvId result_id, SpvId object_id, SpvId composite_id, const SpvFactory::Indices &indices);
692 static SpvInstruction composite_extract(SpvId type_id, SpvId result_id, SpvId composite_id, const Indices &indices);
693 static SpvInstruction composite_construct(SpvId type_id, SpvId result_id, const Components &constituents);
694 static SpvInstruction is_inf(SpvId type_id, SpvId result_id, SpvId src_id);
695 static SpvInstruction is_nan(SpvId type_id, SpvId result_id, SpvId src_id);
696 static SpvInstruction bitcast(SpvId type_id, SpvId result_id, SpvId src_id);
697 static SpvInstruction float_add(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
698 static SpvInstruction integer_add(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
699 static SpvInstruction integer_equal(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
700 static SpvInstruction integer_not_equal(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
701 static SpvInstruction integer_less_than(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
702 static SpvInstruction integer_less_than_equal(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
703 static SpvInstruction integer_greater_than(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
704 static SpvInstruction integer_greater_than_equal(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
705 static SpvInstruction branch(SpvId target_label_id);
706 static SpvInstruction conditional_branch(SpvId condition_label_id, SpvId true_label_id, SpvId false_label_id, const BranchWeights &weights = {});
707 static SpvInstruction loop_merge(SpvId merge_label_id, SpvId continue_label_id, uint32_t loop_control_mask = SpvLoopControlMaskNone);
708 static SpvInstruction selection_merge(SpvId merge_label_id, uint32_t selection_control_mask = SpvSelectionControlMaskNone);
709 static SpvInstruction phi(SpvId type_id, SpvId result_id, const BlockVariables &block_vars);
710 static SpvInstruction unary_op(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_id);
711 static SpvInstruction binary_op(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
712 static SpvInstruction convert(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_id);
713 static SpvInstruction extended(SpvId instruction_set_id, SpvId instruction_number, SpvId type_id, SpvId result_id, const SpvFactory::Operands &operands);
714};
715
716/** Contents of a SPIR-V Instruction */
717struct SpvInstructionContents {
718 using Operands = std::vector<SpvId>;
719 using ValueTypes = std::vector<SpvValueType>;
720 mutable RefCount ref_count;
721 SpvOp op_code = SpvOpNop;
722 SpvId result_id = SpvNoResult;
723 SpvId type_id = SpvNoType;
724 Operands operands;
725 ValueTypes value_types;
726};
727
728/** Contents of a SPIR-V code block */
729struct SpvBlockContents {
730 using Instructions = std::vector<SpvInstruction>;
731 using Variables = std::vector<SpvInstruction>;
732 using Blocks = std::vector<SpvBlock>;
733 mutable RefCount ref_count;
734 SpvId block_id = SpvInvalidId;
735 Instructions instructions;
736 Variables variables;
737 Blocks before;
738 Blocks after;
739 bool reachable = true;
740};
741
742/** Contents of a SPIR-V function */
743struct SpvFunctionContents {
744 using PrecisionMap = std::unordered_map<SpvId, SpvPrecision>;
745 using Parameters = std::vector<SpvInstruction>;
746 using Blocks = std::vector<SpvBlock>;
747 mutable RefCount ref_count;
748 SpvId function_id;
749 SpvId function_type_id;
750 SpvId return_type_id;
751 uint32_t control_mask;
752 SpvInstruction declaration;
753 Parameters parameters;
754 PrecisionMap precision;
755 Blocks blocks;
756};
757
758/** Contents of a SPIR-V code module */
759struct SpvModuleContents {
760 using Capabilities = std::set<SpvCapability>;
761 using Extensions = std::set<std::string>;
762 using Imports = std::unordered_map<std::string, SpvId>;
763 using Functions = std::vector<SpvFunction>;
764 using Instructions = std::vector<SpvInstruction>;
765 using EntryPoints = std::unordered_map<std::string, SpvInstruction>;
766
767 mutable RefCount ref_count;
768 SpvId module_id = SpvInvalidId;
769 SpvId version_format = SpvVersion;
770 SpvId binding_count = 0;
771 SpvSourceLanguage source_language = SpvSourceLanguageUnknown;
772 SpvAddressingModel addressing_model = SpvAddressingModelLogical;
773 SpvMemoryModel memory_model = SpvMemoryModelSimple;
774 Capabilities capabilities;
775 Extensions extensions;
776 Imports imports;
777 EntryPoints entry_points;
778 Instructions execution_modes;
779 Instructions debug_source;
780 Instructions debug_symbols;
781 Instructions annotations;
782 Instructions types;
783 Instructions constants;
784 Instructions globals;
785 Functions functions;
786 Instructions instructions;
787};
788
789/** Helper functions for determining calling convention of GLSL builtins **/
790bool is_glsl_unary_op(SpvId glsl_op_code);
791bool is_glsl_binary_op(SpvId glsl_op_code);
792uint32_t glsl_operand_count(SpvId glsl_op_code);
793
794/** Output the contents of a SPIR-V module in human-readable form **/
795std::ostream &operator<<(std::ostream &stream, const SpvModule &);
796
797/** Output the definition of a SPIR-V function in human-readable form **/
798std::ostream &operator<<(std::ostream &stream, const SpvFunction &);
799
800/** Output the contents of a SPIR-V block in human-readable form **/
801std::ostream &operator<<(std::ostream &stream, const SpvBlock &);
802
803/** Output a SPIR-V instruction in human-readable form **/
804std::ostream &operator<<(std::ostream &stream, const SpvInstruction &);
805
806} // namespace Internal
807} // namespace Halide
808
809#endif // WITH_SPIRV
810
811namespace Halide {
812namespace Internal {
813
814/** Internal test for SPIR-V IR **/
816
817} // namespace Internal
818} // namespace Halide
819
820#endif // HALIDE_SPIRV_IR_H
Support classes for reference-counting via intrusive shared pointers.
Defines halide types.
void * lookup_symbol(const char *sym, const known_symbol *map)
HALIDE_ALWAYS_INLINE auto select(C &&c, T &&t, F &&f) noexcept -> SelectOp< decltype(pattern_arg(c)), decltype(pattern_arg(t)), decltype(pattern_arg(f))>
Definition IRMatch.h:1719
RefCount & ref_count(const T *t) noexcept
Because in this header we don't yet know how client classes store their RefCount (and we don't want t...
void spirv_ir_test()
Internal test for SPIR-V IR.
ConstantInterval operator<<(const ConstantInterval &a, const ConstantInterval &b)
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
Type type_of()
Construct the halide equivalent of a C type.
Definition Type.h:572
Expr is_nan(Expr x)
Returns true if the argument is a Not a Number (NaN).
Expr is_inf(Expr x)
Returns true if the argument is Inf or -Inf.
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
unsigned __INT8_TYPE__ uint8_t
unsigned __INT32_TYPE__ uint32_t