Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
Weights.h
Go to the documentation of this file.
1#ifndef _WEIGHTS
2#define _WEIGHTS
3
4#include <cstdint>
5#include <iostream>
6#include <string>
7
8#include "Featurization.h"
9#include "HalideBuffer.h"
10#include "NetworkSize.h"
11
12namespace Halide {
13namespace Internal {
14
15struct Weights {
18
21
24
27
28 template<typename F>
29 void for_each_buffer(F f) {
30 f(head1_filter);
31 f(head1_bias);
32 f(head2_filter);
33 f(head2_bias);
34 f(conv1_filter);
35 f(conv1_bias);
36 }
37
38 void randomize(uint32_t seed);
39
40 bool load(std::istream &i);
41 bool save(std::ostream &o) const;
42
43 bool load_from_file(const std::string &filename);
44 bool save_to_file(const std::string &filename) const;
45
46 // Load/save from the 'classic' form of six raw data files
47 bool load_from_dir(const std::string &dir);
48 bool save_to_dir(const std::string &dir) const;
49};
50
51} // namespace Internal
52} // namespace Halide
53
54#endif // _WEIGHTS
Defines a Buffer type that wraps from halide_buffer_t and adds functionality, and methods for more co...
A templated Buffer class that wraps halide_buffer_t and adds functionality.
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
const int head2_w
Definition NetworkSize.h:8
@ Internal
Not visible externally, similar to 'static' linkage in C.
const int head2_channels
Definition NetworkSize.h:8
const int head1_w
Definition NetworkSize.h:7
const int head1_channels
Definition NetworkSize.h:7
const int conv1_channels
Definition NetworkSize.h:9
const int head1_h
Definition NetworkSize.h:7
unsigned __INT32_TYPE__ uint32_t
static constexpr uint32_t version()
static constexpr uint32_t version()
Halide::Runtime::Buffer< float > conv1_bias
Definition Weights.h:26
Halide::Runtime::Buffer< float > conv1_filter
Definition Weights.h:25
bool load_from_dir(const std::string &dir)
bool save_to_dir(const std::string &dir) const
void for_each_buffer(F f)
Definition Weights.h:29
uint32_t schedule_features_version
Definition Weights.h:17
void randomize(uint32_t seed)
bool save_to_file(const std::string &filename) const
bool load(std::istream &i)
uint32_t pipeline_features_version
Definition Weights.h:16
Halide::Runtime::Buffer< float > head1_bias
Definition Weights.h:20
bool load_from_file(const std::string &filename)
Halide::Runtime::Buffer< float > head1_filter
Definition Weights.h:19
bool save(std::ostream &o) const
Halide::Runtime::Buffer< float > head2_filter
Definition Weights.h:22
Halide::Runtime::Buffer< float > head2_bias
Definition Weights.h:23