Halide
HalidePyTorchCudaHelpers.h
Go to the documentation of this file.
1 #ifndef HL_PYTORCH_CUDA_HELPERS_H
2 #define HL_PYTORCH_CUDA_HELPERS_H
3 
4 /** \file
5  * Override Halide's CUDA hooks so that the Halide code called from PyTorch uses
6  * the correct GPU device and stream.
7  */
8 
9 #ifdef HL_PT_CUDA
10 #include "HalideRuntimeCuda.h"
11 #include "cuda.h"
12 
13 namespace Halide {
14 namespace PyTorch {
15 
16 typedef struct UserContext {
17  UserContext(int id, CUcontext *ctx, cudaStream_t *stream)
18  : device_id(id), cuda_context(ctx), stream(stream){};
19 
20  int device_id;
21  CUcontext *cuda_context;
22  cudaStream_t *stream;
23 } UserContext;
24 
25 } // namespace PyTorch
26 } // namespace Halide
27 
28 // Replace Halide weakly-linked CUDA handles
29 extern "C" {
30 
31 int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) {
32  if (user_context != NULL) {
33  Halide::PyTorch::UserContext *user_ctx = (Halide::PyTorch::UserContext *)user_context;
34  *ctx = *user_ctx->cuda_context;
35  } else {
36  *ctx = NULL;
37  }
38  return 0;
39 }
40 
41 int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) {
42  if (user_context != NULL) {
43  Halide::PyTorch::UserContext *user_ctx = (Halide::PyTorch::UserContext *)user_context;
44  *stream = *user_ctx->stream;
45  } else {
46  *stream = 0;
47  }
48  return 0;
49 }
50 
52  if (user_context != NULL) {
53  Halide::PyTorch::UserContext *user_ctx = (Halide::PyTorch::UserContext *)user_context;
54  return user_ctx->device_id;
55  } else {
56  return 0;
57  }
58 }
59 
60 } // extern "C"
61 
62 #endif // HL_PT_CUDA
63 
64 #endif /* end of include guard: HL_PYTORCH_CUDA_HELPERS_H */
NULL
#define NULL
Definition: runtime_internal.h:32
HalideRuntimeCuda.h
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AddAtomicMutex.h:21
Halide::Runtime::Internal::Cuda::CUcontext
struct CUctx_st * CUcontext
CUDA context.
Definition: mini_cuda.h:22
Halide::Runtime::Internal::Cuda::CUstream
struct CUstream_st * CUstream
CUDA stream.
Definition: mini_cuda.h:25
halide_get_gpu_device
int halide_get_gpu_device(void *user_context)
Halide calls this to get the desired halide gpu device setting.
user_context
void * user_context
Definition: printer.h:33