Halide
check_call_graphs.h
Go to the documentation of this file.
1 #ifndef CHECK_CALL_GRAPHS_H
2 #define CHECK_CALL_GRAPHS_H
3 
4 #include <algorithm>
5 #include <assert.h>
6 #include <functional>
7 #include <map>
8 #include <numeric>
9 #include <stdio.h>
10 #include <string.h>
11 
12 #include "Halide.h"
13 
14 typedef std::map<std::string, std::vector<std::string>> CallGraphs;
15 
16 // For each producer node, find all functions that it calls.
18 public:
19  CallGraphs calls; // Caller -> vector of callees
20  std::string producer = "";
21 
22 private:
24 
26  if (op->is_producer) {
27  std::string old_producer = producer;
28  producer = op->name;
29  calls[producer]; // Make sure each producer is allocated a slot
30  // Group the callees of the 'produce' and 'update' together
31  auto new_stmt = mutate(op->body);
32  producer = old_producer;
33  return new_stmt;
34  } else {
36  }
37  }
38 
39  Halide::Expr visit(const Halide::Internal::Load *op) override {
40  if (!producer.empty()) {
41  assert(calls.count(producer) > 0);
42  std::vector<std::string> &callees = calls[producer];
43  if (std::find(callees.begin(), callees.end(), op->name) == callees.end()) {
44  callees.push_back(op->name);
45  }
46  }
48  }
49 };
50 
51 // These are declared "inline" to avoid "unused function" warnings
52 inline int check_call_graphs(Halide::Pipeline p, CallGraphs &expected) {
53  // Add a custom lowering pass that scrapes the call graph. We give ownership
54  // of it to the Pipeline, whose lifetime escapes this function.
55  CheckCalls *checker = new CheckCalls;
56  p.add_custom_lowering_pass(checker);
58  CallGraphs &result = checker->calls;
59 
60  if (result.size() != expected.size()) {
61  printf("Expect %d callers instead of %d\n", (int)expected.size(), (int)result.size());
62  return 1;
63  }
64  for (auto &iter : expected) {
65  if (result.count(iter.first) == 0) {
66  printf("Expect %s to be in the call graphs\n", iter.first.c_str());
67  return 1;
68  }
69  std::vector<std::string> &expected_callees = iter.second;
70  std::vector<std::string> &result_callees = result[iter.first];
71  std::sort(expected_callees.begin(), expected_callees.end());
72  std::sort(result_callees.begin(), result_callees.end());
73  if (expected_callees != result_callees) {
74  std::string expected_str = std::accumulate(
75  expected_callees.begin(), expected_callees.end(), std::string{},
76  [](const std::string &a, const std::string &b) {
77  return a.empty() ? b : a + ", " + b;
78  });
79  std::string result_str = std::accumulate(
80  result_callees.begin(), result_callees.end(), std::string{},
81  [](const std::string &a, const std::string &b) {
82  return a.empty() ? b : a + ", " + b;
83  });
84 
85  printf("Expect callees of %s to be (%s); got (%s) instead\n",
86  iter.first.c_str(), expected_str.c_str(), result_str.c_str());
87  return 1;
88  }
89  }
90  return 0;
91 }
92 
93 template<typename T, typename F>
94 inline int check_image2(const Halide::Buffer<T> &im, const F &func) {
95  for (int y = 0; y < im.height(); y++) {
96  for (int x = 0; x < im.width(); x++) {
97  T correct = func(x, y);
98  if (im(x, y) != correct) {
99  std::cout << "im(" << x << ", " << y << ") = " << im(x, y)
100  << " instead of " << correct << "\n";
101  return 1;
102  }
103  }
104  }
105  return 0;
106 }
107 
108 template<typename T, typename F>
109 inline int check_image3(const Halide::Buffer<T> &im, const F &func) {
110  for (int z = 0; z < im.channels(); z++) {
111  for (int y = 0; y < im.height(); y++) {
112  for (int x = 0; x < im.width(); x++) {
113  T correct = func(x, y, z);
114  if (im(x, y, z) != correct) {
115  std::cout << "im(" << x << ", " << y << ", " << z << ") = "
116  << im(x, y, z) << " instead of " << correct << "\n";
117  return 1;
118  }
119  }
120  }
121  }
122  return 0;
123 }
124 
125 template<typename T, typename F>
126 inline auto // SFINAE: returns int if F has arity of 2
127 check_image(const Halide::Buffer<T> &im, const F &func) -> decltype(std::declval<F>()(0, 0), int()) {
128  return check_image2(im, func);
129 }
130 
131 template<typename T, typename F>
132 inline auto // SFINAE: returns int if F has arity of 3
133 check_image(const Halide::Buffer<T> &im, const F &func) -> decltype(std::declval<F>()(0, 0, 0), int()) {
134  return check_image3(im, func);
135 }
136 
137 #endif
check_image3
int check_image3(const Halide::Buffer< T > &im, const F &func)
Definition: check_call_graphs.h:109
Halide::Pipeline::infer_arguments
std::vector< Argument > infer_arguments(const Internal::Stmt &body)
CheckCalls::producer
std::string producer
Definition: check_call_graphs.h:20
CheckCalls::calls
CallGraphs calls
Definition: check_call_graphs.h:19
Halide::Internal::Stmt
A reference-counted handle to a statement node.
Definition: Expr.h:418
Halide::Pipeline
A class representing a Halide pipeline.
Definition: Pipeline.h:108
Halide::Internal::Load
Load a value from a named symbol if predicate is true.
Definition: IR.h:209
Halide::Buffer
A Halide::Buffer is a named shared reference to a Halide::Runtime::Buffer.
Definition: Argument.h:16
Halide::Internal::ProducerConsumer::is_producer
bool is_producer
Definition: IR.h:309
Halide::Internal::IRMutator::visit
virtual Expr visit(const IntImm *)
Halide::Internal::ProducerConsumer::body
Stmt body
Definition: IR.h:310
Halide::Internal::ProducerConsumer
This node is a helpful annotation to do with permissions.
Definition: IR.h:307
Halide::Internal::IRMutator
A base class for passes over the IR which modify it (e.g.
Definition: IRMutator.h:26
CheckCalls
Definition: check_call_graphs.h:17
Halide::Pipeline::add_custom_lowering_pass
void add_custom_lowering_pass(T *pass)
Add a custom pass to be used during lowering.
Definition: Pipeline.h:375
CallGraphs
std::map< std::string, std::vector< std::string > > CallGraphs
Definition: check_call_graphs.h:14
Halide::Internal::ProducerConsumer::name
std::string name
Definition: IR.h:308
check_image
auto check_image(const Halide::Buffer< T > &im, const F &func) -> decltype(std::declval< F >()(0, 0), int())
Definition: check_call_graphs.h:127
check_image2
int check_image2(const Halide::Buffer< T > &im, const F &func)
Definition: check_call_graphs.h:94
Halide::Expr
A fragment of Halide syntax.
Definition: Expr.h:257
Halide::Internal::Load::name
std::string name
Definition: IR.h:210
check_call_graphs
int check_call_graphs(Halide::Pipeline p, CallGraphs &expected)
Definition: check_call_graphs.h:52
Halide::Internal::IRMutator::mutate
virtual Expr mutate(const Expr &expr)
This is the main interface for using a mutator.
Halide::Pipeline::compile_to_module
Module compile_to_module(const std::vector< Argument > &args, const std::string &fn_name, const Target &target=get_target_from_environment(), LinkageType linkage_type=LinkageType::ExternalPlusMetadata)
Create an internal representation of lowered code as a self contained Module suitable for further com...