Halide
HalidePyTorchCudaHelpers.h
Go to the documentation of this file.
1 #ifndef HL_PYTORCH_CUDA_HELPERS_H
2 #define HL_PYTORCH_CUDA_HELPERS_H
3 
4 /** \file
5  * Override Halide's CUDA hooks so that the Halide code called from PyTorch uses
6  * the correct GPU device and stream. This header should be included once in
7  * the PyTorch/C++ binding source file (see apps/HelloPyTorch/setup.py for an
8  * example).
9  */
10 
11 #include "HalideRuntimeCuda.h"
12 #include "cuda.h"
13 #include "cuda_runtime.h"
14 
15 namespace Halide {
16 namespace PyTorch {
17 
18 typedef struct UserContext {
19  UserContext(int id, CUcontext *ctx, cudaStream_t *stream)
21 
22  int device_id;
24  cudaStream_t *stream;
25 } UserContext;
26 
27 } // namespace PyTorch
28 } // namespace Halide
29 
30 // Replace Halide weakly-linked CUDA handles
31 extern "C" {
32 
33 int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) {
34  if (user_context != nullptr) {
36  *ctx = *user_ctx->cuda_context;
37  } else {
38  *ctx = nullptr;
39  }
41 }
42 
43 int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) {
44  if (user_context != nullptr) {
46  *stream = *user_ctx->stream;
47  } else {
48  *stream = 0;
49  }
51 }
52 
53 int halide_get_gpu_device(void *user_context) {
54  if (user_context != nullptr) {
56  return user_ctx->device_id;
57  } else {
58  return 0;
59  }
60 }
61 
62 } // extern "C"
63 
64 #endif /* end of include guard: HL_PYTORCH_CUDA_HELPERS_H */
Halide::PyTorch::UserContext
Definition: HalidePyTorchCudaHelpers.h:18
halide_cuda_acquire_context
int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create=true)
Definition: HalidePyTorchCudaHelpers.h:33
halide_get_gpu_device
int halide_get_gpu_device(void *user_context)
Halide calls this to get the desired halide gpu device setting.
Definition: HalidePyTorchCudaHelpers.h:53
halide_error_code_success
@ halide_error_code_success
There was no error.
Definition: HalideRuntime.h:1039
Halide::PyTorch::UserContext::cuda_context
CUcontext * cuda_context
Definition: HalidePyTorchCudaHelpers.h:23
HalideRuntimeCuda.h
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AbstractGenerator.h:19
Halide::PyTorch::UserContext::stream
cudaStream_t * stream
Definition: HalidePyTorchCudaHelpers.h:24
halide_cuda_get_stream
int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream)
Definition: HalidePyTorchCudaHelpers.h:43
Halide::Runtime::Internal::Cuda::CUcontext
struct CUctx_st * CUcontext
CUDA context.
Definition: mini_cuda.h:22
Halide::PyTorch::UserContext::device_id
int device_id
Definition: HalidePyTorchCudaHelpers.h:20
Halide::PyTorch::UserContext
struct Halide::PyTorch::UserContext UserContext
Halide::Runtime::Internal::Cuda::CUstream
struct CUstream_st * CUstream
CUDA stream.
Definition: mini_cuda.h:25
Halide::PyTorch::UserContext::UserContext
UserContext(int id, CUcontext *ctx, cudaStream_t *stream)
Definition: HalidePyTorchCudaHelpers.h:19