1 #ifndef HL_PYTORCH_WRAPPER_H
2 #define HL_PYTORCH_WRAPPER_H
21 #define HLPT_CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
22 #define HLPT_CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
23 #define HLPT_CHECK_DEVICE(x, dev) AT_ASSERTM(x.is_cuda() && x.get_device() == dev, #x " must be a CUDA tensor")
30 inline std::vector<int>
get_dims(
const at::Tensor tensor) {
31 int ndims = tensor.ndimension();
32 std::vector<int> dims(ndims, 0);
34 for (
int dim = 0; dim < ndims; ++dim) {
35 dims[dim] = tensor.size(ndims - 1 - dim);
40 template<
class scalar_t>
42 AT_ERROR(
"Scalar type ", tensor.scalar_type(),
" not handled by Halide's PyTorch wrapper");
48 #ifdef AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS
49 #define HL_PYTORCH_API_VERSION 13
51 #define HL_PYTORCH_API_VERSION 12
54 #if HL_PYTORCH_API_VERSION >= 13
57 #define HL_PT_DEFINE_TYPECHECK(ctype, ttype) \
59 inline void check_type<ctype>(at::Tensor & tensor) { \
60 AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
65 #undef HL_PT_DEFINE_TYPECHECK
67 #else // HL_PYTORCH_API_VERSION < 13
71 #define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3) \
73 inline void check_type<ctype>(at::Tensor & tensor) { \
74 AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
79 #undef HL_PT_DEFINE_TYPECHECK
81 #endif // HL_PYTORCH_API_VERSION check
83 template<
class scalar_t>
85 check_type<scalar_t>(tensor);
86 std::vector<int> dims =
get_dims(tensor);
87 #if HL_PYTORCH_API_VERSION >= 13
88 scalar_t *pData = tensor.data_ptr<scalar_t>();
90 scalar_t *pData = tensor.data<scalar_t>();
95 template<
class scalar_t>
97 check_type<scalar_t>(tensor);
98 std::vector<int> dims =
get_dims(tensor);
99 #if HL_PYTORCH_API_VERSION >= 13
100 scalar_t *pData = tensor.data_ptr<scalar_t>();
102 scalar_t *pData = tensor.data<scalar_t>();
104 AT_ASSERTM(tensor.is_cuda(),
"expected input tensor to be on a CUDA device.");
110 AT_ASSERTM(err == 0,
"(CUDA) halide_device_wrap failed");
120 #endif // HL_PYTORCH_WRAPPER_H