Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
check_call_graphs.h
Go to the documentation of this file.
1#ifndef CHECK_CALL_GRAPHS_H
2#define CHECK_CALL_GRAPHS_H
3
4#include <algorithm>
5#include <assert.h>
6#include <functional>
7#include <map>
8#include <numeric>
9#include <stdio.h>
10#include <string.h>
11
12#include "Halide.h"
13
14typedef std::map<std::string, std::vector<std::string>> CallGraphs;
15
16// For each producer node, find all functions that it calls.
18public:
19 CallGraphs calls; // Caller -> vector of callees
20 std::string producer = "";
21
22private:
24
26 if (op->is_producer) {
27 std::string old_producer = producer;
28 producer = op->name;
29 calls[producer]; // Make sure each producer is allocated a slot
30 // Group the callees of the 'produce' and 'update' together
31 auto new_stmt = mutate(op->body);
32 producer = old_producer;
33 return new_stmt;
34 } else {
36 }
37 }
38
39 Halide::Expr visit(const Halide::Internal::Load *op) override {
40 if (!producer.empty()) {
41 assert(calls.count(producer) > 0);
42 std::vector<std::string> &callees = calls[producer];
43 if (std::find(callees.begin(), callees.end(), op->name) == callees.end()) {
44 callees.push_back(op->name);
45 }
46 }
48 }
49};
50
51// These are declared "inline" to avoid "unused function" warnings
53 // Add a custom lowering pass that scrapes the call graph. We give ownership
54 // of it to the Pipeline, whose lifetime escapes this function.
55 CheckCalls *checker = new CheckCalls;
56 p.add_custom_lowering_pass(checker);
58 CallGraphs &result = checker->calls;
59
60 if (result.size() != expected.size()) {
61 printf("Expect %d callers instead of %d\n", (int)expected.size(), (int)result.size());
62 return 1;
63 }
64 for (auto &iter : expected) {
65 if (result.count(iter.first) == 0) {
66 printf("Expect %s to be in the call graphs\n", iter.first.c_str());
67 return 1;
68 }
69 std::vector<std::string> &expected_callees = iter.second;
70 std::vector<std::string> &result_callees = result[iter.first];
71 std::sort(expected_callees.begin(), expected_callees.end());
72 std::sort(result_callees.begin(), result_callees.end());
73 if (expected_callees != result_callees) {
74 std::string expected_str = std::accumulate(
75 expected_callees.begin(), expected_callees.end(), std::string{},
76 [](const std::string &a, const std::string &b) {
77 return a.empty() ? b : a + ", " + b;
78 });
79 std::string result_str = std::accumulate(
80 result_callees.begin(), result_callees.end(), std::string{},
81 [](const std::string &a, const std::string &b) {
82 return a.empty() ? b : a + ", " + b;
83 });
84
85 printf("Expect callees of %s to be (%s); got (%s) instead\n",
86 iter.first.c_str(), expected_str.c_str(), result_str.c_str());
87 return 1;
88 }
89 }
90 return 0;
91}
92
93template<typename T, typename F>
94inline int check_image2(const Halide::Buffer<T> &im, const F &func) {
95 for (int y = 0; y < im.height(); y++) {
96 for (int x = 0; x < im.width(); x++) {
97 T correct = func(x, y);
98 if (im(x, y) != correct) {
99 std::cout << "im(" << x << ", " << y << ") = " << im(x, y)
100 << " instead of " << correct << "\n";
101 return 1;
102 }
103 }
104 }
105 return 0;
106}
107
108template<typename T, typename F>
109inline int check_image3(const Halide::Buffer<T> &im, const F &func) {
110 for (int z = 0; z < im.channels(); z++) {
111 for (int y = 0; y < im.height(); y++) {
112 for (int x = 0; x < im.width(); x++) {
113 T correct = func(x, y, z);
114 if (im(x, y, z) != correct) {
115 std::cout << "im(" << x << ", " << y << ", " << z << ") = "
116 << im(x, y, z) << " instead of " << correct << "\n";
117 return 1;
118 }
119 }
120 }
121 }
122 return 0;
123}
124
125template<typename T, typename F>
126inline auto // SFINAE: returns int if F has arity of 2
127check_image(const Halide::Buffer<T> &im, const F &func) -> decltype(std::declval<F>()(0, 0), int()) {
128 return check_image2(im, func);
129}
130
131template<typename T, typename F>
132inline auto // SFINAE: returns int if F has arity of 3
133check_image(const Halide::Buffer<T> &im, const F &func) -> decltype(std::declval<F>()(0, 0, 0), int()) {
134 return check_image3(im, func);
135}
136
137#endif
int check_call_graphs(Halide::Pipeline p, CallGraphs &expected)
int check_image2(const Halide::Buffer< T > &im, const F &func)
auto check_image(const Halide::Buffer< T > &im, const F &func) -> decltype(std::declval< F >()(0, 0), int())
int check_image3(const Halide::Buffer< T > &im, const F &func)
std::map< std::string, std::vector< std::string > > CallGraphs
std::string producer
CallGraphs calls
A Halide::Buffer is a named shared reference to a Halide::Runtime::Buffer.
Definition RDom.h:21
A base class for passes over the IR which modify it (e.g.
Definition IRMutator.h:26
virtual Expr visit(const IntImm *)
virtual Expr mutate(const Expr &expr)
This is the main interface for using a mutator.
A class representing a Halide pipeline.
Definition Pipeline.h:107
std::vector< Argument > infer_arguments(const Internal::Stmt &body)
Module compile_to_module(const std::vector< Argument > &args, const std::string &fn_name, const Target &target=get_target_from_environment(), LinkageType linkage_type=LinkageType::ExternalPlusMetadata)
Create an internal representation of lowered code as a self contained Module suitable for further com...
void add_custom_lowering_pass(T *pass)
Add a custom pass to be used during lowering.
Definition Pipeline.h:387
A fragment of Halide syntax.
Definition Expr.h:258
Load a value from a named symbol if predicate is true.
Definition IR.h:217
std::string name
Definition IR.h:218
This node is a helpful annotation to do with permissions.
Definition IR.h:315
A reference-counted handle to a statement node.
Definition Expr.h:427