Halide
Derivative.h
Go to the documentation of this file.
1 #ifndef HALIDE_DERIVATIVE_H
2 #define HALIDE_DERIVATIVE_H
3
4 /** \file
5  * Automatic differentiation
6  */
7
8 #include "Expr.h"
9 #include "Func.h"
10 #include "Module.h"
11
12 #include <map>
13 #include <string>
14 #include <vector>
15
16 namespace Halide {
17
18 /**
19  * Helper structure storing the adjoints Func.
20  * Use d(func) or d(buffer) to obtain the derivative Func.
21  */
22 class Derivative {
23 public:
24  // function name & update_id, for initialization update_id == -1
25  using FuncKey = std::pair<std::string, int>;
26
27  explicit Derivative(const std::map<FuncKey, Func> &adjoints_in)
29  }
32  }
33
34  // These all return an undefined Func if no derivative is found
35  // (typically, if the input Funcs aren't differentiable)
36  Func operator()(const Func &func, int update_id = -1) const;
37  Func operator()(const Buffer<> &buffer) const;
38  Func operator()(const Param<> &param) const;
39
40 private:
42 };
43
44 /**
45  * Given a Func and a corresponding adjoint, (back)propagate the
46  * adjoint to all dependent Funcs, buffers, and parameters.
47  * The bounds of output and adjoint need to be specified with pair {min, extent}
48  * For each Func the output depends on, and for the pure definition and
49  * each update of that Func, it generates a derivative Func stored in
50  * the Derivative.
51  */
54  const Region &output_bounds);
55 /**
56  * Given a Func and a corresponding adjoint buffer, (back)propagate the
57  * adjoint to all dependent Funcs, buffers, and parameters.
58  * For each Func the output depends on, and for the pure definition and
59  * each update of that Func, it generates a derivative Func stored in
60  * the Derivative.
61  */
64 /**
65  * Given a scalar Func with size 1, (back)propagate the gradient
66  * to all dependent Funcs, buffers, and parameters.
67  * For each Func the output depends on, and for the pure definition and
68  * each update of that Func, it generates a derivative Func stored in
69  * the Derivative.
70  */
72
73 } // namespace Halide
74
75 #endif
Halide::Derivative::Derivative
Derivative(const std::map< FuncKey, Func > &adjoints_in)
Definition: Derivative.h:27
Halide::Region
std::vector< Range > Region
A multi-dimensional box.
Definition: Expr.h:343
Halide::Derivative::Derivative
Definition: Derivative.h:30
Halide::Derivative::FuncKey
std::pair< std::string, int > FuncKey
Definition: Derivative.h:25
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Func.h
Halide::Buffer<>
Halide::Derivative::operator()
Func operator()(const Func &func, int update_id=-1) const
Expr.h
Halide::Func
A halide function.
Definition: Func.h:667
Halide::Derivative
Helper structure storing the adjoints Func.
Definition: Derivative.h:22