HalidePyTorchHelpers.h File Reference
#include <exception>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include "torch/extension.h"
#include "HalideBuffer.h"

Go to the source code of this file.


 This file defines the class FunctionDAG, which is our representation of a Halide pipeline, and contains methods to using Halide's bounds tools to query properties of it.


#define HLPT_CHECK_CONTIGUOUS(x)   AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define HLPT_CHECK_CUDA(x)   AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define HLPT_CHECK_DEVICE(x, dev)   AT_ASSERTM(x.is_cuda() && x.get_device() == dev, #x " must be a CUDA tensor")
#define HL_PT_DEFINE_TYPECHECK(ctype, ttype, _3)


std::vector< int > Halide::PyTorch::get_dims (const at::Tensor tensor)
template<class scalar_t >
void Halide::PyTorch::check_type (at::Tensor &tensor)
template<class scalar_t >
Buffer< scalar_t > Halide::PyTorch::wrap (at::Tensor &tensor)

Detailed Description

Set of utility functions to wrap PyTorch tensors into Halide buffers, making sure the data in on the correct device (CPU/GPU).

Definition in file HalidePyTorchHelpers.h.

Macro Definition Documentation


#define HLPT_CHECK_CONTIGUOUS (   x)    AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")

Definition at line 25 of file HalidePyTorchHelpers.h.


#define HLPT_CHECK_CUDA (   x)    AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")

Definition at line 26 of file HalidePyTorchHelpers.h.


#define HLPT_CHECK_DEVICE (   x,
)    AT_ASSERTM(x.is_cuda() && x.get_device() == dev, #x " must be a CUDA tensor")

Definition at line 27 of file HalidePyTorchHelpers.h.



Definition at line 55 of file HalidePyTorchHelpers.h.


#define HL_PT_DEFINE_TYPECHECK (   ctype,
template<> \
inline void check_type<ctype>(at::Tensor & tensor) { \
AT_ASSERTM(tensor.scalar_type() == at::ScalarType::ttype, "scalar type do not match"); \

Definition at line 75 of file HalidePyTorchHelpers.h.