Halide
CodeGen_PyTorch.h
Go to the documentation of this file.
1 #ifndef HALIDE_CODEGEN_PYTORCH_H
2 #define HALIDE_CODEGEN_PYTORCH_H
3 
4 /** \file
5  *
6  * Defines an IRPrinter that emits C++ code that:
7  * 1. wraps PyTorch's C++ tensor into Halide * buffers,
8  * 2. calls the corresponding Halide operator.
9  * 3. maps the output buffer back to a PyTorch tensor.
10  *
11  * The generated code checks for runtime errors and raises PyTorch exception
12  * accordingly. It also makes sure the GPU device and stream are consistent when
13  * the PyTorch input, when applicable.
14  */
15 
16 #include "IRPrinter.h"
17 
18 namespace Halide {
19 
20 class Module;
21 
22 namespace Internal {
23 
24 struct LoweredFunc;
25 
26 /** This class emits C++ code to wrap a Halide pipeline so that it can
27  * be used as a C++ extension operator in PyTorch.
28  */
29 class CodeGen_PyTorch : public IRPrinter {
30 public:
31  CodeGen_PyTorch(std::ostream &dest);
32  ~CodeGen_PyTorch() override = default;
33 
34  /** Emit the PyTorch C++ wrapper for the Halide pipeline. */
35  void compile(const Module &module);
36 
37 private:
38  void compile(const LoweredFunc &func, bool is_cuda);
39 };
40 
41 } // namespace Internal
42 } // namespace Halide
43 
44 #endif // HALIDE_CODEGEN_PYTORCH_H
IRPrinter.h
Halide::Internal::IRPrinter
An IRVisitor that emits IR to the given output stream in a human readable form.
Definition: IRPrinter.h:102
Halide::Internal::CodeGen_PyTorch::compile
void compile(const Module &module)
Emit the PyTorch C++ wrapper for the Halide pipeline.
Halide::Module
A halide module.
Definition: Module.h:138
Halide::Internal::CodeGen_PyTorch::CodeGen_PyTorch
CodeGen_PyTorch(std::ostream &dest)
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AbstractGenerator.h:19
Halide::LinkageType::Internal
@ Internal
Not visible externally, similar to 'static' linkage in C.
Halide::Internal::CodeGen_PyTorch::~CodeGen_PyTorch
~CodeGen_PyTorch() override=default
Halide::Internal::LoweredFunc
Definition of a lowered function.
Definition: Module.h:97
Halide::Internal::CodeGen_PyTorch
This class emits C++ code to wrap a Halide pipeline so that it can be used as a C++ extension operato...
Definition: CodeGen_PyTorch.h:29