Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
State.h
Go to the documentation of this file.
1#ifndef STATE_H
2#define STATE_H
3
4#include "ASLog.h"
5#include "Cache.h"
6#include "CostModel.h"
7#include "DefaultCostModel.h"
8#include "Featurization.h"
9#include "Halide.h"
10#include "LoopNest.h"
11#include "PerfectHashMap.h"
12#include <map>
13#include <utility>
14
15namespace Halide {
16namespace Internal {
17namespace Autoscheduler {
18
19// A struct representing an intermediate state in the tree search.
20// It represents a partial schedule for some pipeline.
21struct State {
23 // The LoopNest this state corresponds to.
25 // The parent that generated this state.
27 // Cost of this state, as evaluated by the cost model.
28 double cost = 0;
29 // Number of decisions made at this state (used for finding which DAG node to schedule).
31 // Penalization is determined based on structural hash during beam search.
32 bool penalized = false;
33
34 // The C++ source code of the generated schedule for this State.
35 // Computed if `apply_schedule` is called.
37
38 // The number of times a cost is enqueued into the cost model,
39 // for all states.
41
42 State() = default;
43 State(const State &) = delete;
44 State(State &&) = delete;
45 void operator=(const State &) = delete;
46 void operator=(State &&) = delete;
47
48 // Compute a structural hash based on depth and num_decisions_made.
49 // Defers to root->structural_hash().
50 uint64_t structural_hash(int depth) const;
51
52 // Compute the featurization of this state (based on `root`),
53 // and store features in `features`. Defers to `root->compute_features()`.
55 const Adams2019Params &params,
57 const CachingOptions &cache_options);
58
59 // Calls `compute_featurization` and prints those features to `out`.
61 const Adams2019Params &params,
62 const CachingOptions &cache_options,
63 std::ostream &out);
64
65 // Performs some pruning to decide if this state is worth queuing in
66 // the cost_model. If it is, calls `cost_model->enqueue` and returns true,
67 // otherwise sets `cost` equal to a large value and returns false.
68 bool calculate_cost(const FunctionDAG &dag, const Adams2019Params &params,
69 CostModel *cost_model, const CachingOptions &cache_options,
70 int verbosity = 99);
71
72 // Make a child copy of this state. The loop nest is const (we
73 // make mutated copies of it, rather than mutating it), so we can
74 // continue to point to the same one and so this is a cheap
75 // operation.
77
78 // Generate the successor states to this state.
79 // If they are not pruned by `calculate_cost()`,
80 // then calls `accept_child()` on them.
82 const Adams2019Params &params,
83 CostModel *cost_model,
84 std::function<void(IntrusivePtr<State> &&)> &accept_child,
85 Cache *cache) const;
86
87 // Dumps cost, the `root` LoopNest, and then `schedule_source` to `os`.
88 void dump(std::ostream &os) const;
89
90 // Apply the schedule represented by this state to a Halide
91 // Pipeline. Also generate source code for the schedule for the
92 // user to copy-paste to freeze this schedule as permanent artifact.
93 // Also fills `schedule_source`.
94 void apply_schedule(const FunctionDAG &dag, const Adams2019Params &params);
95};
96
97} // namespace Autoscheduler
98} // namespace Internal
99} // namespace Halide
100
101#endif // STATE_H
A class representing a reference count to be used with IntrusivePtr.
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
@ Internal
Not visible externally, similar to 'static' linkage in C.
unsigned __INT64_TYPE__ uint64_t
bool calculate_cost(const FunctionDAG &dag, const Adams2019Params &params, CostModel *cost_model, const CachingOptions &cache_options, int verbosity=99)
void generate_children(const FunctionDAG &dag, const Adams2019Params &params, CostModel *cost_model, std::function< void(IntrusivePtr< State > &&)> &accept_child, Cache *cache) const
void dump(std::ostream &os) const
void operator=(const State &)=delete
IntrusivePtr< const State > parent
Definition State.h:26
uint64_t structural_hash(int depth) const
IntrusivePtr< const LoopNest > root
Definition State.h:24
void save_featurization(const FunctionDAG &dag, const Adams2019Params &params, const CachingOptions &cache_options, std::ostream &out)
void apply_schedule(const FunctionDAG &dag, const Adams2019Params &params)
void compute_featurization(const FunctionDAG &dag, const Adams2019Params &params, StageMap< ScheduleFeatures > *features, const CachingOptions &cache_options)
IntrusivePtr< State > make_child() const
Intrusive shared pointers have a reference count (a RefCount object) stored in the class itself.