Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
HalidePyTorchHelpers.h File Reference

Set of utility functions to wrap PyTorch tensors into Halide buffers, making sure the data in on the correct device (CPU/GPU). More...

#include <exception>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include "HalideBuffer.h"

Go to the source code of this file.

Namespaces

namespace  Halide
 This file defines the class FunctionDAG, which is our representation of a Halide pipeline, and contains methods to using Halide's bounds tools to query properties of it.
 
namespace  Halide::PyTorch
 

Macros

#define HLPT_CHECK_CONTIGUOUS(x)
 
#define HLPT_CHECK_CUDA(x)
 
#define HLPT_CHECK_DEVICE(x, dev)
 
#define HL_PYTORCH_API_VERSION   12
 
#define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3)
 

Functions

const halide_device_interface_thalide_cuda_device_interface ()
 
std::vector< int > Halide::PyTorch::get_dims (const at::Tensor tensor)
 
template<class scalar_t >
void Halide::PyTorch::check_type (at::Tensor &tensor)
 
 Halide::PyTorch::AT_FORALL_SCALAR_TYPES_WITH_COMPLEX (HL_PT_DEFINE_TYPECHECK)
 
template<class scalar_t >
Buffer< scalar_t > Halide::PyTorch::wrap (at::Tensor &tensor)
 
template<class scalar_t >
Buffer< scalar_t > Halide::PyTorch::wrap_cuda (at::Tensor &tensor)
 

Detailed Description

Set of utility functions to wrap PyTorch tensors into Halide buffers, making sure the data in on the correct device (CPU/GPU).

This header is included in each generated op by the PyTorch CodeGen.

Definition in file HalidePyTorchHelpers.h.

Macro Definition Documentation

◆ HLPT_CHECK_CONTIGUOUS

#define HLPT_CHECK_CONTIGUOUS ( x)
Value:
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")

Definition at line 21 of file HalidePyTorchHelpers.h.

◆ HLPT_CHECK_CUDA

#define HLPT_CHECK_CUDA ( x)
Value:
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")

Definition at line 22 of file HalidePyTorchHelpers.h.

◆ HLPT_CHECK_DEVICE

#define HLPT_CHECK_DEVICE ( x,
dev )
Value:
AT_ASSERTM(x.is_cuda() && x.get_device() == dev, #x " must be a CUDA tensor")

Definition at line 23 of file HalidePyTorchHelpers.h.

◆ HL_PYTORCH_API_VERSION

#define HL_PYTORCH_API_VERSION   12

Definition at line 51 of file HalidePyTorchHelpers.h.

◆ HL_PT_DEFINE_TYPECHECK

#define HL_PT_DEFINE_TYPECHECK ( ctype,
ttype,
_3 )
Value:
template<> \
inline void check_type<ctype>(at::Tensor & tensor) { \
AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \
}

Definition at line 71 of file HalidePyTorchHelpers.h.

Function Documentation

◆ halide_cuda_device_interface()

const halide_device_interface_t * halide_cuda_device_interface ( )