Halide
HalidePyTorchCudaHelpers.h
Go to the documentation of this file.
1
#ifndef HL_PYTORCH_CUDA_HELPERS_H
2
#define HL_PYTORCH_CUDA_HELPERS_H
3
4
/** \file
5
* Override Halide's CUDA hooks so that the Halide code called from PyTorch uses
6
* the correct GPU device and stream. This header should be included once in
7
* the PyTorch/C++ binding source file (see apps/HelloPyTorch/setup.py for an
8
* example).
9
*/
10
11
#include "
HalideRuntimeCuda.h
"
12
#include "cuda.h"
13
#include "cuda_runtime.h"
14
15
namespace
Halide
{
16
namespace
PyTorch {
17
18
typedef
struct
UserContext
{
19
UserContext
(
int
id
,
CUcontext
*ctx, cudaStream_t *
stream
)
20
:
device_id
(id),
cuda_context
(ctx),
stream
(
stream
){};
21
22
int
device_id
;
23
CUcontext
*
cuda_context
;
24
cudaStream_t *
stream
;
25
}
UserContext
;
26
27
}
// namespace PyTorch
28
}
// namespace Halide
29
30
// Replace Halide weakly-linked CUDA handles
31
extern
"C"
{
32
33
int
halide_cuda_acquire_context
(
void
*user_context,
CUcontext
*ctx,
bool
create =
true
) {
34
if
(user_context !=
nullptr
) {
35
Halide::PyTorch::UserContext
*user_ctx = (
Halide::PyTorch::UserContext
*)user_context;
36
*ctx = *user_ctx->
cuda_context
;
37
}
else
{
38
*ctx =
nullptr
;
39
}
40
return
halide_error_code_success
;
41
}
42
43
int
halide_cuda_get_stream
(
void
*user_context,
CUcontext
ctx,
CUstream
*stream) {
44
if
(user_context !=
nullptr
) {
45
Halide::PyTorch::UserContext
*user_ctx = (
Halide::PyTorch::UserContext
*)user_context;
46
*stream = *user_ctx->
stream
;
47
}
else
{
48
*stream = 0;
49
}
50
return
halide_error_code_success
;
51
}
52
53
int
halide_get_gpu_device
(
void
*user_context) {
54
if
(user_context !=
nullptr
) {
55
Halide::PyTorch::UserContext
*user_ctx = (
Halide::PyTorch::UserContext
*)user_context;
56
return
user_ctx->
device_id
;
57
}
else
{
58
return
0;
59
}
60
}
61
62
}
// extern "C"
63
64
#endif
/* end of include guard: HL_PYTORCH_CUDA_HELPERS_H */
Halide::PyTorch::UserContext
Definition:
HalidePyTorchCudaHelpers.h:18
halide_cuda_acquire_context
int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create=true)
Definition:
HalidePyTorchCudaHelpers.h:33
halide_get_gpu_device
int halide_get_gpu_device(void *user_context)
Halide calls this to get the desired halide gpu device setting.
Definition:
HalidePyTorchCudaHelpers.h:53
halide_error_code_success
@ halide_error_code_success
There was no error.
Definition:
HalideRuntime.h:1039
Halide::PyTorch::UserContext::cuda_context
CUcontext * cuda_context
Definition:
HalidePyTorchCudaHelpers.h:23
HalideRuntimeCuda.h
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition:
AbstractGenerator.h:19
Halide::PyTorch::UserContext::stream
cudaStream_t * stream
Definition:
HalidePyTorchCudaHelpers.h:24
halide_cuda_get_stream
int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream)
Definition:
HalidePyTorchCudaHelpers.h:43
Halide::Runtime::Internal::Cuda::CUcontext
struct CUctx_st * CUcontext
CUDA context.
Definition:
mini_cuda.h:22
Halide::PyTorch::UserContext::device_id
int device_id
Definition:
HalidePyTorchCudaHelpers.h:20
Halide::PyTorch::UserContext
struct Halide::PyTorch::UserContext UserContext
Halide::Runtime::Internal::Cuda::CUstream
struct CUstream_st * CUstream
CUDA stream.
Definition:
mini_cuda.h:25
Halide::PyTorch::UserContext::UserContext
UserContext(int id, CUcontext *ctx, cudaStream_t *stream)
Definition:
HalidePyTorchCudaHelpers.h:19
src
runtime
HalidePyTorchCudaHelpers.h
Generated by
1.8.17