Halide
19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
HalidePyTorchCudaHelpers.h
Go to the documentation of this file.
1
#ifndef HL_PYTORCH_CUDA_HELPERS_H
2
#define HL_PYTORCH_CUDA_HELPERS_H
3
4
/** \file
5
* Override Halide's CUDA hooks so that the Halide code called from PyTorch uses
6
* the correct GPU device and stream. This header should be included once in
7
* the PyTorch/C++ binding source file (see apps/HelloPyTorch/setup.py for an
8
* example).
9
*/
10
11
#include "
HalideRuntimeCuda.h
"
12
#include "cuda.h"
13
#include "cuda_runtime.h"
14
15
namespace
Halide
{
16
namespace
PyTorch {
17
18
typedef
struct
UserContext
{
19
UserContext
(
int
id
, CUcontext *ctx, cudaStream_t *
stream
)
20
:
device_id
(id),
cuda_context
(ctx),
stream
(
stream
){};
21
22
int
device_id
;
23
CUcontext *
cuda_context
;
24
cudaStream_t *
stream
;
25
}
UserContext
;
26
27
}
// namespace PyTorch
28
}
// namespace Halide
29
30
// Replace Halide weakly-linked CUDA handles
31
extern
"C"
{
32
33
int
halide_cuda_acquire_context
(
void
*user_context, CUcontext *ctx,
bool
create =
true
) {
34
if
(user_context !=
nullptr
) {
35
Halide::PyTorch::UserContext
*user_ctx = (
Halide::PyTorch::UserContext
*)user_context;
36
*ctx = *user_ctx->
cuda_context
;
37
}
else
{
38
*ctx =
nullptr
;
39
}
40
return
halide_error_code_success
;
41
}
42
43
int
halide_cuda_get_stream
(
void
*user_context, CUcontext ctx, CUstream *stream) {
44
if
(user_context !=
nullptr
) {
45
Halide::PyTorch::UserContext
*user_ctx = (
Halide::PyTorch::UserContext
*)user_context;
46
*stream = *user_ctx->
stream
;
47
}
else
{
48
*stream = 0;
49
}
50
return
halide_error_code_success
;
51
}
52
53
int
halide_get_gpu_device
(
void
*user_context) {
54
if
(user_context !=
nullptr
) {
55
Halide::PyTorch::UserContext
*user_ctx = (
Halide::PyTorch::UserContext
*)user_context;
56
return
user_ctx->
device_id
;
57
}
else
{
58
return
0;
59
}
60
}
61
62
}
// extern "C"
63
64
#endif
/* end of include guard: HL_PYTORCH_CUDA_HELPERS_H */
halide_cuda_acquire_context
int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create=true)
Definition
HalidePyTorchCudaHelpers.h:33
halide_cuda_get_stream
int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream)
Definition
HalidePyTorchCudaHelpers.h:43
halide_get_gpu_device
int halide_get_gpu_device(void *user_context)
Halide calls this to get the desired halide gpu device setting.
Definition
HalidePyTorchCudaHelpers.h:53
halide_error_code_success
@ halide_error_code_success
There was no error.
Definition
HalideRuntime.h:1072
HalideRuntimeCuda.h
Routines specific to the Halide Cuda runtime.
Halide::PyTorch::UserContext
struct Halide::PyTorch::UserContext UserContext
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition
AbstractGenerator.h:19
Halide::PyTorch::UserContext
Definition
HalidePyTorchCudaHelpers.h:18
Halide::PyTorch::UserContext::cuda_context
CUcontext * cuda_context
Definition
HalidePyTorchCudaHelpers.h:23
Halide::PyTorch::UserContext::device_id
int device_id
Definition
HalidePyTorchCudaHelpers.h:22
Halide::PyTorch::UserContext::UserContext
UserContext(int id, CUcontext *ctx, cudaStream_t *stream)
Definition
HalidePyTorchCudaHelpers.h:19
Halide::PyTorch::UserContext::stream
cudaStream_t * stream
Definition
HalidePyTorchCudaHelpers.h:24
src
runtime
HalidePyTorchCudaHelpers.h
Generated by
1.12.0