Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
HalidePyTorchCudaHelpers.h File Reference

Override Halide's CUDA hooks so that the Halide code called from PyTorch uses the correct GPU device and stream. More...

#include "HalideRuntimeCuda.h"
#include "cuda.h"
#include "cuda_runtime.h"

Go to the source code of this file.

Classes

struct  Halide::PyTorch::UserContext
 

Namespaces

namespace  Halide
 This file defines the class FunctionDAG, which is our representation of a Halide pipeline, and contains methods to using Halide's bounds tools to query properties of it.
 
namespace  Halide::PyTorch
 

Typedefs

typedef struct Halide::PyTorch::UserContext Halide::PyTorch::UserContext
 

Functions

int halide_cuda_acquire_context (void *user_context, CUcontext *ctx, bool create=true)
 
int halide_cuda_get_stream (void *user_context, CUcontext ctx, CUstream *stream)
 
int halide_get_gpu_device (void *user_context)
 Halide calls this to get the desired halide gpu device setting.
 

Detailed Description

Override Halide's CUDA hooks so that the Halide code called from PyTorch uses the correct GPU device and stream.

This header should be included once in the PyTorch/C++ binding source file (see apps/HelloPyTorch/setup.py for an example).

Definition in file HalidePyTorchCudaHelpers.h.

Function Documentation

◆ halide_cuda_acquire_context()

int halide_cuda_acquire_context ( void * user_context,
CUcontext * ctx,
bool create = true )

◆ halide_cuda_get_stream()

int halide_cuda_get_stream ( void * user_context,
CUcontext ctx,
CUstream * stream )

◆ halide_get_gpu_device()

int halide_get_gpu_device ( void * user_context)

Halide calls this to get the desired halide gpu device setting.

Implement this yourself to use a different gpu device per user_context. The default implementation returns the value set by halide_set_gpu_device, or the environment variable HL_GPU_DEVICE.

Definition at line 53 of file HalidePyTorchCudaHelpers.h.

References Halide::PyTorch::UserContext::device_id, and user_context.