Halide
Halide::PyTorch Namespace Reference

Functions

std::vector< int > get_dims (const at::Tensor tensor)
 
template<class scalar_t >
void check_type (at::Tensor &tensor)
 
 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX (HL_PT_DEFINE_TYPECHECK)
 
template<class scalar_t >
Buffer< scalar_t > wrap (at::Tensor &tensor)
 

Function Documentation

◆ get_dims()

std::vector<int> Halide::PyTorch::get_dims ( const at::Tensor  tensor)
inline

Definition at line 34 of file HalidePyTorchHelpers.h.

Referenced by wrap().

◆ check_type()

template<class scalar_t >
void Halide::PyTorch::check_type ( at::Tensor &  tensor)
inline

Definition at line 45 of file HalidePyTorchHelpers.h.

◆ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX()

Halide::PyTorch::AT_FORALL_SCALAR_TYPES_WITH_COMPLEX ( HL_PT_DEFINE_TYPECHECK  )

◆ wrap()

template<class scalar_t >
Buffer<scalar_t> Halide::PyTorch::wrap ( at::Tensor &  tensor)
inline

Definition at line 88 of file HalidePyTorchHelpers.h.

References get_dims(), and halide_cuda_device_interface().