Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
CodeGen_PyTorch.h
Go to the documentation of this file.
1#ifndef HALIDE_CODEGEN_PYTORCH_H
2#define HALIDE_CODEGEN_PYTORCH_H
3
4/** \file
5 *
6 * Defines an IRPrinter that emits C++ code that:
7 * 1. wraps PyTorch's C++ tensor into Halide * buffers,
8 * 2. calls the corresponding Halide operator.
9 * 3. maps the output buffer back to a PyTorch tensor.
10 *
11 * The generated code checks for runtime errors and raises PyTorch exception
12 * accordingly. It also makes sure the GPU device and stream are consistent when
13 * the PyTorch input, when applicable.
14 */
15
16#include "IRPrinter.h"
17
18namespace Halide {
19
20class Module;
21
22namespace Internal {
23
24struct LoweredFunc;
25
26/** This class emits C++ code to wrap a Halide pipeline so that it can
27 * be used as a C++ extension operator in PyTorch.
28 */
29class CodeGen_PyTorch : public IRPrinter {
30public:
31 CodeGen_PyTorch(std::ostream &dest);
32 ~CodeGen_PyTorch() override = default;
33
34 /** Emit the PyTorch C++ wrapper for the Halide pipeline. */
35 void compile(const Module &module);
36
37private:
38 void compile(const LoweredFunc &func, bool is_cuda);
39};
40
41} // namespace Internal
42} // namespace Halide
43
44#endif // HALIDE_CODEGEN_PYTORCH_H
This header file defines operators that let you dump a Halide expression, statement,...
This class emits C++ code to wrap a Halide pipeline so that it can be used as a C++ extension operato...
CodeGen_PyTorch(std::ostream &dest)
~CodeGen_PyTorch() override=default
void compile(const Module &module)
Emit the PyTorch C++ wrapper for the Halide pipeline.
An IRVisitor that emits IR to the given output stream in a human readable form.
Definition IRPrinter.h:122
A halide module.
Definition Module.h:142
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
Definition of a lowered function.
Definition Module.h:101