Halide
Derivative.h
Go to the documentation of this file.
1 #ifndef HALIDE_DERIVATIVE_H
2 #define HALIDE_DERIVATIVE_H
3 
4 /** \file
5  * Automatic differentiation
6  */
7 
8 #include "Expr.h"
9 #include "Func.h"
10 #include "Module.h"
11 
12 #include <map>
13 #include <string>
14 #include <vector>
15 
16 namespace Halide {
17 
18 /**
19  * Helper structure storing the adjoints Func.
20  * Use d(func) or d(buffer) to obtain the derivative Func.
21  */
22 class Derivative {
23 public:
24  // function name & update_id, for initialization update_id == -1
25  using FuncKey = std::pair<std::string, int>;
26 
27  explicit Derivative(const std::map<FuncKey, Func> &adjoints_in)
28  : adjoints(adjoints_in) {
29  }
30  explicit Derivative(std::map<FuncKey, Func> &&adjoints_in)
31  : adjoints(std::move(adjoints_in)) {
32  }
33 
34  // These all return an undefined Func if no derivative is found
35  // (typically, if the input Funcs aren't differentiable)
36  Func operator()(const Func &func, int update_id = -1) const;
37  Func operator()(const Buffer<> &buffer) const;
38  Func operator()(const Param<> &param) const;
39 
40 private:
41  const std::map<FuncKey, Func> adjoints;
42 };
43 
44 /**
45  * Given a Func and a corresponding adjoint, (back)propagate the
46  * adjoint to all dependent Funcs, buffers, and parameters.
47  * The bounds of output and adjoint need to be specified with pair {min, extent}
48  * For each Func the output depends on, and for the pure definition and
49  * each update of that Func, it generates a derivative Func stored in
50  * the Derivative.
51  */
52 Derivative propagate_adjoints(const Func &output,
53  const Func &adjoint,
54  const Region &output_bounds);
55 /**
56  * Given a Func and a corresponding adjoint buffer, (back)propagate the
57  * adjoint to all dependent Funcs, buffers, and parameters.
58  * For each Func the output depends on, and for the pure definition and
59  * each update of that Func, it generates a derivative Func stored in
60  * the Derivative.
61  */
62 Derivative propagate_adjoints(const Func &output,
63  const Buffer<float> &adjoint);
64 /**
65  * Given a scalar Func with size 1, (back)propagate the gradient
66  * to all dependent Funcs, buffers, and parameters.
67  * For each Func the output depends on, and for the pure definition and
68  * each update of that Func, it generates a derivative Func stored in
69  * the Derivative.
70  */
71 Derivative propagate_adjoints(const Func &output);
72 
73 } // namespace Halide
74 
75 #endif
Halide::Derivative::Derivative
Derivative(const std::map< FuncKey, Func > &adjoints_in)
Definition: Derivative.h:27
Halide::Region
std::vector< Range > Region
A multi-dimensional box.
Definition: Expr.h:343
Halide::Derivative::Derivative
Derivative(std::map< FuncKey, Func > &&adjoints_in)
Definition: Derivative.h:30
Halide::Derivative::FuncKey
std::pair< std::string, int > FuncKey
Definition: Derivative.h:25
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AddAtomicMutex.h:21
Func.h
Halide::Buffer<>
Halide::Derivative::operator()
Func operator()(const Func &func, int update_id=-1) const
Expr.h
Halide::Func
A halide function.
Definition: Func.h:667
Halide::Derivative
Helper structure storing the adjoints Func.
Definition: Derivative.h:22
Halide::propagate_adjoints
Derivative propagate_adjoints(const Func &output, const Func &adjoint, const Region &output_bounds)
Given a Func and a corresponding adjoint, (back)propagate the adjoint to all dependent Funcs,...
Module.h
Halide::Param
A scalar parameter to a halide pipeline.
Definition: Param.h:22