1#ifndef CHECK_CALL_GRAPHS_H
2#define CHECK_CALL_GRAPHS_H
14typedef std::map<std::string, std::vector<std::string>>
CallGraphs;
43 if (std::find(callees.begin(), callees.end(), op->
name) == callees.end()) {
44 callees.push_back(op->
name);
60 if (result.size() != expected.size()) {
61 printf(
"Expect %d callers instead of %d\n", (
int)expected.size(), (
int)result.size());
64 for (
auto &iter : expected) {
65 if (result.count(iter.first) == 0) {
66 printf(
"Expect %s to be in the call graphs\n", iter.first.c_str());
69 std::vector<std::string> &expected_callees = iter.second;
70 std::vector<std::string> &result_callees = result[iter.first];
71 std::sort(expected_callees.begin(), expected_callees.end());
72 std::sort(result_callees.begin(), result_callees.end());
73 if (expected_callees != result_callees) {
74 std::string expected_str = std::accumulate(
75 expected_callees.begin(), expected_callees.end(), std::string{},
76 [](
const std::string &a,
const std::string &b) {
77 return a.empty() ? b : a +
", " + b;
79 std::string result_str = std::accumulate(
80 result_callees.begin(), result_callees.end(), std::string{},
81 [](
const std::string &a,
const std::string &b) {
82 return a.empty() ? b : a +
", " + b;
85 printf(
"Expect callees of %s to be (%s); got (%s) instead\n",
86 iter.first.c_str(), expected_str.c_str(), result_str.c_str());
93template<
typename T,
typename F>
95 for (
int y = 0; y < im.height(); y++) {
96 for (
int x = 0; x < im.width(); x++) {
97 T correct = func(x, y);
98 if (im(x, y) != correct) {
99 std::cout <<
"im(" << x <<
", " << y <<
") = " << im(x, y)
100 <<
" instead of " << correct <<
"\n";
108template<
typename T,
typename F>
110 for (
int z = 0; z < im.channels(); z++) {
111 for (
int y = 0; y < im.height(); y++) {
112 for (
int x = 0; x < im.width(); x++) {
113 T correct = func(x, y, z);
114 if (im(x, y, z) != correct) {
115 std::cout <<
"im(" << x <<
", " << y <<
", " << z <<
") = "
116 << im(x, y, z) <<
" instead of " << correct <<
"\n";
125template<
typename T,
typename F>
131template<
typename T,
typename F>
int check_call_graphs(Halide::Pipeline p, CallGraphs &expected)
int check_image2(const Halide::Buffer< T > &im, const F &func)
auto check_image(const Halide::Buffer< T > &im, const F &func) -> decltype(std::declval< F >()(0, 0), int())
int check_image3(const Halide::Buffer< T > &im, const F &func)
std::map< std::string, std::vector< std::string > > CallGraphs
A Halide::Buffer is a named shared reference to a Halide::Runtime::Buffer.
A base class for passes over the IR which modify it (e.g.
virtual Expr visit(const IntImm *)
virtual Expr mutate(const Expr &expr)
This is the main interface for using a mutator.
A class representing a Halide pipeline.
std::vector< Argument > infer_arguments(const Internal::Stmt &body)
Module compile_to_module(const std::vector< Argument > &args, const std::string &fn_name, const Target &target=get_target_from_environment(), LinkageType linkage_type=LinkageType::ExternalPlusMetadata)
Create an internal representation of lowered code as a self contained Module suitable for further com...
void add_custom_lowering_pass(T *pass)
Add a custom pass to be used during lowering.
A fragment of Halide syntax.
Load a value from a named symbol if predicate is true.
This node is a helpful annotation to do with permissions.
A reference-counted handle to a statement node.