Halide
check_call_graphs.h
Go to the documentation of this file.
1 #ifndef CHECK_CALL_GRAPHS_H
2 #define CHECK_CALL_GRAPHS_H
3 
4 #include <algorithm>
5 #include <assert.h>
6 #include <functional>
7 #include <map>
8 #include <numeric>
9 #include <stdio.h>
10 #include <string.h>
11 
12 #include "Halide.h"
13 
14 typedef std::map<std::string, std::vector<std::string>> CallGraphs;
15 
16 // For each producer node, find all functions that it calls.
18 public:
19  CallGraphs calls; // Caller -> vector of callees
20  std::string producer = "";
21 
22 private:
24 
25  void visit(const Halide::Internal::ProducerConsumer *op) override {
26  if (op->is_producer) {
27  std::string old_producer = producer;
28  producer = op->name;
29  calls[producer]; // Make sure each producer is allocated a slot
30  // Group the callees of the 'produce' and 'update' together
31  op->body.accept(this);
32  producer = old_producer;
33  } else {
35  }
36  }
37 
38  void visit(const Halide::Internal::Load *op) override {
40  if (!producer.empty()) {
41  assert(calls.count(producer) > 0);
42  std::vector<std::string> &callees = calls[producer];
43  if (std::find(callees.begin(), callees.end(), op->name) == callees.end()) {
44  callees.push_back(op->name);
45  }
46  }
47  }
48 };
49 
50 // These are declared "inline" to avoid "unused function" warnings
51 inline int check_call_graphs(CallGraphs &result, CallGraphs &expected) {
52  if (result.size() != expected.size()) {
53  printf("Expect %d callers instead of %d\n", (int)expected.size(), (int)result.size());
54  return -1;
55  }
56  for (auto &iter : expected) {
57  if (result.count(iter.first) == 0) {
58  printf("Expect %s to be in the call graphs\n", iter.first.c_str());
59  return -1;
60  }
61  std::vector<std::string> &expected_callees = iter.second;
62  std::vector<std::string> &result_callees = result[iter.first];
63  std::sort(expected_callees.begin(), expected_callees.end());
64  std::sort(result_callees.begin(), result_callees.end());
65  if (expected_callees != result_callees) {
66  std::string expected_str = std::accumulate(
67  expected_callees.begin(), expected_callees.end(), std::string{},
68  [](const std::string &a, const std::string &b) {
69  return a.empty() ? b : a + ", " + b;
70  });
71  std::string result_str = std::accumulate(
72  result_callees.begin(), result_callees.end(), std::string{},
73  [](const std::string &a, const std::string &b) {
74  return a.empty() ? b : a + ", " + b;
75  });
76 
77  printf("Expect calless of %s to be (%s); got (%s) instead\n",
78  iter.first.c_str(), expected_str.c_str(), result_str.c_str());
79  return -1;
80  }
81  }
82  return 0;
83 }
84 
85 template<typename T, typename F>
86 inline int check_image2(const Halide::Buffer<T> &im, const F &func) {
87  for (int y = 0; y < im.height(); y++) {
88  for (int x = 0; x < im.width(); x++) {
89  T correct = func(x, y);
90  if (im(x, y) != correct) {
91  std::cout << "im(" << x << ", " << y << ") = " << im(x, y)
92  << " instead of " << correct << "\n";
93  return -1;
94  }
95  }
96  }
97  return 0;
98 }
99 
100 template<typename T, typename F>
101 inline int check_image3(const Halide::Buffer<T> &im, const F &func) {
102  for (int z = 0; z < im.channels(); z++) {
103  for (int y = 0; y < im.height(); y++) {
104  for (int x = 0; x < im.width(); x++) {
105  T correct = func(x, y, z);
106  if (im(x, y, z) != correct) {
107  std::cout << "im(" << x << ", " << y << ", " << z << ") = "
108  << im(x, y, z) << " instead of " << correct << "\n";
109  return -1;
110  }
111  }
112  }
113  }
114  return 0;
115 }
116 
117 template<typename T, typename F>
118 inline auto // SFINAE: returns int if F has arity of 2
119 check_image(const Halide::Buffer<T> &im, const F &func) -> decltype(std::declval<F>()(0, 0), int()) {
120  return check_image2(im, func);
121 }
122 
123 template<typename T, typename F>
124 inline auto // SFINAE: returns int if F has arity of 3
125 check_image(const Halide::Buffer<T> &im, const F &func) -> decltype(std::declval<F>()(0, 0, 0), int()) {
126  return check_image3(im, func);
127 }
128 
129 #endif
Halide::Internal::IRVisitor::visit
virtual void visit(const IntImm *)
check_image3
int check_image3(const Halide::Buffer< T > &im, const F &func)
Definition: check_call_graphs.h:101
CheckCalls::producer
std::string producer
Definition: check_call_graphs.h:20
Halide::Internal::IRVisitor
A base class for algorithms that need to recursively walk over the IR.
Definition: IRVisitor.h:21
CheckCalls::calls
CallGraphs calls
Definition: check_call_graphs.h:19
Halide::Internal::IRHandle::accept
void accept(IRVisitor *v) const
Dispatch to the correct visitor method for this node.
Definition: Expr.h:190
Halide::Internal::Load
Load a value from a named symbol if predicate is true.
Definition: IR.h:199
Halide::Buffer
A Halide::Buffer is a named shared reference to a Halide::Runtime::Buffer.
Definition: Argument.h:16
Halide::Internal::ProducerConsumer::is_producer
bool is_producer
Definition: IR.h:299
Halide::Internal::ProducerConsumer::body
Stmt body
Definition: IR.h:300
Halide::Internal::ProducerConsumer
This node is a helpful annotation to do with permissions.
Definition: IR.h:297
CheckCalls
Definition: check_call_graphs.h:17
CallGraphs
std::map< std::string, std::vector< std::string > > CallGraphs
Definition: check_call_graphs.h:14
Halide::Internal::ProducerConsumer::name
std::string name
Definition: IR.h:298
check_image
auto check_image(const Halide::Buffer< T > &im, const F &func) -> decltype(std::declval< F >()(0, 0), int())
Definition: check_call_graphs.h:119
check_image2
int check_image2(const Halide::Buffer< T > &im, const F &func)
Definition: check_call_graphs.h:86
Halide::Internal::Load::name
std::string name
Definition: IR.h:200
check_call_graphs
int check_call_graphs(CallGraphs &result, CallGraphs &expected)
Definition: check_call_graphs.h:51