Halide
CodeGen_PyTorch.h
Go to the documentation of this file.
1 #ifndef HALIDE_CODEGEN_PYTORCH_H
2 #define HALIDE_CODEGEN_PYTORCH_H
3 
4 /** \file
5  *
6  * Defines an IRPrinter that emits C++ code that:
7  * 1. wraps PyTorch's C++ tensor into Halide * buffers,
8  * 2. calls the corresponding Halide operator.
9  * 3. maps the output buffer back to a PyTorch tensor.
10  *
11  * The generated code checks for runtime errors and raises PyTorch exception
12  * accordingly. It also makes sure the GPU device and stream are consistent when
13  * the PyTorch input, when applicable.
14  */
15 
16 #include "IRPrinter.h"
17 #include "Module.h"
18 
19 namespace Halide {
20 namespace Internal {
21 
22 /** This class emits C++ code to wrap a Halide pipeline so that it can
23  * be used as a C++ extension operator in PyTorch.
24  */
25 class CodeGen_PyTorch : public IRPrinter {
26 public:
27  CodeGen_PyTorch(std::ostream &dest);
28  ~CodeGen_PyTorch() override = default;
29 
30  /** Emit the PyTorch C++ wrapper for the Halide pipeline. */
31  void compile(const Module &module);
32 
33  static void test();
34 
35 private:
36  void compile(const LoweredFunc &func, bool is_cuda);
37 };
38 
39 } // namespace Internal
40 } // namespace Halide
41 
42 #endif // HALIDE_CODEGEN_PYTORCH_H
IRPrinter.h
Halide::Internal::IRPrinter
An IRVisitor that emits IR to the given output stream in a human readable form.
Definition: IRPrinter.h:98
Halide::Internal::CodeGen_PyTorch::compile
void compile(const Module &module)
Emit the PyTorch C++ wrapper for the Halide pipeline.
Halide::Module
A halide module.
Definition: Module.h:136
Halide::Internal::CodeGen_PyTorch::CodeGen_PyTorch
CodeGen_PyTorch(std::ostream &dest)
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AddAtomicMutex.h:21
Halide::LinkageType::Internal
@ Internal
Not visible externally, similar to 'static' linkage in C.
Halide::Internal::CodeGen_PyTorch::~CodeGen_PyTorch
~CodeGen_PyTorch() override=default
Halide::Internal::LoweredFunc
Definition of a lowered function.
Definition: Module.h:97
Halide::Internal::CodeGen_PyTorch
This class emits C++ code to wrap a Halide pipeline so that it can be used as a C++ extension operato...
Definition: CodeGen_PyTorch.h:25
Module.h
Halide::Internal::CodeGen_PyTorch::test
static void test()