Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
HalidePyTorchHelpers.h
Go to the documentation of this file.
1#ifndef HL_PYTORCH_WRAPPER_H
2#define HL_PYTORCH_WRAPPER_H
3
4/** \file
5 * Set of utility functions to wrap PyTorch tensors into Halide buffers,
6 * making sure the data in on the correct device (CPU/GPU). This header
7 * is included in each generated op by the PyTorch CodeGen.
8 */
9
10#include <exception>
11#include <iostream>
12#include <sstream>
13#include <string>
14#include <vector>
15
16#include "HalideBuffer.h"
17
18// Forward declare the cuda_device_interface, for tensor wrapper.
20
21#define HLPT_CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
22#define HLPT_CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
23#define HLPT_CHECK_DEVICE(x, dev) AT_ASSERTM(x.is_cuda() && x.get_device() == dev, #x " must be a CUDA tensor")
24
25namespace Halide {
26namespace PyTorch {
27
29
30inline std::vector<int> get_dims(const at::Tensor tensor) {
31 int ndims = tensor.ndimension();
32 std::vector<int> dims(ndims, 0);
33 // PyTorch dim order is reverse of Halide
34 for (int dim = 0; dim < ndims; ++dim) {
35 dims[dim] = tensor.size(ndims - 1 - dim);
36 }
37 return dims;
38}
39
40template<class scalar_t>
41inline void check_type(at::Tensor &tensor) {
42 AT_ERROR("Scalar type ", tensor.scalar_type(), " not handled by Halide's PyTorch wrapper");
43}
44
45// TODO: if PyTorch exposes any variable with the API version,
46// I haven't found it in source or documentation; for now, we'll sniff
47// this macro's existence to infer that we are building with v1.3+ (vs 1.2)
48#ifdef AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS
49#define HL_PYTORCH_API_VERSION 13
50#else
51#define HL_PYTORCH_API_VERSION 12
52#endif
53
54#if HL_PYTORCH_API_VERSION >= 13
55
56// PyTorch 1.3+
57#define HL_PT_DEFINE_TYPECHECK(ctype, ttype) \
58 template<> \
59 inline void check_type<ctype>(at::Tensor & tensor) { \
60 AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
61 }
62
63AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(HL_PT_DEFINE_TYPECHECK);
64
65#undef HL_PT_DEFINE_TYPECHECK
66
67#else // HL_PYTORCH_API_VERSION < 13
68
69// PyTorch 1.2
70
71#define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3) \
72 template<> \
73 inline void check_type<ctype>(at::Tensor & tensor) { \
74 AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
75 }
76
78
79#undef HL_PT_DEFINE_TYPECHECK
80
81#endif // HL_PYTORCH_API_VERSION check
82
83template<class scalar_t>
84inline Buffer<scalar_t> wrap(at::Tensor &tensor) {
86 std::vector<int> dims = get_dims(tensor);
87#if HL_PYTORCH_API_VERSION >= 13
88 scalar_t *pData = tensor.data_ptr<scalar_t>();
89#else
90 scalar_t *pData = tensor.data<scalar_t>();
91#endif
92 return Buffer<scalar_t>(pData, dims);
93}
94
95template<class scalar_t>
96inline Buffer<scalar_t> wrap_cuda(at::Tensor &tensor) {
98 std::vector<int> dims = get_dims(tensor);
99#if HL_PYTORCH_API_VERSION >= 13
100 scalar_t *pData = tensor.data_ptr<scalar_t>();
101#else
102 scalar_t *pData = tensor.data<scalar_t>();
103#endif
104 AT_ASSERTM(tensor.is_cuda(), "expected input tensor to be on a CUDA device.");
105
106 Buffer<scalar_t> buffer(dims);
107
109 int err = buffer.device_wrap_native(cuda_interface, (uint64_t)pData);
110 AT_ASSERTM(err == 0, "(CUDA) halide_device_wrap failed");
111
112 buffer.set_device_dirty();
113
114 return buffer;
115}
116
117} // namespace PyTorch
118} // namespace Halide
119
120#endif // HL_PYTORCH_WRAPPER_H
Defines a Buffer type that wraps from halide_buffer_t and adds functionality, and methods for more co...
const halide_device_interface_t * halide_cuda_device_interface()
#define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3)
A templated Buffer class that wraps halide_buffer_t and adds functionality.
void set_device_dirty(bool v=true)
int device_wrap_native(const struct halide_device_interface_t *device_interface, uint64_t handle, void *ctx=nullptr)
std::vector< int > get_dims(const at::Tensor tensor)
Buffer< scalar_t > wrap(at::Tensor &tensor)
Buffer< scalar_t > wrap_cuda(at::Tensor &tensor)
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(HL_PT_DEFINE_TYPECHECK)
void check_type(at::Tensor &tensor)
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
unsigned __INT64_TYPE__ uint64_t
Each GPU API provides a halide_device_interface_t struct pointing to the code that manages device all...