Halide
Derivative.h
Go to the documentation of this file.
1 #ifndef HALIDE_DERIVATIVE_H
2 #define HALIDE_DERIVATIVE_H
3 
4 /** \file
5  * Automatic differentiation
6  */
7 
8 #include "Expr.h"
9 #include "Func.h"
10 
11 #include <map>
12 #include <string>
13 #include <vector>
14 
15 namespace Halide {
16 
17 /**
18  * Helper structure storing the adjoints Func.
19  * Use d(func) or d(buffer) to obtain the derivative Func.
20  */
21 class Derivative {
22 public:
23  // function name & update_id, for initialization update_id == -1
24  using FuncKey = std::pair<std::string, int>;
25 
26  explicit Derivative(const std::map<FuncKey, Func> &adjoints_in)
27  : adjoints(adjoints_in) {
28  }
29  explicit Derivative(std::map<FuncKey, Func> &&adjoints_in)
30  : adjoints(std::move(adjoints_in)) {
31  }
32 
33  // These all return an undefined Func if no derivative is found
34  // (typically, if the input Funcs aren't differentiable)
35  Func operator()(const Func &func, int update_id = -1) const;
36  Func operator()(const Buffer<> &buffer) const;
37  Func operator()(const Param<> &param) const;
38  Func operator()(const std::string &name) const;
39 
40 private:
41  const std::map<FuncKey, Func> adjoints;
42 };
43 
44 /**
45  * Given a Func and a corresponding adjoint, (back)propagate the
46  * adjoint to all dependent Funcs, buffers, and parameters.
47  * The bounds of output and adjoint need to be specified with pair {min, extent}
48  * For each Func the output depends on, and for the pure definition and
49  * each update of that Func, it generates a derivative Func stored in
50  * the Derivative.
51  */
52 Derivative propagate_adjoints(const Func &output,
53  const Func &adjoint,
54  const Region &output_bounds);
55 /**
56  * Given a Func and a corresponding adjoint buffer, (back)propagate the
57  * adjoint to all dependent Funcs, buffers, and parameters.
58  * For each Func the output depends on, and for the pure definition and
59  * each update of that Func, it generates a derivative Func stored in
60  * the Derivative.
61  */
62 Derivative propagate_adjoints(const Func &output,
63  const Buffer<float> &adjoint);
64 /**
65  * Given a scalar Func with size 1, (back)propagate the gradient
66  * to all dependent Funcs, buffers, and parameters.
67  * For each Func the output depends on, and for the pure definition and
68  * each update of that Func, it generates a derivative Func stored in
69  * the Derivative.
70  */
71 Derivative propagate_adjoints(const Func &output);
72 
73 } // namespace Halide
74 
75 #endif
Halide::Derivative::Derivative
Derivative(const std::map< FuncKey, Func > &adjoints_in)
Definition: Derivative.h:26
Halide::Region
std::vector< Range > Region
A multi-dimensional box.
Definition: Expr.h:344
Halide::Derivative::Derivative
Derivative(std::map< FuncKey, Func > &&adjoints_in)
Definition: Derivative.h:29
Halide::Derivative::FuncKey
std::pair< std::string, int > FuncKey
Definition: Derivative.h:24
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AbstractGenerator.h:19
Func.h
Halide::Buffer<>
Halide::Derivative::operator()
Func operator()(const Func &func, int update_id=-1) const
Expr.h
Halide::Func
A halide function.
Definition: Func.h:687
Halide::Derivative
Helper structure storing the adjoints Func.
Definition: Derivative.h:21
Halide::propagate_adjoints
Derivative propagate_adjoints(const Func &output, const Func &adjoint, const Region &output_bounds)
Given a Func and a corresponding adjoint, (back)propagate the adjoint to all dependent Funcs,...
Halide::Param
A scalar parameter to a halide pipeline.
Definition: Param.h:22