Halide
SpirvIR.h
Go to the documentation of this file.
1 #ifndef HALIDE_SPIRV_IR_H
2 #define HALIDE_SPIRV_IR_H
3 
4 /** \file
5  * Defines methods for constructing and encoding instructions into the Khronos
6  * format specification known as the Standard Portable Intermediate Representation
7  * for Vulkan (SPIR-V). These class interfaces adopt Halide's conventions for its
8  * own IR, but is implemented as a stand-alone optional component that can be
9  * enabled as required for certain runtimes (eg Vulkan).
10  *
11  * NOTE: This file is only used internally for CodeGen! *DO NOT* add this file
12  * to the list of exported Halide headers in the src/CMakeFiles.txt or the
13  * top level Makefile.
14  */
15 #ifdef WITH_SPIRV
16 
17 #include <map>
18 #include <set>
19 #include <stack>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "IntrusivePtr.h"
24 #include "Type.h"
25 
26 #include <spirv/1.6/GLSL.std.450.h> // GLSL extended instructions for common intrinsics
27 #include <spirv/1.6/spirv.h> // Use v1.6 headers but only use the minimal viable format version (for maximum compatiblity)
28 
29 namespace Halide {
30 namespace Internal {
31 
32 /** Precision requirment for return values */
33 enum SpvPrecision {
34  SpvFullPrecision,
35  SpvRelaxedPrecision,
36 };
37 
38 /** Scope qualifiers for Execution & Memory operations */
39 enum SpvScope {
40  SpvCrossDeviceScope = 0,
41  SpvDeviceScope = 1,
42  SpvWorkgroupScope = 2,
43  SpvSubgroupScope = 3,
44  SpvInvocationScope = 4
45 };
46 
47 /** Specific types of predefined constants */
48 enum SpvPredefinedConstant {
49  SpvNullConstant,
50  SpvTrueConstant,
51  SpvFalseConstant,
52 };
53 
54 /** Specific types of SPIR-V object ids */
55 enum SpvKind {
56  SpvInvalidItem,
57  SpvTypeId,
58  SpvVoidTypeId,
59  SpvBoolTypeId,
60  SpvIntTypeId,
61  SpvUIntTypeId,
62  SpvFloatTypeId,
63  SpvVectorTypeId,
64  SpvArrayTypeId,
65  SpvRuntimeArrayTypeId,
66  SpvStringTypeId,
67  SpvPointerTypeId,
68  SpvStructTypeId,
69  SpvFunctionTypeId,
70  SpvAccessChainId,
71  SpvConstantId,
72  SpvBoolConstantId,
73  SpvIntConstantId,
74  SpvFloatConstantId,
75  SpvStringConstantId,
76  SpvCompositeConstantId,
77  SpvResultId,
78  SpvVariableId,
79  SpvInstructionId,
80  SpvFunctionId,
81  SpvBlockId,
82  SpvLabelId,
83  SpvParameterId,
84  SpvImportId,
85  SpvModuleId,
86  SpvUnknownItem,
87 };
88 
89 /** Specific types of SPIR-V operand types */
90 enum SpvValueType {
91  SpvInvalidValueType,
92  SpvOperandId,
93  SpvBitMaskLiteral,
94  SpvIntegerLiteral,
95  SpvIntegerData,
96  SpvFloatData,
97  SpvStringData,
98  SpvUnknownValueType
99 };
100 
101 /** SPIR-V requires all IDs to be 32-bit unsigned integers */
102 using SpvId = uint32_t;
103 using SpvBinary = std::vector<uint32_t>;
104 
105 static constexpr SpvStorageClass SpvInvalidStorageClass = SpvStorageClassMax; // sentinel for invalid storage class
106 static constexpr SpvId SpvInvalidId = SpvId(-1);
107 static constexpr SpvId SpvNoResult = 0;
108 static constexpr SpvId SpvNoType = 0;
109 
110 /** Pre-declarations for SPIR-V IR classes */
111 class SpvModule;
112 class SpvFunction;
113 class SpvBlock;
114 class SpvInstruction;
115 class SpvBuilder;
116 struct SpvFactory;
117 
118 /** Pre-declarations for SPIR-V IR data structures */
119 struct SpvModuleContents;
120 struct SpvFunctionContents;
121 struct SpvBlockContents;
122 struct SpvInstructionContents;
123 
124 /** Intrusive pointer types for SPIR-V IR data */
125 using SpvModuleContentsPtr = IntrusivePtr<SpvModuleContents>;
126 using SpvFunctionContentsPtr = IntrusivePtr<SpvFunctionContents>;
127 using SpvBlockContentsPtr = IntrusivePtr<SpvBlockContents>;
128 using SpvInstructionContentsPtr = IntrusivePtr<SpvInstructionContents>;
129 
130 /** General interface for representing a SPIR-V Instruction */
131 class SpvInstruction {
132 public:
133  using LiteralValue = std::pair<uint32_t, SpvValueType>;
134  using Immediates = std::vector<LiteralValue>;
135  using Operands = std::vector<SpvId>;
136  using ValueTypes = std::vector<SpvValueType>;
137 
138  SpvInstruction() = default;
139  ~SpvInstruction() = default;
140 
141  SpvInstruction(const SpvInstruction &) = default;
142  SpvInstruction &operator=(const SpvInstruction &) = default;
143  SpvInstruction(SpvInstruction &&) = default;
144  SpvInstruction &operator=(SpvInstruction &&) = default;
145 
146  void set_block(SpvBlock block);
147  void set_result_id(SpvId id);
148  void set_type_id(SpvId id);
149  void set_op_code(SpvOp opcode);
150  void add_operand(SpvId id);
151  void add_operands(const Operands &operands);
152  void add_immediate(SpvId id, SpvValueType type);
153  void add_immediates(const Immediates &Immediates);
154  void add_data(uint32_t bytes, const void *data, SpvValueType type);
155  void add_string(const std::string &str);
156 
157  template<typename T>
158  void append(const T &operands_or_immediates_or_strings);
159 
160  SpvId result_id() const;
161  SpvId type_id() const;
162  SpvOp op_code() const;
163  SpvId operand(uint32_t index) const;
164  const void *data(uint32_t index = 0) const;
165  SpvValueType value_type(uint32_t index) const;
166  const Operands &operands() const;
167 
168  bool has_type() const;
169  bool has_result() const;
170  bool is_defined() const;
171  bool is_immediate(uint32_t index) const;
172  uint32_t length() const;
173  SpvBlock block() const;
174  void check_defined() const;
175 
176  void encode(SpvBinary &binary) const;
177 
178  static SpvInstruction make(SpvOp op_code);
179 
180 protected:
181  SpvInstructionContentsPtr contents;
182 };
183 
184 /** General interface for representing a SPIR-V Block */
185 class SpvBlock {
186 public:
187  using Instructions = std::vector<SpvInstruction>;
188  using Variables = std::vector<SpvInstruction>;
189  using Blocks = std::vector<SpvBlock>;
190 
191  SpvBlock() = default;
192  ~SpvBlock() = default;
193 
194  SpvBlock(const SpvBlock &) = default;
195  SpvBlock &operator=(const SpvBlock &) = default;
196  SpvBlock(SpvBlock &&) = default;
197  SpvBlock &operator=(SpvBlock &&) = default;
198 
199  void add_instruction(SpvInstruction inst);
200  void add_variable(SpvInstruction var);
201  void set_function(SpvFunction func);
202  const Instructions &instructions() const;
203  const Variables &variables() const;
204  SpvFunction function() const;
205  bool is_reachable() const;
206  bool is_terminated() const;
207  bool is_defined() const;
208  SpvId id() const;
209  void check_defined() const;
210 
211  void encode(SpvBinary &binary) const;
212 
213  static SpvBlock make(SpvFunction func, SpvId id);
214 
215 protected:
216  SpvBlockContentsPtr contents;
217 };
218 
219 /** General interface for representing a SPIR-V Function */
220 class SpvFunction {
221 public:
222  using Blocks = std::vector<SpvBlock>;
223  using Parameters = std::vector<SpvInstruction>;
224 
225  SpvFunction() = default;
226  ~SpvFunction() = default;
227 
228  SpvFunction(const SpvFunction &) = default;
229  SpvFunction &operator=(const SpvFunction &) = default;
230  SpvFunction(SpvFunction &&) = default;
231  SpvFunction &operator=(SpvFunction &&) = default;
232 
233  SpvBlock create_block(SpvId block_id);
234  void add_block(const SpvBlock &block);
235  void add_parameter(const SpvInstruction &param);
236  void set_module(SpvModule module);
237  void set_return_precision(SpvPrecision precision);
238  void set_parameter_precision(uint32_t index, SpvPrecision precision);
239  bool is_defined() const;
240 
241  const Blocks &blocks() const;
242  SpvBlock entry_block() const;
243  SpvBlock tail_block() const;
244  SpvPrecision return_precision() const;
245  const Parameters &parameters() const;
246  SpvPrecision parameter_precision(uint32_t index) const;
247  uint32_t parameter_count() const;
248  uint32_t control_mask() const;
249  SpvInstruction declaration() const;
250  SpvModule module() const;
251  SpvId return_type_id() const;
252  SpvId type_id() const;
253  SpvId id() const;
254  void check_defined() const;
255 
256  void encode(SpvBinary &binary) const;
257 
258  static SpvFunction make(SpvId func_id, SpvId func_type_id, SpvId return_type_id, uint32_t control_mask = SpvFunctionControlMaskNone);
259 
260 protected:
261  SpvFunctionContentsPtr contents;
262 };
263 
264 /** General interface for representing a SPIR-V code module */
265 class SpvModule {
266 public:
267  using ImportDefinition = std::pair<SpvId, std::string>;
268  using ImportNames = std::vector<std::string>;
269  using EntryPointNames = std::vector<std::string>;
270  using Instructions = std::vector<SpvInstruction>;
271  using Functions = std::vector<SpvFunction>;
272  using Capabilities = std::vector<SpvCapability>;
273  using Extensions = std::vector<std::string>;
274  using Imports = std::vector<ImportDefinition>;
275 
276  SpvModule() = default;
277  ~SpvModule() = default;
278 
279  SpvModule(const SpvModule &) = default;
280  SpvModule &operator=(const SpvModule &) = default;
281  SpvModule(SpvModule &&) = default;
282  SpvModule &operator=(SpvModule &&) = default;
283 
284  void add_debug_string(SpvId result_id, const std::string &string);
285  void add_debug_symbol(SpvId id, const std::string &symbol);
286  void add_annotation(const SpvInstruction &val);
287  void add_type(const SpvInstruction &val);
288  void add_constant(const SpvInstruction &val);
289  void add_global(const SpvInstruction &val);
290  void add_execution_mode(const SpvInstruction &val);
291  void add_function(SpvFunction val);
292  void add_instruction(const SpvInstruction &val);
293  void add_entry_point(const std::string &name, SpvInstruction entry_point);
294 
295  void import_instruction_set(SpvId id, const std::string &instruction_set);
296  void require_capability(SpvCapability val);
297  void require_extension(const std::string &val);
298 
299  void set_version_format(uint32_t version);
300  void set_source_language(SpvSourceLanguage val);
301  void set_addressing_model(SpvAddressingModel val);
302  void set_memory_model(SpvMemoryModel val);
303  void set_binding_count(SpvId count);
304 
305  uint32_t version_format() const;
306  SpvSourceLanguage source_language() const;
307  SpvAddressingModel addressing_model() const;
308  SpvMemoryModel memory_model() const;
309  SpvInstruction entry_point(const std::string &name) const;
310  EntryPointNames entry_point_names() const;
311  ImportNames import_names() const;
312  SpvId lookup_import(const std::string &Instruction_set) const;
313  uint32_t entry_point_count() const;
314 
315  Imports imports() const;
316  Extensions extensions() const;
317  Capabilities capabilities() const;
318  Instructions entry_points() const;
319  const Instructions &execution_modes() const;
320  const Instructions &debug_source() const;
321  const Instructions &debug_symbols() const;
322  const Instructions &annotations() const;
323  const Instructions &type_definitions() const;
324  const Instructions &global_constants() const;
325  const Instructions &global_variables() const;
326  const Functions &function_definitions() const;
327 
328  uint32_t binding_count() const;
329  SpvModule module() const;
330 
331  bool is_imported(const std::string &instruction_set) const;
332  bool is_capability_required(SpvCapability val) const;
333  bool is_extension_required(const std::string &val) const;
334  bool is_defined() const;
335  SpvId id() const;
336  void check_defined() const;
337 
338  void encode(SpvBinary &binary) const;
339 
340  static SpvModule make(SpvId module_id,
341  SpvSourceLanguage source_language = SpvSourceLanguageUnknown,
342  SpvAddressingModel addressing_model = SpvAddressingModelLogical,
343  SpvMemoryModel memory_model = SpvMemoryModelSimple);
344 
345 protected:
346  SpvModuleContentsPtr contents;
347 };
348 
349 /** Builder interface for constructing a SPIR-V code module and
350  * all associated types, declarations, blocks, functions &
351  * instructions */
352 class SpvBuilder {
353 public:
354  using ParamTypes = std::vector<SpvId>;
355  using Components = std::vector<SpvId>;
356  using StructMemberTypes = std::vector<SpvId>;
357  using Variables = std::vector<SpvId>;
358  using Indices = std::vector<uint32_t>;
359  using Literals = std::vector<uint32_t>;
360 
361  SpvBuilder();
362  ~SpvBuilder() = default;
363 
364  SpvBuilder(const SpvBuilder &) = delete;
365  SpvBuilder &operator=(const SpvBuilder &) = delete;
366 
367  // Reserve a unique ID to use for identifying a specifc kind of SPIR-V result **/
368  SpvId reserve_id(SpvKind = SpvResultId);
369 
370  // Look up the specific kind of SPIR-V item from its unique ID
371  SpvKind kind_of(SpvId id) const;
372 
373  // Get a human readable name for a specific kind of SPIR-V item
374  std::string kind_name(SpvKind kind) const;
375 
376  // Look up the ID associated with the type for a given variable ID
377  SpvId type_of(SpvId variable_id) const;
378 
379  // Top-Level declaration methods ... each of these is a convenvience
380  // function that checks to see if the requested thing has already been
381  // declared, in which case it returns its existing id, otherwise it
382  // adds a new declaration, and returns the new id. This avoids all
383  // the logic checks in the calling code, and also ensures that
384  // duplicates aren't created.
385 
386  SpvId declare_void_type();
387  SpvId declare_type(const Type &type, uint32_t array_size = 1);
388  SpvId declare_pointer_type(const Type &type, SpvStorageClass storage_class);
389  SpvId declare_pointer_type(SpvId type_id, SpvStorageClass storage_class);
390  SpvId declare_constant(const Type &type, const void *data, bool is_specialization = false);
391  SpvId declare_null_constant(const Type &type);
392  SpvId declare_bool_constant(bool value);
393  SpvId declare_string_constant(const std::string &str);
394  SpvId declare_integer_constant(const Type &type, int64_t value);
395  SpvId declare_float_constant(const Type &type, double value);
396  SpvId declare_scalar_constant(const Type &type, const void *data);
397  SpvId declare_vector_constant(const Type &type, const void *data);
398  SpvId declare_specialization_constant(const Type &type, const void *data);
399  SpvId declare_access_chain(SpvId ptr_type_id, SpvId base_id, const Indices &indices);
400  SpvId declare_pointer_access_chain(SpvId ptr_type_id, SpvId base_id, SpvId element_id, const Indices &indices);
401  SpvId declare_function_type(SpvId return_type, const ParamTypes &param_types = {});
402  SpvId declare_function(const std::string &name, SpvId function_type);
403  SpvId declare_struct(const std::string &name, const StructMemberTypes &member_types);
404  SpvId declare_variable(const std::string &name, SpvId type_id, SpvStorageClass storage_class, SpvId initializer_id = SpvInvalidId);
405  SpvId declare_global_variable(const std::string &name, SpvId type_id, SpvStorageClass storage_class, SpvId initializer_id = SpvInvalidId);
406  SpvId declare_symbol(const std::string &symbol, SpvId id, SpvId scope_id);
407 
408  // Top level creation methods for adding new items ... these have a limited
409  // number of checks and the caller must ensure that duplicates aren't created
410  SpvId add_type(const Type &type, uint32_t array_size = 1);
411  SpvId add_struct(const std::string &name, const StructMemberTypes &member_types);
412  SpvId add_array_with_default_size(SpvId base_type_id, SpvId array_size_id);
413  SpvId add_runtime_array(SpvId base_type_id);
414  SpvId add_pointer_type(const Type &type, SpvStorageClass storage_class);
415  SpvId add_pointer_type(SpvId base_type_id, SpvStorageClass storage_class);
416  SpvId add_constant(const Type &type, const void *data, bool is_specialization = false);
417  SpvId add_function_type(SpvId return_type_id, const ParamTypes &param_type_ids);
418  SpvId add_function(const std::string &name, SpvId return_type, const ParamTypes &param_types = {});
419  SpvId add_instruction(SpvInstruction val);
420 
421  void add_annotation(SpvId target_id, SpvDecoration decoration_type, const Literals &literals = {});
422  void add_struct_annotation(SpvId struct_type_id, uint32_t member_index, SpvDecoration decoration_type, const Literals &literals = {});
423  void add_symbol(const std::string &symbol, SpvId id, SpvId scope_id);
424 
425  void add_entry_point(SpvId func_id, SpvExecutionModel exec_model,
426  const Variables &variables = {});
427 
428  // Define the execution mode with a fixed local size for the workgroup (using literal values)
429  void add_execution_mode_local_size(SpvId entry_point_id, uint32_t local_size_x, uint32_t local_size_y, uint32_t local_size_z);
430 
431  // Same as above but uses id's for the local size (to allow specialization constants to be used)
432  void add_execution_mode_local_size_id(SpvId entry_point_id, SpvId local_size_x, SpvId local_size_y, SpvId local_size_z);
433 
434  // Assigns a specific SPIR-V version format for output (needed for compatibility)
435  void set_version_format(uint32_t version);
436 
437  // Assigns a specific source language hint to the module
438  void set_source_language(SpvSourceLanguage val);
439 
440  // Sets the addressing model to use for the module
441  void set_addressing_model(SpvAddressingModel val);
442 
443  // Sets the memory model to use for the module
444  void set_memory_model(SpvMemoryModel val);
445 
446  // Returns the source language hint for the module
447  SpvSourceLanguage source_language() const;
448 
449  // Returns the addressing model used for the module
450  SpvAddressingModel addressing_model() const;
451 
452  // Returns the memory model used for the module
453  SpvMemoryModel memory_model() const;
454 
455  // Import the GLSL.std.450 external instruction set. Returns its corresponding ID.
456  SpvId import_glsl_intrinsics();
457 
458  // Import an external instruction set bby name. Returns its corresponding ID.
459  SpvId import_instruction_set(const std::string &instruction_set);
460 
461  // Add an extension string to the list of required extensions for the module
462  void require_extension(const std::string &extension);
463 
464  // Add a specific capability to the list of requirements for the module
465  void require_capability(SpvCapability);
466 
467  // Returns true if the given instruction set has been imported
468  bool is_imported(const std::string &instruction_set) const;
469 
470  // Returns true if the given extension string is required by the module
471  bool is_extension_required(const std::string &extension) const;
472 
473  // Returns true if the given capability is required by the module
474  bool is_capability_required(SpvCapability) const;
475 
476  // Change the current build location to the given block. All local
477  // declarations and instructions will be added here.
478  void enter_block(const SpvBlock &block);
479 
480  // Create a new block with the given ID
481  SpvBlock create_block(SpvId block_id);
482 
483  // Returns the current block (the active scope for building)
484  SpvBlock current_block() const;
485 
486  // Resets the block build scope, and unassigns the current block
487  SpvBlock leave_block();
488 
489  // Change the current build scope to be within the given function
490  void enter_function(const SpvFunction &func);
491 
492  // Returns the function object for the given ID (or an invalid function if none is found)
493  SpvFunction lookup_function(SpvId func_id) const;
494 
495  // Returns the current function being used as the active build scope
496  SpvFunction current_function() const;
497 
498  // Resets the function build scope, and unassigns the current function
499  SpvFunction leave_function();
500 
501  // Returns the current id being used for building (ie the last item created)
502  SpvId current_id() const;
503 
504  // Updates the current id being used for building
505  void update_id(SpvId id);
506 
507  // Returns true if the given id is of the corresponding type
508  bool is_pointer_type(SpvId id) const;
509  bool is_struct_type(SpvId id) const;
510  bool is_vector_type(SpvId id) const;
511  bool is_scalar_type(SpvId id) const;
512  bool is_array_type(SpvId id) const;
513  bool is_constant(SpvId id) const;
514 
515  // Looks up the given pointer type id and returns a corresponding base type id (or an invalid id if none is found)
516  SpvId lookup_base_type(SpvId pointer_type) const;
517 
518  // Returns the storage class for the given variable id (or invalid if none is found)
519  SpvStorageClass lookup_storage_class(SpvId id) const;
520 
521  // Returns the item id for the given symbol name (or an invalid id if none is found)
522  SpvId lookup_id(const std::string &symbol) const;
523 
524  // Returns the build scope id for the item id (or an invalid id if none is found)
525  SpvId lookup_scope(SpvId id) const;
526 
527  // Returns the id for the imported instruction set (or an invalid id if none is found)
528  SpvId lookup_import(const std::string &instruction_set) const;
529 
530  // Returns the symbol string for the given id (or an empty string if none is found)
531  std::string lookup_symbol(SpvId id) const;
532 
533  // Returns the current module being used for building
534  SpvModule current_module() const;
535 
536  // Appends the given instruction to the current build location
537  void append(SpvInstruction inst);
538 
539  // Finalizes the module and prepares it for encoding (must be called before module can be used)
540  void finalize();
541 
542  // Encodes the current module to the given binary
543  void encode(SpvBinary &binary) const;
544 
545  // Resets the builder and all internal state
546  void reset();
547 
548 protected:
549  using TypeKey = uint64_t;
550  using TypeMap = std::unordered_map<TypeKey, SpvId>;
551  using KindMap = std::unordered_map<SpvId, SpvKind>;
552  using PointerTypeKey = std::pair<SpvId, SpvStorageClass>;
553  using PointerTypeMap = std::map<PointerTypeKey, SpvId>;
554  using BaseTypeMap = std::unordered_map<SpvId, SpvId>;
555  using VariableTypeMap = std::unordered_map<SpvId, SpvId>;
556  using StorageClassMap = std::unordered_map<SpvId, SpvStorageClass>;
557  using ConstantKey = uint64_t;
558  using ConstantMap = std::unordered_map<ConstantKey, SpvId>;
559  using StringMap = std::unordered_map<ConstantKey, SpvId>;
560  using ScopeMap = std::unordered_map<SpvId, SpvId>;
561  using IdSymbolMap = std::unordered_map<SpvId, std::string>;
562  using SymbolIdMap = std::unordered_map<std::string, SpvId>;
563  using FunctionTypeKey = uint64_t;
564  using FunctionTypeMap = std::unordered_map<FunctionTypeKey, SpvId>;
565  using FunctionMap = std::unordered_map<SpvId, SpvFunction>;
566 
567  // Internal methods for creating ids, keys, and look ups
568 
569  SpvId make_id(SpvKind kind);
570 
571  TypeKey make_type_key(const Type &type, uint32_t array_size = 1) const;
572  SpvId lookup_type(const Type &type, uint32_t array_size = 1) const;
573 
574  TypeKey make_struct_type_key(const StructMemberTypes &member_types) const;
575  SpvId lookup_struct(const std::string &name, const StructMemberTypes &member_types) const;
576 
577  PointerTypeKey make_pointer_type_key(const Type &type, SpvStorageClass storage_class) const;
578  SpvId lookup_pointer_type(const Type &type, SpvStorageClass storage_class) const;
579 
580  PointerTypeKey make_pointer_type_key(SpvId base_type_id, SpvStorageClass storage_class) const;
581  SpvId lookup_pointer_type(SpvId base_type_id, SpvStorageClass storage_class) const;
582 
583  template<typename T>
584  SpvId declare_scalar_constant_of_type(const Type &scalar_type, const T *data);
585 
586  template<typename T>
587  SpvId declare_specialization_constant_of_type(const Type &scalar_type, const T *data);
588 
589  template<typename T>
590  SpvBuilder::Components declare_constants_for_each_lane(Type type, const void *data);
591 
592  ConstantKey make_bool_constant_key(bool value) const;
593  ConstantKey make_string_constant_key(const std::string &value) const;
594  ConstantKey make_constant_key(uint8_t code, uint8_t bits, int lanes, size_t bytes, const void *data, bool is_specialization = false) const;
595  ConstantKey make_constant_key(const Type &type, const void *data, bool is_specialization = false) const;
596  SpvId lookup_constant(const Type &type, const void *data, bool is_specialization = false) const;
597 
598  ConstantKey make_null_constant_key(const Type &type) const;
599  SpvId lookup_null_constant(const Type &type) const;
600 
601  SpvId lookup_variable(const std::string &name, SpvId type_id, SpvStorageClass storage_class, SpvId scope_id) const;
602  bool has_variable(const std::string &name, SpvId type_id, SpvStorageClass storage_class, SpvId scope_id) const;
603 
604  FunctionTypeKey make_function_type_key(SpvId return_type_id, const ParamTypes &param_type_ids) const;
605  SpvId lookup_function_type(SpvId return_type_id, const ParamTypes &param_type_ids) const;
606 
607  SpvId active_id = SpvInvalidId;
608  SpvFunction active_function;
609  SpvBlock active_block;
610  SpvModule module;
611  KindMap kind_map;
612  TypeMap type_map;
613  TypeMap struct_map;
614  ScopeMap scope_map;
615  StringMap string_map;
616  ConstantMap constant_map;
617  FunctionMap function_map;
618  IdSymbolMap id_symbol_map;
619  SymbolIdMap symbol_id_map;
620  BaseTypeMap base_type_map;
621  StorageClassMap storage_class_map;
622  PointerTypeMap pointer_type_map;
623  VariableTypeMap variable_type_map;
624  FunctionTypeMap function_type_map;
625 };
626 
627 /** Factory interface for constructing specific SPIR-V instructions */
628 struct SpvFactory {
629  using Indices = std::vector<uint32_t>;
630  using Literals = std::vector<uint32_t>;
631  using BranchWeights = std::vector<uint32_t>;
632  using Components = std::vector<SpvId>;
633  using ParamTypes = std::vector<SpvId>;
634  using MemberTypeIds = std::vector<SpvId>;
635  using Operands = std::vector<SpvId>;
636  using Variables = std::vector<SpvId>;
637  using VariableBlockIdPair = std::pair<SpvId, SpvId>; // (Variable Id, Block Id)
638  using BlockVariables = std::vector<VariableBlockIdPair>;
639 
640  static SpvInstruction no_op(SpvId result_id);
641  static SpvInstruction capability(const SpvCapability &capability);
642  static SpvInstruction extension(const std::string &extension);
643  static SpvInstruction import(SpvId instruction_set_id, const std::string &instruction_set_name);
644  static SpvInstruction label(SpvId result_id);
645  static SpvInstruction debug_line(SpvId string_id, uint32_t line, uint32_t column);
646  static SpvInstruction debug_string(SpvId result_id, const std::string &string);
647  static SpvInstruction debug_symbol(SpvId target_id, const std::string &symbol);
648  static SpvInstruction decorate(SpvId target_id, SpvDecoration decoration_type, const Literals &literals = {});
649  static SpvInstruction decorate_member(SpvId struct_type_id, uint32_t member_index, SpvDecoration decoration_type, const Literals &literals = {});
650  static SpvInstruction void_type(SpvId void_type_id);
651  static SpvInstruction bool_type(SpvId bool_type_id);
652  static SpvInstruction integer_type(SpvId int_type_id, uint32_t bits, uint32_t signedness);
653  static SpvInstruction float_type(SpvId float_type_id, uint32_t bits);
654  static SpvInstruction vector_type(SpvId vector_type_id, SpvId element_type_id, uint32_t vector_size);
655  static SpvInstruction array_type(SpvId array_type_id, SpvId element_type_id, SpvId array_size_id);
656  static SpvInstruction struct_type(SpvId result_id, const MemberTypeIds &member_type_ids);
657  static SpvInstruction runtime_array_type(SpvId result_type_id, SpvId base_type_id);
658  static SpvInstruction pointer_type(SpvId pointer_type_id, SpvStorageClass storage_class, SpvId base_type_id);
659  static SpvInstruction function_type(SpvId function_type_id, SpvId return_type_id, const ParamTypes &param_type_ids);
660  static SpvInstruction constant(SpvId result_id, SpvId type_id, size_t bytes, const void *data, SpvValueType value_type);
661  static SpvInstruction null_constant(SpvId result_id, SpvId type_id);
662  static SpvInstruction bool_constant(SpvId result_id, SpvId type_id, bool value);
663  static SpvInstruction string_constant(SpvId result_id, const std::string &value);
664  static SpvInstruction composite_constant(SpvId result_id, SpvId type_id, const Components &components);
665  static SpvInstruction specialization_constant(SpvId result_id, SpvId type_id, size_t bytes, const void *data, SpvValueType value_type);
666  static SpvInstruction variable(SpvId result_id, SpvId result_type_id, uint32_t storage_class, SpvId initializer_id = SpvInvalidId);
667  static SpvInstruction function(SpvId return_type_id, SpvId func_id, uint32_t control_mask, SpvId func_type_id);
668  static SpvInstruction function_parameter(SpvId param_type_id, SpvId param_id);
669  static SpvInstruction function_end();
670  static SpvInstruction return_stmt(SpvId return_value_id = SpvInvalidId);
671  static SpvInstruction entry_point(SpvId exec_model, SpvId func_id, const std::string &name, const Variables &variables);
672  static SpvInstruction memory_model(SpvAddressingModel addressing_model, SpvMemoryModel memory_model);
673  static SpvInstruction exec_mode_local_size(SpvId function_id, uint32_t local_size_size_x, uint32_t local_size_size_y, uint32_t local_size_size_z);
674  static SpvInstruction exec_mode_local_size_id(SpvId function_id, SpvId local_size_x_id, SpvId local_size_y_id, SpvId local_size_z_id); // only avail in 1.2
675  static SpvInstruction memory_barrier(SpvId memory_scope_id, SpvId semantics_mask_id);
676  static SpvInstruction control_barrier(SpvId execution_scope_id, SpvId memory_scope_id, SpvId semantics_mask_id);
677  static SpvInstruction bitwise_not(SpvId type_id, SpvId result_id, SpvId src_id);
678  static SpvInstruction bitwise_and(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
679  static SpvInstruction logical_not(SpvId type_id, SpvId result_id, SpvId src_id);
680  static SpvInstruction logical_and(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
681  static SpvInstruction shift_right_logical(SpvId type_id, SpvId result_id, SpvId src_id, SpvId shift_id);
682  static SpvInstruction shift_right_arithmetic(SpvId type_id, SpvId result_id, SpvId src_id, SpvId shift_id);
683  static SpvInstruction multiply_extended(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
684  static SpvInstruction select(SpvId type_id, SpvId result_id, SpvId condition_id, SpvId true_id, SpvId false_id);
685  static SpvInstruction in_bounds_access_chain(SpvId type_id, SpvId result_id, SpvId base_id, const Indices &indices);
686  static SpvInstruction pointer_access_chain(SpvId type_id, SpvId result_id, SpvId base_id, SpvId element_id, const Indices &indices);
687  static SpvInstruction load(SpvId type_id, SpvId result_id, SpvId ptr_id, uint32_t access_mask = 0x0);
688  static SpvInstruction store(SpvId ptr_id, SpvId obj_id, uint32_t access_mask = 0x0);
689  static SpvInstruction vector_insert_dynamic(SpvId type_id, SpvId result_id, SpvId vector_id, SpvId value_id, SpvId index_id);
690  static SpvInstruction vector_extract_dynamic(SpvId type_id, SpvId result_id, SpvId vector_id, SpvId value_id, SpvId index_id);
691  static SpvInstruction vector_shuffle(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, const Indices &indices);
692  static SpvInstruction composite_insert(SpvId type_id, SpvId result_id, SpvId object_id, SpvId composite_id, const SpvFactory::Indices &indices);
693  static SpvInstruction composite_extract(SpvId type_id, SpvId result_id, SpvId composite_id, const Indices &indices);
694  static SpvInstruction composite_construct(SpvId type_id, SpvId result_id, const Components &constituents);
695  static SpvInstruction is_inf(SpvId type_id, SpvId result_id, SpvId src_id);
696  static SpvInstruction is_nan(SpvId type_id, SpvId result_id, SpvId src_id);
697  static SpvInstruction bitcast(SpvId type_id, SpvId result_id, SpvId src_id);
698  static SpvInstruction float_add(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
699  static SpvInstruction integer_add(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
700  static SpvInstruction integer_equal(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
701  static SpvInstruction integer_not_equal(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
702  static SpvInstruction integer_less_than(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
703  static SpvInstruction integer_less_than_equal(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
704  static SpvInstruction integer_greater_than(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
705  static SpvInstruction integer_greater_than_equal(SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id, bool is_signed);
706  static SpvInstruction branch(SpvId target_label_id);
707  static SpvInstruction conditional_branch(SpvId condition_label_id, SpvId true_label_id, SpvId false_label_id, const BranchWeights &weights = {});
708  static SpvInstruction loop_merge(SpvId merge_label_id, SpvId continue_label_id, uint32_t loop_control_mask = SpvLoopControlMaskNone);
709  static SpvInstruction selection_merge(SpvId merge_label_id, uint32_t selection_control_mask = SpvSelectionControlMaskNone);
710  static SpvInstruction phi(SpvId type_id, SpvId result_id, const BlockVariables &block_vars);
711  static SpvInstruction unary_op(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_id);
712  static SpvInstruction binary_op(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_a_id, SpvId src_b_id);
713  static SpvInstruction convert(SpvOp op_code, SpvId type_id, SpvId result_id, SpvId src_id);
714  static SpvInstruction extended(SpvId instruction_set_id, SpvId instruction_number, SpvId type_id, SpvId result_id, const SpvFactory::Operands &operands);
715 };
716 
717 /** Contents of a SPIR-V Instruction */
718 struct SpvInstructionContents {
719  using Operands = std::vector<SpvId>;
720  using ValueTypes = std::vector<SpvValueType>;
721  mutable RefCount ref_count;
722  SpvOp op_code = SpvOpNop;
723  SpvId result_id = SpvNoResult;
724  SpvId type_id = SpvNoType;
725  Operands operands;
726  ValueTypes value_types;
727  SpvBlock block;
728 };
729 
730 /** Contents of a SPIR-V code block */
731 struct SpvBlockContents {
732  using Instructions = std::vector<SpvInstruction>;
733  using Variables = std::vector<SpvInstruction>;
734  using Blocks = std::vector<SpvBlock>;
735  mutable RefCount ref_count;
736  SpvId block_id = SpvInvalidId;
737  SpvFunction parent;
738  Instructions instructions;
739  Variables variables;
740  Blocks before;
741  Blocks after;
742  bool reachable = true;
743 };
744 
745 /** Contents of a SPIR-V function */
746 struct SpvFunctionContents {
747  using PrecisionMap = std::unordered_map<SpvId, SpvPrecision>;
748  using Parameters = std::vector<SpvInstruction>;
749  using Blocks = std::vector<SpvBlock>;
750  mutable RefCount ref_count;
751  SpvModule parent;
752  SpvId function_id;
753  SpvId function_type_id;
754  SpvId return_type_id;
755  uint32_t control_mask;
756  SpvInstruction declaration;
757  Parameters parameters;
758  PrecisionMap precision;
759  Blocks blocks;
760 };
761 
762 /** Contents of a SPIR-V code module */
763 struct SpvModuleContents {
764  using Capabilities = std::set<SpvCapability>;
765  using Extensions = std::set<std::string>;
766  using Imports = std::unordered_map<std::string, SpvId>;
767  using Functions = std::vector<SpvFunction>;
768  using Instructions = std::vector<SpvInstruction>;
769  using EntryPoints = std::unordered_map<std::string, SpvInstruction>;
770 
771  mutable RefCount ref_count;
772  SpvId module_id = SpvInvalidId;
773  SpvId version_format = SpvVersion;
774  SpvId binding_count = 0;
775  SpvSourceLanguage source_language = SpvSourceLanguageUnknown;
776  SpvAddressingModel addressing_model = SpvAddressingModelLogical;
777  SpvMemoryModel memory_model = SpvMemoryModelSimple;
778  Capabilities capabilities;
779  Extensions extensions;
780  Imports imports;
781  EntryPoints entry_points;
782  Instructions execution_modes;
783  Instructions debug_source;
784  Instructions debug_symbols;
785  Instructions annotations;
786  Instructions types;
787  Instructions constants;
788  Instructions globals;
789  Functions functions;
790  Instructions instructions;
791 };
792 
793 /** Helper functions for determining calling convention of GLSL builtins **/
794 bool is_glsl_unary_op(SpvId glsl_op_code);
795 bool is_glsl_binary_op(SpvId glsl_op_code);
796 uint32_t glsl_operand_count(SpvId glsl_op_code);
797 
798 /** Output the contents of a SPIR-V module in human-readable form **/
799 std::ostream &operator<<(std::ostream &stream, const SpvModule &);
800 
801 /** Output the definition of a SPIR-V function in human-readable form **/
802 std::ostream &operator<<(std::ostream &stream, const SpvFunction &);
803 
804 /** Output the contents of a SPIR-V block in human-readable form **/
805 std::ostream &operator<<(std::ostream &stream, const SpvBlock &);
806 
807 /** Output a SPIR-V instruction in human-readable form **/
808 std::ostream &operator<<(std::ostream &stream, const SpvInstruction &);
809 
810 } // namespace Internal
811 } // namespace Halide
812 
813 #endif // WITH_SPIRV
814 
815 namespace Halide {
816 namespace Internal {
817 
818 /** Internal test for SPIR-V IR **/
819 void spirv_ir_test();
820 
821 } // namespace Internal
822 } // namespace Halide
823 
824 #endif // HALIDE_SPIRV_IR_H
uint8_t
unsigned __INT8_TYPE__ uint8_t
Definition: runtime_internal.h:29
Halide::Internal::operator<<
std::ostream & operator<<(std::ostream &stream, const Stmt &)
Emit a halide statement on an output stream (such as std::cout) in a human-readable form.
lookup_symbol
void * lookup_symbol(const char *sym, const known_symbol *map)
uint64_t
unsigned __INT64_TYPE__ uint64_t
Definition: runtime_internal.h:23
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AbstractGenerator.h:19
Halide::select
Expr select(Expr condition, Expr true_value, Expr false_value)
Returns an expression similar to the ternary operator in C, except that it always evaluates all argum...
Halide::LinkageType::Internal
@ Internal
Not visible externally, similar to 'static' linkage in C.
Halide::Internal::ref_count
RefCount & ref_count(const T *t) noexcept
Because in this header we don't yet know how client classes store their RefCount (and we don't want t...
Halide::is_inf
Expr is_inf(Expr x)
Returns true if the argument is Inf or -Inf.
int64_t
signed __INT64_TYPE__ int64_t
Definition: runtime_internal.h:22
Type.h
Halide::is_nan
Expr is_nan(Expr x)
Returns true if the argument is a Not a Number (NaN).
Halide::type_of
Type type_of()
Construct the halide equivalent of a C type.
Definition: Type.h:557
uint32_t
unsigned __INT32_TYPE__ uint32_t
Definition: runtime_internal.h:25
Halide::Internal::spirv_ir_test
void spirv_ir_test()
Internal test for SPIR-V IR.
IntrusivePtr.h