Halide
DefaultCostModel.h
Go to the documentation of this file.
1 #ifndef DEFAULT_COST_MODEL_H
2 #define DEFAULT_COST_MODEL_H
3 
4 #include "CostModel.h"
5 #include "Statistics.h"
6 #include "Weights.h"
7 #include <string>
8 
9 namespace Halide {
10 
11 class DefaultCostModel : public CostModel {
12 private:
13  Internal::Weights weights;
14  Runtime::Buffer<float> schedule_feat_queue, pipeline_feat_queue, costs, costs_per_stage;
15  Runtime::Buffer<double *> cost_ptrs;
16  std::vector<std::vector<double> *> cost_per_stage_ptrs;
17  int cursor, num_stages, num_cores;
18  int batch_id{0};
19 
20  const std::string weights_in_path, weights_out_path;
21  const bool randomize_weights;
22 
23  Runtime::Buffer<float>
24  head1_filter_update, head1_bias_update,
25  head2_filter_update, head2_bias_update,
26  conv1_filter_update, conv1_bias_update;
27  int timestep = 0;
28 
29  Internal::Autoscheduler::Statistics &stats;
30 
31 public:
32  DefaultCostModel(const std::string &weights_in_path,
33  const std::string &weights_out_path,
34  bool randomize_weights,
36  : weights_in_path(weights_in_path),
37  weights_out_path(weights_out_path),
38  randomize_weights(randomize_weights),
39  stats{stats} {
40  load_weights();
41  }
42 
43  // Configure the cost model for the algorithm to be scheduled.
44  void set_pipeline_features(const Internal::Autoscheduler::FunctionDAG &dag,
45  const Internal::Autoscheduler::Anderson2021Params &params) override;
46  void set_pipeline_features(const Runtime::Buffer<float> &, int n);
47 
48  // Enqueue a schedule to be evaluated. The second version of this method returns a buffer of
49  // schedule_features that should be filled in by the caller.
50  void enqueue(const Internal::Autoscheduler::FunctionDAG &dag,
52  double *cost_ptr,
53  std::vector<double> *cost_per_stage_ptr) override;
54  void enqueue(int ns, Runtime::Buffer<float> *schedule_feats, double *cost_ptr, std::vector<double> *cost_per_stage_ptr);
55 
56  // Evaluate all schedules in the queue.
57  void evaluate_costs() override;
58 
59  // Discard all schedules in the queue.
60  void reset() override;
61 
62  // Update model weights using true measured runtimes.
63  float backprop(const Runtime::Buffer<const float> &true_runtimes, float learning_rate);
64 
65  // Save/Load the model weights to/from disk.
66  void save_weights();
67  void load_weights();
68 };
69 
70 std::unique_ptr<DefaultCostModel> make_default_cost_model(Internal::Autoscheduler::Statistics &stats,
71  const std::string &weights_in_dir = "",
72  const std::string &weights_out_dir = "",
73  bool randomize_weights = false);
74 } // namespace Halide
75 
76 #endif // DEFAULT_COST_MODEL_H
Halide::DefaultCostModel::enqueue
void enqueue(const Internal::Autoscheduler::FunctionDAG &dag, const Halide::Internal::Autoscheduler::StageMapOfScheduleFeatures &schedule_feats, double *cost_ptr) override
Halide::DefaultCostModel::backprop
float backprop(const Runtime::Buffer< const float > &true_runtimes, float learning_rate)
Halide::DefaultCostModel::reset
void reset() override
Halide::DefaultCostModel::evaluate_costs
void evaluate_costs() override
Halide::Internal::Autoscheduler::Statistics
Definition: Statistics.h:63
Halide::make_default_cost_model
std::unique_ptr< DefaultCostModel > make_default_cost_model(const std::string &weights_in_dir="", const std::string &weights_out_dir="", bool randomize_weights=false)
Halide::DefaultCostModel::save_weights
void save_weights()
Halide::DefaultCostModel::DefaultCostModel
DefaultCostModel(const std::string &weights_in_path, const std::string &weights_out_path, bool randomize_weights, Internal::Autoscheduler::Statistics &stats)
Definition: DefaultCostModel.h:32
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AbstractGenerator.h:19
Weights.h
Halide::DefaultCostModel::load_weights
void load_weights()
CostModel.h
Statistics.h
PerfectHashMap
Definition: PerfectHashMap.h:38
Halide::DefaultCostModel::set_pipeline_features
void set_pipeline_features(const Internal::Autoscheduler::FunctionDAG &dag, const Internal::Autoscheduler::Adams2019Params &params) override