// Halide tutorial lesson 13: Tuples // This lesson describes how to write Funcs that evaluate to multiple // values. // On linux, you can compile and run it like so: // g++ lesson_13*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -lpthread -ldl -o lesson_13 -std=c++17 // LD_LIBRARY_PATH=<path/to/libHalide.so> ./lesson_13 // On os x: // g++ lesson_13*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -o lesson_13 -std=c++17 // DYLD_LIBRARY_PATH=<path/to/libHalide.dylib> ./lesson_13 // If you have the entire Halide source tree, you can also build it by // running: // make tutorial_lesson_13_tuples // in a shell with the current directory at the top of the halide // source tree. #include "Halide.h" #include <algorithm> #include <stdio.h> using namespace Halide; int main(int argc, char **argv) { // So far Funcs (such as the one below) have evaluated to a single // scalar value for each point in their domain. Func single_valued; Var x, y; single_valued(x, y) = x + y; // One way to write a Func that returns a collection of values is // to add an additional dimension that indexes that // collection. This is how we typically deal with color. For // example, the Func below represents a collection of three values // for every x, y coordinate indexed by c. Func color_image; Var c; color_image(x, y, c) = select(c == 0, 245, // Red value c == 1, 42, // Green value 132); // Blue value // Since this pattern appears quite often, Halide provides a // syntatic sugar to write the code above as the following, // using the "mux" function. // color_image(x, y, c) = mux(c, {245, 42, 132}); // This method is often convenient because it makes it easy to // operate on this Func in a way that treats each item in the // collection equally: Func brighter; brighter(x, y, c) = color_image(x, y, c) + 10; // However this method is also inconvenient for three reasons. // // 1) Funcs are defined over an infinite domain, so users of this // Func can for example access color_image(x, y, -17), which is // not a meaningful value and is probably indicative of a bug. // // 2) It requires a select, which can impact performance if not // bounded and unrolled: // brighter.bound(c, 0, 3).unroll(c); // // 3) With this method, all values in the collection must have the // same type. While the above two issues are merely inconvenient, // this one is a hard limitation that makes it impossible to // express certain things in this way. // It is also possible to represent a collection of values as a // collection of Funcs: Func func_array[3]; func_array[0](x, y) = x + y; func_array[1](x, y) = sin(x); func_array[2](x, y) = cos(y); // This method avoids the three problems above, but introduces a // new annoyance. Because these are separate Funcs, it is // difficult to schedule them so that they are all computed // together inside a single loop over x, y. // A third alternative is to define a Func as evaluating to a // Tuple instead of an Expr. A Tuple is a fixed-size collection of // Exprs. Each Expr in a Tuple may have a different type. The // following function evaluates to an integer value (x+y), and a // floating point value (sin(x*y)). Func multi_valued; multi_valued(x, y) = Tuple(x + y, sin(x * y)); // Realizing a tuple-valued Func returns a collection of // Buffers. We call this a Realization. It's equivalent to a // std::vector of Buffer objects: { Realization r = multi_valued.realize({80, 60}); assert(r.size() == 2); Buffer<int> im0 = r[0]; Buffer<float> im1 = r[1]; assert(im0(30, 40) == 30 + 40); assert(im1(30, 40) == sinf(30 * 40)); } // All Tuple elements are evaluated together over the same domain // in the same loop nest, but stored in distinct allocations. The // equivalent C++ code to the above is: { int multi_valued_0[80 * 60]; float multi_valued_1[80 * 60]; for (int y = 0; y < 80; y++) { for (int x = 0; x < 60; x++) { multi_valued_0[x + 60 * y] = x + y; multi_valued_1[x + 60 * y] = sinf(x * y); } } } // When compiling ahead-of-time, a Tuple-valued Func evaluates // into multiple distinct output halide_buffer_t structs. These appear in // order at the end of the function signature: // int multi_valued(...input buffers and params..., // halide_buffer_t *output_1, halide_buffer_t *output_2); // You can construct a Tuple by passing multiple Exprs to the // Tuple constructor as we did above. Perhaps more elegantly, you // can also take advantage of initializer lists and just // enclose your Exprs in braces: Func multi_valued_2; multi_valued_2(x, y) = {x + y, sin(x * y)}; // Calls to a multi-valued Func cannot be treated as Exprs. The // following is a syntax error: // Func consumer; // consumer(x, y) = multi_valued_2(x, y) + 10; // Instead you must index a Tuple with square brackets to retrieve // the individual Exprs: Expr integer_part = multi_valued_2(x, y)[0]; Expr floating_part = multi_valued_2(x, y)[1]; Func consumer; consumer(x, y) = {integer_part + 10, floating_part + 10.0f}; // Tuple reductions. { // Tuples are particularly useful in reductions, as they allow // the reduction to maintain complex state as it walks along // its domain. The simplest example is an argmax. // First we create a Buffer to take the argmax over. Func input_func; input_func(x) = sin(x); Buffer<float> input = input_func.realize({100}); // Then we define a 2-valued Tuple which tracks the index of // the maximum value and the value itself. Func arg_max; // Pure definition. arg_max() = {0, input(0)}; // Update definition. RDom r(1, 99); Expr old_index = arg_max()[0]; Expr old_max = arg_max()[1]; Expr new_index = select(old_max < input(r), r, old_index); Expr new_max = max(input(r), old_max); arg_max() = {new_index, new_max}; // The equivalent C++ is: int arg_max_0 = 0; float arg_max_1 = input(0); for (int r = 1; r < 100; r++) { int old_index = arg_max_0; float old_max = arg_max_1; int new_index = old_max < input(r) ? r : old_index; float new_max = std::max(input(r), old_max); // In a tuple update definition, all loads and computation // are done before any stores, so that all Tuple elements // are updated atomically with respect to recursive calls // to the same Func. arg_max_0 = new_index; arg_max_1 = new_max; } // Let's verify that the Halide and C++ found the same maximum // value and index. { Realization r = arg_max.realize(); Buffer<int> r0 = r[0]; Buffer<float> r1 = r[1]; assert(arg_max_0 == r0(0)); assert(arg_max_1 == r1(0)); } // Halide provides argmax and argmin as built-in reductions // similar to sum, product, maximum, and minimum. They return // a Tuple consisting of the point in the reduction domain // corresponding to that value, and the value itself. In the // case of ties they return the first value found. We'll use // one of these in the following section. } // Tuples for user-defined types. { // Tuples can also be a convenient way to represent compound // objects such as complex numbers. Defining an object that // can be converted to and from a Tuple is one way to extend // Halide's type system with user-defined types. struct Complex { Expr real, imag; // Construct from a Tuple Complex(Tuple t) : real(t[0]), imag(t[1]) { } // Construct from a pair of Exprs Complex(Expr r, Expr i) : real(r), imag(i) { } // Construct from a call to a Func by treating it as a Tuple Complex(FuncRef t) : Complex(Tuple(t)) { } // Convert to a Tuple operator Tuple() const { return {real, imag}; } // Complex addition Complex operator+(const Complex &other) const { return {real + other.real, imag + other.imag}; } // Complex multiplication Complex operator*(const Complex &other) const { return {real * other.real - imag * other.imag, real * other.imag + imag * other.real}; } // Complex magnitude, squared for efficiency Expr magnitude_squared() const { return real * real + imag * imag; } // Other complex operators would go here. The above are // sufficient for this example. }; // Let's use the Complex struct to compute a Mandelbrot set. Func mandelbrot; // The initial complex value corresponding to an x, y coordinate // in our Func. Complex initial(x / 15.0f - 2.5f, y / 6.0f - 2.0f); // Pure definition. Var t; mandelbrot(x, y, t) = Complex(0.0f, 0.0f); // We'll use an update definition to take 12 steps. RDom r(1, 12); Complex current = mandelbrot(x, y, r - 1); // The following line uses the complex multiplication and // addition we defined above. mandelbrot(x, y, r) = current * current + initial; // We'll use another tuple reduction to compute the iteration // number where the value first escapes a circle of radius 4. // This can be expressed as an argmin of a boolean - we want // the index of the first time the given boolean expression is // false (we consider false to be less than true). The argmax // would return the index of the first time the expression is // true. Expr escape_condition = Complex(mandelbrot(x, y, r)).magnitude_squared() < 16.0f; Tuple first_escape = argmin(escape_condition); // We only want the index, not the value, but argmin returns // both, so we'll index the argmin Tuple expression using // square brackets to get the Expr representing the index. Func escape; escape(x, y) = first_escape[0]; // Realize the pipeline and print the result as ascii art. Buffer<int> result = escape.realize({61, 25}); const char *code = " .:-~*={}&%#@"; for (int y = 0; y < result.height(); y++) { for (int x = 0; x < result.width(); x++) { printf("%c", code[result(x, y)]); } printf("\n"); } } printf("Success!\n"); return 0; }