Halide
Weights.h
Go to the documentation of this file.
1 #ifndef _WEIGHTS
2 #define _WEIGHTS
3 
4 #include <cstdint>
5 #include <iostream>
6 #include <string>
7 
8 #include "Featurization.h"
9 #include "HalideBuffer.h"
10 #include "NetworkSize.h"
11 
12 namespace Halide {
13 namespace Internal {
14 
15 struct Weights {
18 
21 
24 
27 
28  template<typename F>
29  void for_each_buffer(F f) {
30  f(head1_filter);
31  f(head1_bias);
32  f(head2_filter);
33  f(head2_bias);
34  f(conv1_filter);
35  f(conv1_bias);
36  }
37 
38  void randomize(uint32_t seed);
39 
40  bool load(std::istream &i);
41  bool save(std::ostream &o) const;
42 
43  bool load_from_file(const std::string &filename);
44  bool save_to_file(const std::string &filename) const;
45 
46  // Load/save from the 'classic' form of six raw data files
47  bool load_from_dir(const std::string &dir);
48  bool save_to_dir(const std::string &dir) const;
49 };
50 
51 } // namespace Internal
52 } // namespace Halide
53 
54 #endif // _WEIGHTS
Halide::Internal::Weights::load_from_dir
bool load_from_dir(const std::string &dir)
Halide::Internal::Weights::save_to_dir
bool save_to_dir(const std::string &dir) const
Halide::Internal::ScheduleFeatures::version
static constexpr uint32_t version()
Definition: Featurization.h:170
Halide::Internal::Weights::head2_bias
Halide::Runtime::Buffer< float > head2_bias
Definition: Weights.h:23
HalideBuffer.h
Halide::Internal::Weights::save_to_file
bool save_to_file(const std::string &filename) const
Halide::Internal::Weights::for_each_buffer
void for_each_buffer(F f)
Definition: Weights.h:29
Halide::Internal::Weights::pipeline_features_version
uint32_t pipeline_features_version
Definition: Weights.h:16
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AddAtomicMutex.h:21
Halide::Internal::PipelineFeatures::version
static constexpr uint32_t version()
Definition: Featurization.h:20
Halide::LinkageType::Internal
@ Internal
Not visible externally, similar to 'static' linkage in C.
NetworkSize.h
Halide::Internal::Weights::schedule_features_version
uint32_t schedule_features_version
Definition: Weights.h:17
Halide::Internal::Weights::save
bool save(std::ostream &o) const
Halide::Runtime::Buffer< float >
Halide::head2_channels
const int head2_channels
Definition: NetworkSize.h:8
Halide::Internal::Weights::load_from_file
bool load_from_file(const std::string &filename)
Halide::Internal::Weights::conv1_filter
Halide::Runtime::Buffer< float > conv1_filter
Definition: Weights.h:25
Halide::conv1_channels
const int conv1_channels
Definition: NetworkSize.h:9
Halide::head1_channels
const int head1_channels
Definition: NetworkSize.h:7
Featurization.h
Halide::Internal::Weights::load
bool load(std::istream &i)
Halide::head1_h
const int head1_h
Definition: NetworkSize.h:7
Halide::Internal::Weights::head1_bias
Halide::Runtime::Buffer< float > head1_bias
Definition: Weights.h:20
Halide::Internal::Weights::randomize
void randomize(uint32_t seed)
Halide::Internal::Weights::head1_filter
Halide::Runtime::Buffer< float > head1_filter
Definition: Weights.h:19
Halide::head1_w
const int head1_w
Definition: NetworkSize.h:7
uint32_t
unsigned __INT32_TYPE__ uint32_t
Definition: runtime_internal.h:21
Halide::Internal::Weights
Definition: Weights.h:15
Halide::head2_w
const int head2_w
Definition: NetworkSize.h:8
Halide::Internal::Weights::conv1_bias
Halide::Runtime::Buffer< float > conv1_bias
Definition: Weights.h:26
Halide::Internal::Weights::head2_filter
Halide::Runtime::Buffer< float > head2_filter
Definition: Weights.h:22