Halide
HalidePyTorchHelpers.h
Go to the documentation of this file.
1 #ifndef HL_PYTORCH_WRAPPER_H
2 #define HL_PYTORCH_WRAPPER_H
3 
4 /** \file
5  * Set of utility functions to wrap PyTorch tensors into Halide buffers,
6  * making sure the data in on the correct device (CPU/GPU). This header
7  * is included in each generated op by the PyTorch CodeGen.
8  */
9 
10 #include <exception>
11 #include <iostream>
12 #include <sstream>
13 #include <string>
14 #include <vector>
15 
16 #include "HalideBuffer.h"
17 
18 // Forward declare the cuda_device_interface, for tensor wrapper.
20 
21 #define HLPT_CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
22 #define HLPT_CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
23 #define HLPT_CHECK_DEVICE(x, dev) AT_ASSERTM(x.is_cuda() && x.get_device() == dev, #x " must be a CUDA tensor")
24 
25 namespace Halide {
26 namespace PyTorch {
27 
29 
30 inline std::vector<int> get_dims(const at::Tensor tensor) {
31  int ndims = tensor.ndimension();
32  std::vector<int> dims(ndims, 0);
33  // PyTorch dim order is reverse of Halide
34  for (int dim = 0; dim < ndims; ++dim) {
35  dims[dim] = tensor.size(ndims - 1 - dim);
36  }
37  return dims;
38 }
39 
40 template<class scalar_t>
41 inline void check_type(at::Tensor &tensor) {
42  AT_ERROR("Scalar type ", tensor.scalar_type(), " not handled by Halide's PyTorch wrapper");
43 }
44 
45 // TODO: if PyTorch exposes any variable with the API version,
46 // I haven't found it in source or documentation; for now, we'll sniff
47 // this macro's existence to infer that we are building with v1.3+ (vs 1.2)
48 #ifdef AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS
49 #define HL_PYTORCH_API_VERSION 13
50 #else
51 #define HL_PYTORCH_API_VERSION 12
52 #endif
53 
54 #if HL_PYTORCH_API_VERSION >= 13
55 
56 // PyTorch 1.3+
57 #define HL_PT_DEFINE_TYPECHECK(ctype, ttype) \
58  template<> \
59  inline void check_type<ctype>(at::Tensor & tensor) { \
60  AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
61  }
62 
63 AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(HL_PT_DEFINE_TYPECHECK);
64 
65 #undef HL_PT_DEFINE_TYPECHECK
66 
67 #else // HL_PYTORCH_API_VERSION < 13
68 
69 // PyTorch 1.2
70 
71 #define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3) \
72  template<> \
73  inline void check_type<ctype>(at::Tensor & tensor) { \
74  AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
75  }
76 
78 
79 #undef HL_PT_DEFINE_TYPECHECK
80 
81 #endif // HL_PYTORCH_API_VERSION check
82 
83 template<class scalar_t>
84 inline Buffer<scalar_t> wrap(at::Tensor &tensor) {
85  check_type<scalar_t>(tensor);
86  std::vector<int> dims = get_dims(tensor);
87 #if HL_PYTORCH_API_VERSION >= 13
88  scalar_t *pData = tensor.data_ptr<scalar_t>();
89 #else
90  scalar_t *pData = tensor.data<scalar_t>();
91 #endif
92  return Buffer<scalar_t>(pData, dims);
93 }
94 
95 template<class scalar_t>
96 inline Buffer<scalar_t> wrap_cuda(at::Tensor &tensor) {
97  check_type<scalar_t>(tensor);
98  std::vector<int> dims = get_dims(tensor);
99 #if HL_PYTORCH_API_VERSION >= 13
100  scalar_t *pData = tensor.data_ptr<scalar_t>();
101 #else
102  scalar_t *pData = tensor.data<scalar_t>();
103 #endif
104  AT_ASSERTM(tensor.is_cuda(), "expected input tensor to be on a CUDA device.");
105 
106  Buffer<scalar_t> buffer(dims);
107 
109  int err = buffer.device_wrap_native(cuda_interface, (uint64_t)pData);
110  AT_ASSERTM(err == 0, "(CUDA) halide_device_wrap failed");
111 
112  buffer.set_device_dirty();
113 
114  return buffer;
115 }
116 
117 } // namespace PyTorch
118 } // namespace Halide
119 
120 #endif // HL_PYTORCH_WRAPPER_H
Halide::PyTorch::AT_FORALL_SCALAR_TYPES_WITH_COMPLEX
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(HL_PT_DEFINE_TYPECHECK)
Halide::PyTorch::get_dims
std::vector< int > get_dims(const at::Tensor tensor)
Definition: HalidePyTorchHelpers.h:30
HalideBuffer.h
Halide::Runtime::Buffer::device_wrap_native
int device_wrap_native(const struct halide_device_interface_t *device_interface, uint64_t handle, void *ctx=nullptr)
Definition: HalideBuffer.h:1826
uint64_t
unsigned __INT64_TYPE__ uint64_t
Definition: runtime_internal.h:23
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AbstractGenerator.h:19
HL_PT_DEFINE_TYPECHECK
#define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3)
Definition: HalidePyTorchHelpers.h:71
Halide::Runtime::Buffer
A templated Buffer class that wraps halide_buffer_t and adds functionality.
Definition: HalideBuffer.h:121
Halide::PyTorch::wrap_cuda
Buffer< scalar_t > wrap_cuda(at::Tensor &tensor)
Definition: HalidePyTorchHelpers.h:96
Halide::PyTorch::check_type
void check_type(at::Tensor &tensor)
Definition: HalidePyTorchHelpers.h:41
halide_cuda_device_interface
const halide_device_interface_t * halide_cuda_device_interface()
Halide::Runtime::Buffer::set_device_dirty
void set_device_dirty(bool v=true)
Definition: HalideBuffer.h:1780
halide_device_interface_t
Each GPU API provides a halide_device_interface_t struct pointing to the code that manages device all...
Definition: HalideRuntime.h:770
Halide::PyTorch::wrap
Buffer< scalar_t > wrap(at::Tensor &tensor)
Definition: HalidePyTorchHelpers.h:84