Halide
testing.h
Go to the documentation of this file.
1 #ifndef _TESTING_H_
2 #define _TESTING_H_
3 
4 #include "Halide.h"
5 #include <cmath>
6 #include <exception>
7 #include <functional>
8 #include <iostream>
9 
10 namespace Testing {
11 
12 template<typename T>
13 bool neq(T a, T b, T tol) {
14  return std::abs(a - b) > tol;
15 }
16 
17 // Check 3-dimension buffer
18 template<typename T, typename F>
19 auto check_result(const Halide::Buffer<T> &buf, T tol, F f) -> decltype(std::declval<F>()(0, 0, 0), bool()) {
20  class err : std::exception {
21  public:
22  static void vector(const std::vector<T> &v) {
23  for (size_t i = 0; i < v.size(); i++) {
24  if (i > 0) {
25  std::cerr << ",";
26  }
27  std::cerr << +v[i]; // use unary + to promote uint8_t from char to numeric
28  }
29  }
30  };
31  try {
32  buf.for_each_element([&](int x, int y) {
33  std::vector<T> expected;
34  std::vector<T> result;
35  for (int c = 0; c < buf.channels(); c++) {
36  expected.push_back(f(x, y, c));
37  result.push_back(buf(x, y, c));
38  }
39  for (int c = 0; c < buf.channels(); c++) {
40  if (neq(result[c], expected[c], tol)) {
41  std::cerr << "Error: result (";
42  err::vector(result);
43  std::cerr << ") should be (";
44  err::vector(expected);
45  std::cerr << ") at x=" << x << " y=" << y << "\n";
46  throw err();
47  }
48  }
49  });
50  } catch (err &) {
51  return false;
52  }
53  return true;
54 }
55 
56 // Check 2-dimension buffer
57 template<typename T, typename F>
58 auto check_result(const Halide::Buffer<T> &buf, T tol, F f) -> decltype(std::declval<F>()(0, 0), bool()) {
59  class err : std::exception {};
60  try {
61  buf.for_each_element([&](int x, int y) {
62  const T expected = f(x, y);
63  const T result = buf(x, y);
64  if (neq(result, expected, tol)) {
65  std::cerr << "Error: result (";
66  std::cerr << +result;
67  std::cerr << ") should be (";
68  std::cerr << +expected;
69  std::cerr << ") at x=" << x << " y=" << y << "\n";
70  throw err();
71  }
72  });
73  } catch (err &) {
74  return false;
75  }
76  return true;
77 }
78 
79 // Shorthand to check with tolerance=0
80 template<typename T, typename Func>
81 bool check_result(const Halide::Buffer<T> &buf, Func f) {
82  return check_result<T>(buf, 0, f);
83 }
84 } // namespace Testing
85 
86 #endif // _TESTING_H_
Testing::check_result
auto check_result(const Halide::Buffer< T > &buf, T tol, F f) -> decltype(std::declval< F >()(0, 0, 0), bool())
Definition: testing.h:19
Halide::Buffer
A Halide::Buffer is a named shared reference to a Halide::Runtime::Buffer.
Definition: Argument.h:16
Testing::neq
bool neq(T a, T b, T tol)
Definition: testing.h:13
Halide::abs
Expr abs(Expr a)
Returns the absolute value of a signed integer or floating-point expression.
Testing
Definition: testing.h:10
buf
char * buf
Definition: printer.h:32