Halide
HalidePyTorchHelpers.h File Reference
#include <exception>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include "torch/extension.h"
#include "HalideBuffer.h"

Go to the source code of this file.

Namespaces

 Halide
 This file defines the class FunctionDAG, which is our representation of a Halide pipeline, and contains methods to using Halide's bounds tools to query properties of it.
 
 Halide::PyTorch
 

Macros

#define HLPT_CHECK_CONTIGUOUS(x)   AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
 
#define HLPT_CHECK_CUDA(x)   AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
 
#define HLPT_CHECK_DEVICE(x, dev)   AT_ASSERTM(x.is_cuda() && x.get_device() == dev, #x " must be a CUDA tensor")
 
#define HL_PYTORCH_API_VERSION   12
 
#define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3)
 

Functions

std::vector< int > Halide::PyTorch::get_dims (const at::Tensor tensor)
 
template<class scalar_t >
void Halide::PyTorch::check_type (at::Tensor &tensor)
 
 Halide::PyTorch::AT_FORALL_SCALAR_TYPES_WITH_COMPLEX (HL_PT_DEFINE_TYPECHECK)
 
template<class scalar_t >
Buffer< scalar_t > Halide::PyTorch::wrap (at::Tensor &tensor)
 

Detailed Description

Set of utility functions to wrap PyTorch tensors into Halide buffers, making sure the data in on the correct device (CPU/GPU).

Definition in file HalidePyTorchHelpers.h.

Macro Definition Documentation

◆ HLPT_CHECK_CONTIGUOUS

#define HLPT_CHECK_CONTIGUOUS (   x)    AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")

Definition at line 25 of file HalidePyTorchHelpers.h.

◆ HLPT_CHECK_CUDA

#define HLPT_CHECK_CUDA (   x)    AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")

Definition at line 26 of file HalidePyTorchHelpers.h.

◆ HLPT_CHECK_DEVICE

#define HLPT_CHECK_DEVICE (   x,
  dev 
)    AT_ASSERTM(x.is_cuda() && x.get_device() == dev, #x " must be a CUDA tensor")

Definition at line 27 of file HalidePyTorchHelpers.h.

◆ HL_PYTORCH_API_VERSION

#define HL_PYTORCH_API_VERSION   12

Definition at line 55 of file HalidePyTorchHelpers.h.

◆ HL_PT_DEFINE_TYPECHECK

#define HL_PT_DEFINE_TYPECHECK (   ctype,
  ttype,
  _3 
)
Value:
template<> \
inline void check_type<ctype>(at::Tensor & tensor) { \
AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
}

Definition at line 75 of file HalidePyTorchHelpers.h.