Halide
cost_model_schedule.h
Go to the documentation of this file.
1 #include "Halide.h"
2 
3 using namespace Halide;
4 
5 inline void do_cost_model_schedule(Halide::Pipeline pipeline) {
6  // Generated by autoscheduler, manually remove unrolls.
7  // Also manually replaced all RoundUp and ShiftInwards with GuardWithIf.
8 
9  using ::Halide::Func;
11  using ::Halide::RVar;
13  using ::Halide::Var;
14  Func loss_output = pipeline.get_func(55);
15  Func sum_1 = pipeline.get_func(54);
16  Func f2 = pipeline.get_func(53);
17  Func sum = pipeline.get_func(52);
18  Func prediction_output = pipeline.get_func(51);
19  Func updated_bias1 = pipeline.get_func(50);
20  Func bias1_im_0_d_def__ = pipeline.get_func(49);
21  Func conv1_stage1_0_d_def___1 = pipeline.get_func(48);
22  Func updated_filter1 = pipeline.get_func(47);
23  Func filter1_im_0_d_def__ = pipeline.get_func(46);
24  Func updated_head2_bias = pipeline.get_func(45);
25  Func head2_bias_im_0_d_def__ = pipeline.get_func(44);
26  Func head2_conv_0_d_def___1 = pipeline.get_func(43);
27  Func updated_head2_filter = pipeline.get_func(42);
28  Func head2_filter_im_0_d_def__ = pipeline.get_func(41);
29  Func head2_conv_1_d_def__ = pipeline.get_func(40);
30  Func head2_relu_0_d_def__ = pipeline.get_func(39);
31  Func updated_head1_bias = pipeline.get_func(38);
32  Func head1_bias_im_0_d_def__ = pipeline.get_func(37);
33  Func head1_conv_0_d_def___1 = pipeline.get_func(36);
34  Func updated_head1_filter = pipeline.get_func(35);
35  Func head1_filter_im_0_d_def__ = pipeline.get_func(34);
36  Func squashed_head1_filter_0_d_def__ = pipeline.get_func(33);
37  Func squashed_head1_filter_broadcast_0_d_def__ = pipeline.get_func(32);
38  Func head1_conv_1_d_def__ = pipeline.get_func(31);
39  Func conv1_stage1_1_d_def__ = pipeline.get_func(30);
40  Func conv1_stage2_0_d_def___1 = pipeline.get_func(29);
41  Func conv1_stage2_1_d_def__ = pipeline.get_func(28);
42  Func sum_1_d_def__ = pipeline.get_func(27);
43  Func relu1_0_d_def__ = pipeline.get_func(26);
44  Func f0_0_d_def__ = pipeline.get_func(25);
45  Func f1_1_d_def__ = pipeline.get_func(24);
46  Func f2_0_d_def__ = pipeline.get_func(22);
47  Func sum_1_1_d_def__ = pipeline.get_func(21);
48  Func loss_output_0_d_def__ = pipeline.get_func(20);
49  Func adjoint = pipeline.get_func(19);
50  Func f1 = pipeline.get_func(18);
51  Func f0 = pipeline.get_func(17);
52  Func relu1 = pipeline.get_func(16);
53  Func conv1_stage2 = pipeline.get_func(15);
54  Func head2_relu = pipeline.get_func(14);
55  Func head2_conv = pipeline.get_func(13);
56  Func normalized_schedule_features = pipeline.get_func(12);
57  Func conv1_stage1 = pipeline.get_func(8);
58  Func head1_conv = pipeline.get_func(7);
59  Func squashed_head1_filter_broadcast = pipeline.get_func(6);
60  Func squashed_head1_filter = pipeline.get_func(5);
61  Var c(head2_conv_0_d_def___1.get_schedule().dims()[0].var);
62  Var ci("ci");
63  Var n(sum.get_schedule().dims()[0].var);
64  Var ni("ni");
65  Var nii("nii");
66  Var r1010_z(filter1_im_0_d_def__.update(0).get_schedule().dims()[2].var);
67  Var r1207_y(filter1_im_0_d_def__.update(1).get_schedule().dims()[1].var);
68  Var s(squashed_head1_filter_0_d_def__.get_schedule().dims()[1].var);
69  Var si("si");
70  Var v12(head2_bias_im_0_d_def__.get_schedule().dims()[0].var);
71  Var v12i("v12i");
72  Var v13(head2_filter_im_0_d_def__.get_schedule().dims()[0].var);
73  Var v13i("v13i");
74  Var v14(head2_filter_im_0_d_def__.get_schedule().dims()[1].var);
75  Var v2(bias1_im_0_d_def__.get_schedule().dims()[0].var);
76  Var v207(updated_head1_filter.get_schedule().dims()[0].var);
77  Var v207i("v207i");
78  Var v208(updated_head1_filter.get_schedule().dims()[1].var);
79  Var v208i("v208i");
80  Var v209(updated_head1_filter.get_schedule().dims()[2].var);
81  Var v209i("v209i");
82  Var v210(updated_head1_filter.get_schedule().dims()[3].var);
83  Var v210i("v210i");
84  Var v211(updated_head1_bias.get_schedule().dims()[0].var);
85  Var v211i("v211i");
86  Var v212(updated_head1_bias.get_schedule().dims()[1].var);
87  Var v213(updated_head2_filter.get_schedule().dims()[0].var);
88  Var v213i("v213i");
89  Var v214(updated_head2_filter.get_schedule().dims()[1].var);
90  Var v214i("v214i");
91  Var v215(updated_head2_filter.get_schedule().dims()[2].var);
92  Var v215i("v215i");
93  Var v216(updated_head2_bias.get_schedule().dims()[0].var);
94  Var v216i("v216i");
95  Var v217(updated_head2_bias.get_schedule().dims()[1].var);
96  Var v218(updated_filter1.get_schedule().dims()[0].var);
97  Var v218i("v218i");
98  Var v218ii("v218ii");
99  Var v219(updated_filter1.get_schedule().dims()[1].var);
100  Var v219i("v219i");
101  Var v220(updated_filter1.get_schedule().dims()[2].var);
102  Var v220i("v220i");
103  Var v221(updated_bias1.get_schedule().dims()[0].var);
104  Var v221i("v221i");
105  Var v222(updated_bias1.get_schedule().dims()[1].var);
106  Var v2i("v2i");
107  Var v3(filter1_im_0_d_def__.get_schedule().dims()[0].var);
108  Var v4(filter1_im_0_d_def__.get_schedule().dims()[1].var);
109  Var v4i("v4i");
110  Var v5(head1_bias_im_0_d_def__.get_schedule().dims()[0].var);
111  Var v5i("v5i");
112  Var w(head2_conv_0_d_def___1.get_schedule().dims()[1].var);
113  Var wi("wi");
114  Var wii("wii");
115  RVar r1010_x(filter1_im_0_d_def__.update(0).get_schedule().dims()[0].var);
116  RVar r1010_y(filter1_im_0_d_def__.update(0).get_schedule().dims()[1].var);
117  RVar r1029_x(conv1_stage1_1_d_def__.update(0).get_schedule().dims()[0].var);
118  RVar r1029_xi("r1029$xi");
119  RVar r1095_x(head2_filter_im_0_d_def__.update(0).get_schedule().dims()[0].var);
120  RVar r1095_y(head2_filter_im_0_d_def__.update(0).get_schedule().dims()[1].var);
121  RVar r1114_x(head2_bias_im_0_d_def__.update(0).get_schedule().dims()[0].var);
122  RVar r1114_y(head2_bias_im_0_d_def__.update(0).get_schedule().dims()[1].var);
123  RVar r1183_x(head1_conv_1_d_def__.update(0).get_schedule().dims()[0].var);
124  RVar r1207_x(filter1_im_0_d_def__.update(1).get_schedule().dims()[0].var);
125  RVar r1226_x(bias1_im_0_d_def__.update(0).get_schedule().dims()[0].var);
126  RVar r1302_x(head1_bias_im_0_d_def__.update(0).get_schedule().dims()[0].var);
127  RVar r1321_x(squashed_head1_filter_0_d_def__.update(0).get_schedule().dims()[0].var);
128  RVar r14_x(conv1_stage1.update(0).get_schedule().dims()[0].var);
129  RVar r19_x(conv1_stage2.update(0).get_schedule().dims()[0].var);
130  RVar r24_x(f1.update(0).get_schedule().dims()[0].var);
131  RVar r29_x(sum_1.update(0).get_schedule().dims()[0].var);
132  RVar r34_x(sum.update(0).get_schedule().dims()[0].var);
133  RVar r34_y(sum.update(0).get_schedule().dims()[1].var);
134  RVar r4_x(head1_conv.update(0).get_schedule().dims()[0].var);
135  RVar r4_y(head1_conv.update(0).get_schedule().dims()[1].var);
136  RVar r9_x(head2_conv.update(0).get_schedule().dims()[0].var);
137  RVar r986_x(head2_relu_0_d_def__.update(0).get_schedule().dims()[0].var);
138  loss_output
139  .compute_root();
140  sum_1
141  .compute_root();
142  sum_1.update(0);
143  sum
144  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
145  .vectorize(ni)
146  .compute_root()
147  .reorder(ni, n)
148  .parallel(n);
149  sum.update(0)
150  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
151  .vectorize(ni)
152  .reorder(ni, r34_x, r34_y, n)
153  .parallel(n);
154  prediction_output
155  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
156  .vectorize(ni)
157  .compute_root()
158  .reorder(ni, n)
159  .parallel(n);
160  updated_bias1
161  .split(v221, v221, v221i, 8, TailStrategy::GuardWithIf)
162  .vectorize(v221i)
163  .compute_root()
164  .reorder(v221i, v221, v222)
165  .fuse(v221, v222, v221)
166  .parallel(v221);
167  updated_bias1.update(0)
168  .split(v221, v221, v221i, 8, TailStrategy::GuardWithIf)
169  .vectorize(v221i)
170  .reorder(v221i, v221)
171  .parallel(v221);
172  updated_bias1.update(1)
173  .split(v221, v221, v221i, 8, TailStrategy::GuardWithIf)
174  .vectorize(v221i)
175  .reorder(v221i, v221)
176  .parallel(v221);
177  updated_bias1.update(2)
178  .split(v221, v221, v221i, 8, TailStrategy::GuardWithIf)
179  .vectorize(v221i)
180  .reorder(v221i, v221)
181  .parallel(v221);
182  updated_bias1.update(3)
183  .split(v221, v221, v221i, 8, TailStrategy::GuardWithIf)
184  .vectorize(v221i)
185  .reorder(v221i, v221)
186  .parallel(v221);
187  bias1_im_0_d_def__
188  .split(v2, v2, v2i, 8, TailStrategy::GuardWithIf)
189  .vectorize(v2i)
190  .compute_at(updated_bias1, v221)
191  .reorder(v2i, v2);
192  bias1_im_0_d_def__.update(0)
193  .split(v2, v2, v2i, 8, TailStrategy::GuardWithIf)
194  .vectorize(v2i)
195  .reorder(v2i, v2, r1226_x);
196  updated_filter1
197  .split(v218, v218, v218i, 16, TailStrategy::GuardWithIf)
198  .split(v219, v219, v219i, 2, TailStrategy::GuardWithIf)
199  .split(v220, v220, v220i, 2, TailStrategy::GuardWithIf)
200  .split(v218i, v218i, v218ii, 8, TailStrategy::GuardWithIf)
201  .vectorize(v218ii)
202  .compute_root()
203  .reorder(v218ii, v218i, v219i, v220i, v218, v219, v220)
204  .fuse(v219, v220, v219)
205  .fuse(v218, v219, v218)
206  .parallel(v218);
207  updated_filter1.update(0)
208  .split(v218, v218, v218i, 16, TailStrategy::GuardWithIf)
209  .split(v219, v219, v219i, 2, TailStrategy::GuardWithIf)
210  .split(v218i, v218i, v218ii, 8, TailStrategy::GuardWithIf)
211  .vectorize(v218ii)
212  .reorder(v218ii, v218i, v219i, v218, v219)
213  .fuse(v218, v219, v218)
214  .parallel(v218);
215  updated_filter1.update(1)
216  .split(v218, v218, v218i, 16, TailStrategy::GuardWithIf)
217  .split(v219, v219, v219i, 2, TailStrategy::GuardWithIf)
218  .split(v218i, v218i, v218ii, 8, TailStrategy::GuardWithIf)
219  .vectorize(v218ii)
220  .reorder(v218ii, v218i, v219i, v218, v219)
221  .fuse(v218, v219, v218)
222  .parallel(v218);
223  updated_filter1.update(2)
224  .split(v218, v218, v218i, 16, TailStrategy::GuardWithIf)
225  .split(v219, v219, v219i, 2, TailStrategy::GuardWithIf)
226  .split(v218i, v218i, v218ii, 8, TailStrategy::GuardWithIf)
227  .vectorize(v218ii)
228  .reorder(v218ii, v218i, v219i, v218, v219)
229  .fuse(v218, v219, v218)
230  .parallel(v218);
231  updated_filter1.update(3)
232  .split(v218, v218, v218i, 16, TailStrategy::GuardWithIf)
233  .split(v219, v219, v219i, 2, TailStrategy::GuardWithIf)
234  .split(v218i, v218i, v218ii, 8, TailStrategy::GuardWithIf)
235  .vectorize(v218ii)
236  .reorder(v218ii, v218i, v219i, v218, v219)
237  .fuse(v218, v219, v218)
238  .parallel(v218);
239  filter1_im_0_d_def__
240  .split(v4, v4, v4i, 8, TailStrategy::GuardWithIf)
241  .vectorize(v4i)
242  .compute_root()
243  .reorder(v4i, v4, v3)
244  .parallel(v3)
245  .reorder_storage(v4, v3);
246  filter1_im_0_d_def__.update(0)
247  .reorder(r1010_x, r1010_y, r1010_z, v3)
248  .parallel(v3);
249  filter1_im_0_d_def__.update(1)
250  .reorder(r1207_x, r1207_y, v3)
251  .parallel(v3);
252  updated_head2_bias
253  .split(v216, v216, v216i, 8, TailStrategy::GuardWithIf)
254  .vectorize(v216i)
255  .compute_root()
256  .reorder(v216i, v216, v217)
257  .fuse(v216, v217, v216)
258  .parallel(v216);
259  updated_head2_bias.update(0)
260  .split(v216, v216, v216i, 8, TailStrategy::GuardWithIf)
261  .vectorize(v216i)
262  .reorder(v216i, v216)
263  .parallel(v216);
264  updated_head2_bias.update(1)
265  .split(v216, v216, v216i, 8, TailStrategy::GuardWithIf)
266  .vectorize(v216i)
267  .reorder(v216i, v216)
268  .parallel(v216);
269  updated_head2_bias.update(2)
270  .split(v216, v216, v216i, 8, TailStrategy::GuardWithIf)
271  .vectorize(v216i)
272  .reorder(v216i, v216)
273  .parallel(v216);
274  updated_head2_bias.update(3)
275  .split(v216, v216, v216i, 8, TailStrategy::GuardWithIf)
276  .vectorize(v216i)
277  .reorder(v216i, v216)
278  .parallel(v216);
279  head2_bias_im_0_d_def__
280  .split(v12, v12, v12i, 8, TailStrategy::GuardWithIf)
281  .vectorize(v12i)
282  .compute_at(updated_head2_bias, v216)
283  .reorder(v12i, v12);
284  head2_bias_im_0_d_def__.update(0)
285  .split(v12, v12, v12i, 8, TailStrategy::GuardWithIf)
286  .vectorize(v12i)
287  .reorder(v12i, v12, r1114_x, r1114_y);
288  head2_conv_0_d_def___1
290  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
291  .vectorize(ci)
292  .compute_at(head2_bias_im_0_d_def__, v12)
293  .reorder(ci, c, w, n);
294  updated_head2_filter
295  .split(v213, v213, v213i, 8, TailStrategy::GuardWithIf)
296  .split(v214, v214, v214i, 2, TailStrategy::GuardWithIf)
297  .split(v215, v215, v215i, 2, TailStrategy::GuardWithIf)
298  .vectorize(v213i)
299  .compute_root()
300  .reorder(v213i, v214i, v215i, v213, v214, v215)
301  .fuse(v214, v215, v214)
302  .fuse(v213, v214, v213)
303  .parallel(v213);
304  updated_head2_filter.update(0)
305  .split(v213, v213, v213i, 8, TailStrategy::GuardWithIf)
306  .split(v214, v214, v214i, 2, TailStrategy::GuardWithIf)
307  .vectorize(v213i)
308  .reorder(v213i, v214i, v213, v214)
309  .fuse(v213, v214, v213)
310  .parallel(v213);
311  updated_head2_filter.update(1)
312  .split(v213, v213, v213i, 8, TailStrategy::GuardWithIf)
313  .split(v214, v214, v214i, 2, TailStrategy::GuardWithIf)
314  .vectorize(v213i)
315  .reorder(v213i, v214i, v213, v214)
316  .fuse(v213, v214, v213)
317  .parallel(v213);
318  updated_head2_filter.update(2)
319  .split(v213, v213, v213i, 8, TailStrategy::GuardWithIf)
320  .split(v214, v214, v214i, 2, TailStrategy::GuardWithIf)
321  .vectorize(v213i)
322  .reorder(v213i, v214i, v213, v214)
323  .fuse(v213, v214, v213)
324  .parallel(v213);
325  updated_head2_filter.update(3)
326  .split(v213, v213, v213i, 8, TailStrategy::GuardWithIf)
327  .split(v214, v214, v214i, 2, TailStrategy::GuardWithIf)
328  .vectorize(v213i)
329  .reorder(v213i, v214i, v213, v214)
330  .fuse(v213, v214, v213)
331  .parallel(v213);
332  head2_filter_im_0_d_def__
334  .split(v13, v13, v13i, 8, TailStrategy::GuardWithIf)
335  .vectorize(v13i)
336  .compute_at(updated_head2_filter, v214i)
337  .reorder(v13i, v13, v14);
338  head2_filter_im_0_d_def__.update(0)
339  .split(v13, v13, v13i, 8, TailStrategy::GuardWithIf)
340  .vectorize(v13i)
341  .reorder(v13i, v13, v14, r1095_x, r1095_y);
342  head2_conv_1_d_def__
343  .split(n, n, ni, 5, TailStrategy::GuardWithIf)
344  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
345  .vectorize(ci)
346  .compute_root()
347  .reorder(ci, ni, c, w, n)
348  .parallel(n);
349  head2_relu_0_d_def__
351  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
352  .vectorize(ci)
353  .compute_at(head2_conv_1_d_def__, c)
354  .reorder(ci, c, w, n);
355  head2_relu_0_d_def__.update(0)
356  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
357  .vectorize(ci)
358  .reorder(ci, c, w, n, r986_x);
359  updated_head1_bias
360  .split(v211, v211, v211i, 8, TailStrategy::GuardWithIf)
361  .vectorize(v211i)
362  .compute_root()
363  .reorder(v211i, v211, v212)
364  .parallel(v212);
365  updated_head1_bias.update(0)
366  .split(v211, v211, v211i, 8, TailStrategy::GuardWithIf)
367  .vectorize(v211i)
368  .reorder(v211i, v211);
369  updated_head1_bias.update(1)
370  .split(v211, v211, v211i, 8, TailStrategy::GuardWithIf)
371  .vectorize(v211i)
372  .reorder(v211i, v211);
373  updated_head1_bias.update(2)
374  .split(v211, v211, v211i, 8, TailStrategy::GuardWithIf)
375  .vectorize(v211i)
376  .reorder(v211i, v211);
377  updated_head1_bias.update(3)
378  .split(v211, v211, v211i, 8, TailStrategy::GuardWithIf)
379  .vectorize(v211i)
380  .reorder(v211i, v211);
381  head1_bias_im_0_d_def__
382  .split(v5, v5, v5i, 8, TailStrategy::GuardWithIf)
383  .vectorize(v5i)
384  .compute_root()
385  .reorder(v5i, v5);
386  head1_bias_im_0_d_def__.update(0)
387  .split(v5, v5, v5i, 8, TailStrategy::GuardWithIf)
388  .vectorize(v5i)
389  .reorder(v5i, v5, r1302_x);
390  updated_head1_filter
391  .split(v208, v208, v208i, 2, TailStrategy::GuardWithIf)
392  .split(v209, v209, v209i, 2, TailStrategy::GuardWithIf)
393  .split(v210, v210, v210i, 2, TailStrategy::GuardWithIf)
394  .split(v207, v207, v207i, 8, TailStrategy::GuardWithIf)
395  .vectorize(v207i)
396  .compute_root()
397  .reorder(v207i, v207, v208i, v209i, v210i, v208, v209, v210)
398  .fuse(v209, v210, v209)
399  .fuse(v208, v209, v208)
400  .parallel(v208);
401  updated_head1_filter.update(0)
402  .split(v208, v208, v208i, 2, TailStrategy::GuardWithIf)
403  .split(v209, v209, v209i, 2, TailStrategy::GuardWithIf)
404  .split(v207, v207, v207i, 8, TailStrategy::GuardWithIf)
405  .vectorize(v207i)
406  .reorder(v207i, v207, v208i, v209i, v208, v209)
407  .fuse(v208, v209, v208)
408  .parallel(v208);
409  updated_head1_filter.update(1)
410  .split(v208, v208, v208i, 2, TailStrategy::GuardWithIf)
411  .split(v209, v209, v209i, 2, TailStrategy::GuardWithIf)
412  .split(v207, v207, v207i, 8, TailStrategy::GuardWithIf)
413  .vectorize(v207i)
414  .reorder(v207i, v207, v208i, v209i, v208, v209)
415  .fuse(v208, v209, v208)
416  .parallel(v208);
417  updated_head1_filter.update(2)
418  .split(v208, v208, v208i, 2, TailStrategy::GuardWithIf)
419  .split(v209, v209, v209i, 2, TailStrategy::GuardWithIf)
420  .split(v207, v207, v207i, 8, TailStrategy::GuardWithIf)
421  .vectorize(v207i)
422  .reorder(v207i, v207, v208i, v209i, v208, v209)
423  .fuse(v208, v209, v208)
424  .parallel(v208);
425  updated_head1_filter.update(3)
426  .split(v208, v208, v208i, 2, TailStrategy::GuardWithIf)
427  .split(v209, v209, v209i, 2, TailStrategy::GuardWithIf)
428  .split(v207, v207, v207i, 8, TailStrategy::GuardWithIf)
429  .vectorize(v207i)
430  .reorder(v207i, v207, v208i, v209i, v208, v209)
431  .fuse(v208, v209, v208)
432  .parallel(v208);
433  squashed_head1_filter_0_d_def__
435  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
436  .vectorize(ci)
437  .compute_at(updated_head1_filter, v207)
438  .reorder(ci, c, s, n);
439  squashed_head1_filter_0_d_def__.update(0)
440  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
441  .vectorize(ci)
442  .reorder(ci, c, s, n, r1321_x);
443  head1_conv_1_d_def__
444  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
445  .vectorize(ci)
446  .compute_root()
447  .reorder(ci, c, w)
448  .parallel(w);
449  head1_conv_1_d_def__.update(0)
450  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
451  .vectorize(ci)
452  .reorder(ci, c, r1183_x, w)
453  .parallel(w);
454  conv1_stage1_1_d_def__
455  .split(w, w, wi, 8, TailStrategy::GuardWithIf)
456  .vectorize(wi)
457  .compute_root()
458  .reorder(wi, w, c)
459  .parallel(c)
460  .reorder_storage(w, c);
461  conv1_stage1_1_d_def__.update(0)
462  .split(r1029_x, r1029_x, r1029_xi, 2, TailStrategy::GuardWithIf)
463  .split(w, w, wi, 8, TailStrategy::GuardWithIf)
464  .vectorize(wi)
465  .reorder(wi, r1029_xi, r1029_x, w, c)
466  .parallel(c);
467  conv1_stage2_0_d_def___1
469  .split(w, w, wi, 8, TailStrategy::GuardWithIf)
470  .vectorize(wi)
471  .compute_at(conv1_stage1_1_d_def__, r1029_xi)
472  .store_at(conv1_stage1_1_d_def__, r1029_x)
473  .reorder(wi, w, c, n)
474  .reorder_storage(w, c, n);
475  conv1_stage2_1_d_def__
476  .split(c, c, ci, 2, TailStrategy::GuardWithIf)
477  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
478  .vectorize(ni)
479  .compute_root()
480  .reorder(ni, n, ci, w, c)
481  .parallel(c)
482  .reorder_storage(n, c, w);
483  sum_1_d_def__
484  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
485  .vectorize(ni)
486  .compute_root()
487  .reorder(ni, n)
488  .parallel(n);
489  relu1_0_d_def__
490  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
491  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
492  .vectorize(ni)
493  .compute_root()
494  .reorder(ni, c, wi, n, w)
495  .fuse(n, w, n)
496  .parallel(n)
497  .reorder_storage(n, c, w);
498  relu1_0_d_def__.update(0)
499  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
500  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
501  .vectorize(ni)
502  .reorder(ni, wi, n, w)
503  .fuse(n, w, n)
504  .parallel(n);
505  relu1_0_d_def__.update(1)
506  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
507  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
508  .vectorize(ni)
509  .reorder(ni, wi, n, w)
510  .fuse(n, w, n)
511  .parallel(n);
512  relu1_0_d_def__.update(2)
513  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
514  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
515  .vectorize(ni)
516  .reorder(ni, wi, n, w)
517  .fuse(n, w, n)
518  .parallel(n);
519  relu1_0_d_def__.update(3)
520  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
521  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
522  .vectorize(ni)
523  .reorder(ni, wi, n, w)
524  .fuse(n, w, n)
525  .parallel(n);
526  relu1_0_d_def__.update(4)
527  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
528  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
529  .vectorize(ni)
530  .reorder(ni, wi, n, w)
531  .fuse(n, w, n)
532  .parallel(n);
533  relu1_0_d_def__.update(5)
534  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
535  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
536  .vectorize(ni)
537  .reorder(ni, wi, n, w)
538  .fuse(n, w, n)
539  .parallel(n);
540  relu1_0_d_def__.update(6)
541  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
542  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
543  .vectorize(ni)
544  .reorder(ni, wi, n, w)
545  .fuse(n, w, n)
546  .parallel(n);
547  relu1_0_d_def__.update(7)
548  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
549  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
550  .vectorize(ni)
551  .reorder(ni, wi, n, w)
552  .fuse(n, w, n)
553  .parallel(n);
554  relu1_0_d_def__.update(8)
555  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
556  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
557  .vectorize(ni)
558  .reorder(ni, wi, n, w)
559  .fuse(n, w, n)
560  .parallel(n);
561  relu1_0_d_def__.update(9)
562  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
563  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
564  .vectorize(ni)
565  .reorder(ni, wi, n, w)
566  .fuse(n, w, n)
567  .parallel(n);
568  relu1_0_d_def__.update(10)
569  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
570  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
571  .vectorize(ni)
572  .reorder(ni, wi, n, w)
573  .fuse(n, w, n)
574  .parallel(n);
575  relu1_0_d_def__.update(11)
576  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
577  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
578  .vectorize(ni)
579  .reorder(ni, wi, n, w)
580  .fuse(n, w, n)
581  .parallel(n);
582  relu1_0_d_def__.update(12)
583  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
584  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
585  .vectorize(ni)
586  .reorder(ni, wi, n, w)
587  .fuse(n, w, n)
588  .parallel(n);
589  relu1_0_d_def__.update(13)
590  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
591  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
592  .vectorize(ni)
593  .reorder(ni, wi, n, w)
594  .fuse(n, w, n)
595  .parallel(n);
596  relu1_0_d_def__.update(14)
597  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
598  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
599  .vectorize(ni)
600  .reorder(ni, wi, n, w)
601  .fuse(n, w, n)
602  .parallel(n);
603  relu1_0_d_def__.update(15)
604  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
605  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
606  .vectorize(ni)
607  .reorder(ni, wi, n, w)
608  .fuse(n, w, n)
609  .parallel(n);
610  relu1_0_d_def__.update(16)
611  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
612  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
613  .vectorize(ni)
614  .reorder(ni, wi, n, w)
615  .fuse(n, w, n)
616  .parallel(n);
617  relu1_0_d_def__.update(17)
618  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
619  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
620  .vectorize(ni)
621  .reorder(ni, wi, n, w)
622  .fuse(n, w, n)
623  .parallel(n);
624  relu1_0_d_def__.update(18)
625  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
626  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
627  .vectorize(ni)
628  .reorder(ni, wi, n, w)
629  .fuse(n, w, n)
630  .parallel(n);
631  relu1_0_d_def__.update(19)
632  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
633  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
634  .vectorize(ni)
635  .reorder(ni, wi, n, w)
636  .fuse(n, w, n)
637  .parallel(n);
638  relu1_0_d_def__.update(20)
639  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
640  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
641  .vectorize(ni)
642  .reorder(ni, wi, n, w)
643  .fuse(n, w, n)
644  .parallel(n);
645  relu1_0_d_def__.update(21)
646  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
647  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
648  .vectorize(ni)
649  .reorder(ni, wi, n, w)
650  .fuse(n, w, n)
651  .parallel(n);
652  relu1_0_d_def__.update(22)
653  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
654  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
655  .vectorize(ni)
656  .reorder(ni, wi, n, w)
657  .fuse(n, w, n)
658  .parallel(n);
659  relu1_0_d_def__.update(23)
660  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
661  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
662  .vectorize(ni)
663  .reorder(ni, wi, n, w)
664  .fuse(n, w, n)
665  .parallel(n);
666  relu1_0_d_def__.update(24)
667  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
668  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
669  .vectorize(ni)
670  .reorder(ni, wi, n, w)
671  .fuse(n, w, n)
672  .parallel(n);
673  relu1_0_d_def__.update(25)
674  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
675  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
676  .vectorize(ni)
677  .reorder(ni, wi, n, w)
678  .fuse(n, w, n)
679  .parallel(n);
680  relu1_0_d_def__.update(26)
681  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
682  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
683  .vectorize(ni)
684  .reorder(ni, wi, n, w)
685  .fuse(n, w, n)
686  .parallel(n);
687  relu1_0_d_def__.update(27)
688  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
689  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
690  .vectorize(ni)
691  .reorder(ni, wi, n, w)
692  .fuse(n, w, n)
693  .parallel(n);
694  relu1_0_d_def__.update(28)
695  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
696  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
697  .vectorize(ni)
698  .reorder(ni, wi, n, w)
699  .fuse(n, w, n)
700  .parallel(n);
701  relu1_0_d_def__.update(29)
702  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
703  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
704  .vectorize(ni)
705  .reorder(ni, wi, n, w)
706  .fuse(n, w, n)
707  .parallel(n);
708  relu1_0_d_def__.update(30)
709  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
710  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
711  .vectorize(ni)
712  .reorder(ni, wi, n, w)
713  .fuse(n, w, n)
714  .parallel(n);
715  relu1_0_d_def__.update(31)
716  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
717  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
718  .vectorize(ni)
719  .reorder(ni, wi, n, w)
720  .fuse(n, w, n)
721  .parallel(n);
722  f0_0_d_def__
723  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
724  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
725  .vectorize(ni)
726  .compute_root()
727  .reorder(ni, wi, n, w)
728  .fuse(n, w, n)
729  .parallel(n);
730  f1_1_d_def__
731  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
732  .vectorize(ni)
733  .compute_at(f0_0_d_def__, n)
734  .reorder(ni, n);
735  sum_1_1_d_def__
736  .compute_root();
737  f1
738  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
739  .vectorize(ni)
740  .compute_root()
741  .reorder(ni, n)
742  .parallel(n);
743  f1.update(0)
744  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
745  .vectorize(ni)
746  .reorder(ni, r24_x, n)
747  .parallel(n);
748  conv1_stage2
749  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
750  .split(w, w, wi, 4, TailStrategy::GuardWithIf)
751  .split(wi, wi, wii, 2, TailStrategy::GuardWithIf)
752  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
753  .vectorize(ni)
754  .compute_root()
755  .reorder(ni, n, ci, wii, wi, c, w)
756  .fuse(c, w, c)
757  .parallel(c)
758  .reorder_storage(n, c, w);
759  conv1_stage2.update(0)
760  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
761  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
762  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
763  .vectorize(ni)
764  .reorder(ni, r19_x, n, ci, wi, c, w)
765  .fuse(c, w, c)
766  .parallel(c);
767  head2_relu
768  .split(c, c, ci, 3, TailStrategy::GuardWithIf)
769  .split(w, w, wi, 7, TailStrategy::GuardWithIf)
770  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
771  .vectorize(ni)
772  .compute_root()
773  .reorder(ni, n, ci, wi, c, w)
774  .fuse(c, w, c)
775  .parallel(c)
776  .reorder_storage(n, c, w);
777  head2_conv
778  .split(n, n, ni, 40, TailStrategy::GuardWithIf)
779  .split(c, c, ci, 12, TailStrategy::GuardWithIf)
780  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
781  .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
782  .vectorize(nii)
783  .compute_root()
784  .reorder(nii, ni, ci, wi, n, c, w)
785  .fuse(c, w, c)
786  .fuse(n, c, n)
787  .parallel(n)
788  .reorder_storage(n, c, w);
789  head2_conv.update(0)
790  .split(n, n, ni, 40, TailStrategy::GuardWithIf)
791  .split(c, c, ci, 12, TailStrategy::GuardWithIf)
792  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
793  .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
794  .vectorize(nii)
795  .reorder(nii, r9_x, ni, ci, wi, n, c, w)
796  .fuse(c, w, c)
797  .fuse(n, c, n)
798  .parallel(n);
799  normalized_schedule_features
800  .split(c, c, ci, 5, TailStrategy::GuardWithIf)
801  .split(s, s, si, 7, TailStrategy::GuardWithIf)
802  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
803  .vectorize(ni)
804  .compute_root()
805  .reorder(ni, n, ci, si, c, s)
806  .fuse(c, s, c)
807  .parallel(c);
808  conv1_stage1
809  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
810  .vectorize(ci)
811  .compute_at(conv1_stage2, c)
812  .reorder(ci, c, w);
813  conv1_stage1.update(0)
814  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
815  .vectorize(ci)
816  .reorder(ci, c, w, r14_x);
817  head1_conv
818  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
819  .vectorize(ci)
820  .compute_root()
821  .reorder(ci, c, w)
822  .parallel(w);
823  head1_conv.update(0)
824  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
825  .vectorize(ci)
826  .reorder(ci, c, r4_x, r4_y, w)
827  .parallel(w);
828  squashed_head1_filter_broadcast
830  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
831  .vectorize(ci)
832  .compute_at(head1_conv, c)
833  .reorder(ci, c, w, s, n);
834  squashed_head1_filter
835  .split(s, s, si, 10, TailStrategy::GuardWithIf)
836  .split(n, n, ni, 2, TailStrategy::GuardWithIf)
837  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
838  .vectorize(ci)
839  .compute_root()
840  .reorder(ci, c, si, ni, s, n)
841  .fuse(s, n, s)
842  .parallel(s);
843 }
Halide::Stage::parallel
Stage & parallel(const VarOrRVar &var)
Halide::sum
Expr sum(Expr, const std::string &s="sum")
An inline reduction.
Halide::Stage::reorder
Stage & reorder(const std::vector< VarOrRVar > &vars)
Halide::Var
A Halide variable, to be used when defining functions.
Definition: Var.h:19
Halide::TailStrategy::GuardWithIf
@ GuardWithIf
Guard the inner loop with an if statement that prevents evaluation beyond the original extent.
Halide::Func::reorder
Func & reorder(const std::vector< VarOrRVar > &vars)
Reorder variables to have the given nesting order, from innermost out.
Halide::Func::store_at
Func & store_at(const Func &f, const Var &var)
Allocate storage for this function within f's loop over var.
Halide::Func::vectorize
Func & vectorize(const VarOrRVar &var)
Mark a dimension to be computed all-at-once as a single vector.
Halide::Func::compute_at
Func & compute_at(const Func &f, const Var &var)
Compute this function as needed for each unique value of the given var for the given calling function...
Halide::Func::fuse
Func & fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRVar &fused)
Join two dimensions into a single fused dimenion.
Halide::TailStrategy
TailStrategy
Different ways to handle a tail case in a split when the factor does not provably divide the extent.
Definition: Schedule.h:32
Halide::Pipeline
A class representing a Halide pipeline.
Definition: Pipeline.h:97
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AddAtomicMutex.h:21
Halide::Internal::StageSchedule::dims
const std::vector< Dim > & dims() const
The list and ordering of dimensions used to evaluate this function, after all splits have taken place...
Halide::Func::update
Stage update(int idx=0)
Get a handle on an update step for the purposes of scheduling it.
Halide::Func::parallel
Func & parallel(const VarOrRVar &var)
Mark a dimension to be traversed in parallel.
Halide::Func::store_in
Func & store_in(MemoryType memory_type)
Set the type of memory this Func should be stored in.
Halide::Func
A halide function.
Definition: Func.h:667
Halide::Func::compute_root
Func & compute_root()
Compute all of this function once ahead of time.
Halide::RVar
A reduction variable represents a single dimension of a reduction domain (RDom).
Definition: RDom.h:29
Halide::Func::get_schedule
const Internal::StageSchedule & get_schedule() const
Return the current StageSchedule associated with this initial Stage of this Func.
Definition: Func.h:2445
Halide::MemoryType
MemoryType
An enum describing different address spaces to be used with Func::store_in.
Definition: Expr.h:346
Halide::MemoryType::Stack
@ Stack
Stack memory.
Halide::Pipeline::get_func
Func get_func(size_t index)
Return handle to the index-th Func within the pipeline based on the topological order.
do_cost_model_schedule
void do_cost_model_schedule(Halide::Pipeline pipeline)
Definition: cost_model_schedule.h:5
Halide::Func::split
Func & split(const VarOrRVar &old, const VarOrRVar &outer, const VarOrRVar &inner, const Expr &factor, TailStrategy tail=TailStrategy::Auto)
Split a dimension into inner and outer subdimensions with the given names, where the inner dimension ...
Halide::Stage::get_schedule
const Internal::StageSchedule & get_schedule() const
Return the current StageSchedule associated with this Stage.
Definition: Func.h:107