Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
Derivative.h
Go to the documentation of this file.
1#ifndef HALIDE_DERIVATIVE_H
2#define HALIDE_DERIVATIVE_H
3
4/** \file
5 * Automatic differentiation
6 */
7
8#include "Expr.h"
9#include "Func.h"
10
11#include <map>
12#include <string>
13#include <vector>
14
15namespace Halide {
16
17/**
18 * Helper structure storing the adjoints Func.
19 * Use d(func) or d(buffer) to obtain the derivative Func.
20 */
22public:
23 // function name & update_id, for initialization update_id == -1
24 using FuncKey = std::pair<std::string, int>;
25
26 explicit Derivative(const std::map<FuncKey, Func> &adjoints_in)
27 : adjoints(adjoints_in) {
28 }
29 explicit Derivative(std::map<FuncKey, Func> &&adjoints_in)
30 : adjoints(std::move(adjoints_in)) {
31 }
32
33 // These all return an undefined Func if no derivative is found
34 // (typically, if the input Funcs aren't differentiable)
35 Func operator()(const Func &func, int update_id = -1) const;
36 Func operator()(const Buffer<> &buffer) const;
37 Func operator()(const Param<> &param) const;
38 Func operator()(const std::string &name) const;
39
40private:
41 const std::map<FuncKey, Func> adjoints;
42};
43
44/**
45 * Given a Func and a corresponding adjoint, (back)propagate the
46 * adjoint to all dependent Funcs, buffers, and parameters.
47 * The bounds of output and adjoint need to be specified with pair {min, extent}
48 * For each Func the output depends on, and for the pure definition and
49 * each update of that Func, it generates a derivative Func stored in
50 * the Derivative.
51 */
53 const Func &adjoint,
54 const Region &output_bounds);
55/**
56 * Given a Func and a corresponding adjoint buffer, (back)propagate the
57 * adjoint to all dependent Funcs, buffers, and parameters.
58 * For each Func the output depends on, and for the pure definition and
59 * each update of that Func, it generates a derivative Func stored in
60 * the Derivative.
61 */
63 const Buffer<float> &adjoint);
64/**
65 * Given a scalar Func with size 1, (back)propagate the gradient
66 * to all dependent Funcs, buffers, and parameters.
67 * For each Func the output depends on, and for the pure definition and
68 * each update of that Func, it generates a derivative Func stored in
69 * the Derivative.
70 */
72
73} // namespace Halide
74
75#endif
Base classes for Halide expressions (Halide::Expr) and statements (Halide::Internal::Stmt)
Defines Func - the front-end handle on a halide function, and related classes.
Helper structure storing the adjoints Func.
Definition Derivative.h:21
Func operator()(const std::string &name) const
Derivative(std::map< FuncKey, Func > &&adjoints_in)
Definition Derivative.h:29
std::pair< std::string, int > FuncKey
Definition Derivative.h:24
Func operator()(const Func &func, int update_id=-1) const
Func operator()(const Buffer<> &buffer) const
Func operator()(const Param<> &param) const
Derivative(const std::map< FuncKey, Func > &adjoints_in)
Definition Derivative.h:26
A halide function.
Definition Func.h:700
A scalar parameter to a halide pipeline.
Definition Param.h:22
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Derivative propagate_adjoints(const Func &output, const Func &adjoint, const Region &output_bounds)
Given a Func and a corresponding adjoint, (back)propagate the adjoint to all dependent Funcs,...
std::vector< Range > Region
A multi-dimensional box.
Definition Expr.h:350