Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
HalidePyTorchCudaHelpers.h
Go to the documentation of this file.
1#ifndef HL_PYTORCH_CUDA_HELPERS_H
2#define HL_PYTORCH_CUDA_HELPERS_H
3
4/** \file
5 * Override Halide's CUDA hooks so that the Halide code called from PyTorch uses
6 * the correct GPU device and stream. This header should be included once in
7 * the PyTorch/C++ binding source file (see apps/HelloPyTorch/setup.py for an
8 * example).
9 */
10
11#include "HalideRuntimeCuda.h"
12#include "cuda.h"
13#include "cuda_runtime.h"
14
15namespace Halide {
16namespace PyTorch {
17
18typedef struct UserContext {
19 UserContext(int id, CUcontext *ctx, cudaStream_t *stream)
20 : device_id(id), cuda_context(ctx), stream(stream){};
21
23 CUcontext *cuda_context;
24 cudaStream_t *stream;
26
27} // namespace PyTorch
28} // namespace Halide
29
30// Replace Halide weakly-linked CUDA handles
31extern "C" {
32
33int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) {
34 if (user_context != nullptr) {
36 *ctx = *user_ctx->cuda_context;
37 } else {
38 *ctx = nullptr;
39 }
41}
42
43int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) {
44 if (user_context != nullptr) {
46 *stream = *user_ctx->stream;
47 } else {
48 *stream = 0;
49 }
51}
52
54 if (user_context != nullptr) {
56 return user_ctx->device_id;
57 } else {
58 return 0;
59 }
60}
61
62} // extern "C"
63
64#endif /* end of include guard: HL_PYTORCH_CUDA_HELPERS_H */
int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create=true)
int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream)
int halide_get_gpu_device(void *user_context)
Halide calls this to get the desired halide gpu device setting.
@ halide_error_code_success
There was no error.
Routines specific to the Halide Cuda runtime.
struct Halide::PyTorch::UserContext UserContext
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
UserContext(int id, CUcontext *ctx, cudaStream_t *stream)
void * user_context