Halide
HalidePyTorchHelpers.h
Go to the documentation of this file.
1 #ifndef HL_PYTORCH_WRAPPER_H
2 #define HL_PYTORCH_WRAPPER_H
3 
4 /** \file
5  * Set of utility functions to wrap PyTorch tensors into Halide buffers,
6  * making sure the data in on the correct device (CPU/GPU).
7  */
8 
9 #include <exception>
10 #include <iostream>
11 #include <sstream>
12 #include <string>
13 #include <vector>
14 
15 #include "torch/extension.h"
16 
17 #include "HalideBuffer.h"
18 
19 #ifdef HL_PT_CUDA
20 #include "HalideRuntimeCuda.h"
21 #include "cuda.h"
22 #include "cuda_runtime.h"
23 #endif
24 
25 #define HLPT_CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
26 #define HLPT_CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
27 #define HLPT_CHECK_DEVICE(x, dev) AT_ASSERTM(x.is_cuda() && x.get_device() == dev, #x " must be a CUDA tensor")
28 
29 namespace Halide {
30 namespace PyTorch {
31 
33 
34 inline std::vector<int> get_dims(const at::Tensor tensor) {
35  int ndims = tensor.ndimension();
36  std::vector<int> dims(ndims, 0);
37  // PyTorch dim order is reverse of Halide
38  for (int dim = 0; dim < ndims; ++dim) {
39  dims[dim] = tensor.size(ndims - 1 - dim);
40  }
41  return dims;
42 }
43 
44 template<class scalar_t>
45 inline void check_type(at::Tensor &tensor) {
46  AT_ERROR("Scalar type ", tensor.scalar_type(), " not handled by Halide's PyTorch wrapper");
47 }
48 
49 // TODO: if PyTorch exposes any variable with the API version,
50 // I haven't found it in source or documentation; for now, we'll sniff
51 // this macro's existence to infer that we are building with v1.3+ (vs 1.2)
52 #ifdef AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS
53 #define HL_PYTORCH_API_VERSION 13
54 #else
55 #define HL_PYTORCH_API_VERSION 12
56 #endif
57 
58 #if HL_PYTORCH_API_VERSION >= 13
59 
60 // PyTorch 1.3+
61 #define HL_PT_DEFINE_TYPECHECK(ctype, ttype) \
62  template<> \
63  inline void check_type<ctype>(at::Tensor & tensor) { \
64  AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
65  }
66 
67 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(HL_PT_DEFINE_TYPECHECK);
68 
69 #undef HL_PT_DEFINE_TYPECHECK
70 
71 #else // HL_PYTORCH_API_VERSION < 13
72 
73 // PyTorch 1.2
74 
75 #define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3) \
76  template<> \
77  inline void check_type<ctype>(at::Tensor & tensor) { \
78  AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
79  }
80 
82 
83 #undef HL_PT_DEFINE_TYPECHECK
84 
85 #endif // HL_PYTORCH_API_VERSION check
86 
87 template<class scalar_t>
88 inline Buffer<scalar_t> wrap(at::Tensor &tensor) {
89  check_type<scalar_t>(tensor);
90  std::vector<int> dims = get_dims(tensor);
91 #if HL_PYTORCH_API_VERSION >= 13
92  scalar_t *pData = tensor.data_ptr<scalar_t>();
93 #else
94  scalar_t *pData = tensor.data<scalar_t>();
95 #endif
96  Buffer<scalar_t> buffer;
97 
98  // TODO(mgharbi): force Halide to put input/output on GPU?
99  if (tensor.is_cuda()) {
100 #ifdef HL_PT_CUDA
101  buffer = Buffer<scalar_t>(dims);
103  int err = buffer.device_wrap_native(cuda_interface, (uint64_t)pData);
104  AT_ASSERTM(err == 0, "halide_device_wrap failed");
105  buffer.set_device_dirty();
106 #else
107  AT_ERROR("Trying to wrap a CUDA tensor, but HL_PT_CUDA was not defined: cuda is not available");
108 #endif
109  } else {
110  buffer = Buffer<scalar_t>(pData, dims);
111  }
112 
113  return buffer;
114 }
115 
116 } // namespace PyTorch
117 } // namespace Halide
118 
119 #endif // HL_PYTORCH_WRAPPER_H
Halide::PyTorch::AT_FORALL_SCALAR_TYPES_WITH_COMPLEX
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(HL_PT_DEFINE_TYPECHECK)
Halide::PyTorch::get_dims
std::vector< int > get_dims(const at::Tensor tensor)
Definition: HalidePyTorchHelpers.h:34
HalideBuffer.h
HalideRuntimeCuda.h
uint64_t
unsigned __INT64_TYPE__ uint64_t
Definition: runtime_internal.h:19
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AddAtomicMutex.h:21
HL_PT_DEFINE_TYPECHECK
#define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3)
Definition: HalidePyTorchHelpers.h:75
Halide::Runtime::Buffer
A templated Buffer class that wraps halide_buffer_t and adds functionality.
Definition: HalideBuffer.h:43
Halide::PyTorch::check_type
void check_type(at::Tensor &tensor)
Definition: HalidePyTorchHelpers.h:45
halide_device_interface_t
Each GPU API provides a halide_device_interface_t struct pointing to the code that manages device all...
Definition: HalideRuntime.h:721
halide_cuda_device_interface
const struct halide_device_interface_t * halide_cuda_device_interface()
Halide::PyTorch::wrap
Buffer< scalar_t > wrap(at::Tensor &tensor)
Definition: HalidePyTorchHelpers.h:88