Halide
CodeGen_PyTorch.h File Reference
#include "IRPrinter.h"
#include "Module.h"

Go to the source code of this file.

Classes

class  Halide::Internal::CodeGen_PyTorch
 This class emits C++ code to wrap a Halide pipeline so that it can be used as a C++ extension operator in PyTorch. More...
 

Namespaces

 Halide
 This file defines the class FunctionDAG, which is our representation of a Halide pipeline, and contains methods to using Halide's bounds tools to query properties of it.
 
 Halide::Internal
 

Detailed Description

Defines an IRPrinter that emits C++ code that:

  1. wraps PyTorch's C++ tensor into Halide * buffers,
  2. calls the corresponding Halide operator.
  3. maps the output buffer back to a PyTorch tensor.

The generated code checks for runtime errors and raises PyTorch exception accordingly. It also makes sure the GPU device and stream are consistent when the PyTorch input, when applicable.

Definition in file CodeGen_PyTorch.h.