Halide
cost_model_schedule.h
Go to the documentation of this file.
1 #include "Halide.h"
2 
3 using namespace Halide;
4 
5 inline void do_cost_model_schedule(Halide::Pipeline pipeline) {
6  // Generated by autoscheduler, manually remove unrolls.
7  // Also manually replaced all RoundUp and ShiftInwards with GuardWithIf.
8 
9  using ::Halide::Func;
11  using ::Halide::RVar;
13  using ::Halide::Var;
14  Func loss_output = pipeline.get_func(57);
15  Func sum_1 = pipeline.get_func(56);
16  Func f2 = pipeline.get_func(55);
17  Func sum = pipeline.get_func(54);
18  Func prediction_output = pipeline.get_func(53);
19  Func updated_bias1 = pipeline.get_func(52);
20  Func bias1_im_0_d_def__ = pipeline.get_func(51);
21  Func conv1_stage1_0_d_def___1 = pipeline.get_func(50);
22  Func updated_filter1 = pipeline.get_func(49);
23  Func filter1_im_0_d_def__ = pipeline.get_func(48);
24  Func updated_head2_bias = pipeline.get_func(47);
25  Func head2_bias_im_0_d_def__ = pipeline.get_func(46);
26  Func head2_conv_0_d_def___1 = pipeline.get_func(45);
27  Func updated_head2_filter = pipeline.get_func(44);
28  Func head2_filter_im_0_d_def__ = pipeline.get_func(43);
29  Func head2_conv_1_d_def__ = pipeline.get_func(42);
30  Func head2_relu_0_d_def__ = pipeline.get_func(41);
31  Func updated_head1_bias = pipeline.get_func(40);
32  Func head1_bias_im_0_d_def__ = pipeline.get_func(39);
33  Func head1_conv_0_d_def___1 = pipeline.get_func(38);
34  Func updated_head1_filter = pipeline.get_func(37);
35  Func head1_filter_im_0_d_def__ = pipeline.get_func(36);
36  Func squashed_head1_filter_0_d_def__ = pipeline.get_func(35);
37  Func squashed_head1_filter_broadcast_0_d_def__ = pipeline.get_func(34);
38  Func head1_conv_1_d_def__ = pipeline.get_func(33);
39  Func conv1_stage1_1_d_def__ = pipeline.get_func(32);
40  Func conv1_stage2_0_d_def___1 = pipeline.get_func(31);
41  Func conv1_stage2_1_d_def__ = pipeline.get_func(30);
42  Func sum_1_d_def__ = pipeline.get_func(29);
43  Func relu1_0_d_def__ = pipeline.get_func(28);
44  Func f0_0_d_def__ = pipeline.get_func(27);
45  Func cost_per_stage_output_0_d_def__ = pipeline.get_func(26);
46  Func f1_1_d_def__ = pipeline.get_func(25);
47  Func f2_0_d_def__ = pipeline.get_func(23);
48  Func sum_1_1_d_def__ = pipeline.get_func(22);
49  Func loss_output_0_d_def__ = pipeline.get_func(21);
50  Func adjoint = pipeline.get_func(20);
51  Func f1 = pipeline.get_func(19);
52  Func cost_per_stage_output = pipeline.get_func(18);
53  Func f0 = pipeline.get_func(17);
54  Func relu1 = pipeline.get_func(16);
55  Func conv1_stage2 = pipeline.get_func(15);
56  Func head2_relu = pipeline.get_func(14);
57  Func head2_conv = pipeline.get_func(13);
58  Func normalized_schedule_features = pipeline.get_func(12);
59  Func conv1_stage1 = pipeline.get_func(8);
60  Func head1_conv = pipeline.get_func(7);
61  Func squashed_head1_filter_broadcast = pipeline.get_func(6);
62  Func squashed_head1_filter = pipeline.get_func(5);
63  Var c(head2_conv_0_d_def___1.get_schedule().dims()[0].var);
64  Var ci("ci");
65  Var n(sum.get_schedule().dims()[0].var);
66  Var ni("ni");
67  Var nii("nii");
68  Var niii("niii");
69  Var r1316_z(filter1_im_0_d_def__.update(0).get_schedule().dims()[2].var);
70  Var r1512_y(filter1_im_0_d_def__.update(1).get_schedule().dims()[1].var);
71  Var s(squashed_head1_filter_0_d_def__.get_schedule().dims()[1].var);
72  Var si("si");
73  Var v11(bias1_im_0_d_def__.get_schedule().dims()[0].var);
74  Var v11i("v11i");
75  Var v12(filter1_im_0_d_def__.get_schedule().dims()[0].var);
76  Var v13(filter1_im_0_d_def__.get_schedule().dims()[1].var);
77  Var v13i("v13i");
78  Var v14(head1_bias_im_0_d_def__.get_schedule().dims()[0].var);
79  Var v14i("v14i");
80  Var v21(head2_bias_im_0_d_def__.get_schedule().dims()[0].var);
81  Var v21i("v21i");
82  Var v22(head2_filter_im_0_d_def__.get_schedule().dims()[0].var);
83  Var v22i("v22i");
84  Var v23(head2_filter_im_0_d_def__.get_schedule().dims()[1].var);
85  Var v298(updated_head1_filter.get_schedule().dims()[0].var);
86  Var v298i("v298i");
87  Var v299(updated_head1_filter.get_schedule().dims()[1].var);
88  Var v299i("v299i");
89  Var v300(updated_head1_filter.get_schedule().dims()[2].var);
90  Var v301(updated_head1_filter.get_schedule().dims()[3].var);
91  Var v301i("v301i");
92  Var v302(updated_head1_bias.get_schedule().dims()[0].var);
93  Var v302i("v302i");
94  Var v303(updated_head1_bias.get_schedule().dims()[1].var);
95  Var v304(updated_head2_filter.get_schedule().dims()[0].var);
96  Var v304i("v304i");
97  Var v305(updated_head2_filter.get_schedule().dims()[1].var);
98  Var v306(updated_head2_filter.get_schedule().dims()[2].var);
99  Var v307(updated_head2_bias.get_schedule().dims()[0].var);
100  Var v307i("v307i");
101  Var v308(updated_head2_bias.get_schedule().dims()[1].var);
102  Var v309(updated_filter1.get_schedule().dims()[0].var);
103  Var v309i("v309i");
104  Var v310(updated_filter1.get_schedule().dims()[1].var);
105  Var v311(updated_filter1.get_schedule().dims()[2].var);
106  Var v312(updated_bias1.get_schedule().dims()[0].var);
107  Var v312i("v312i");
108  Var v313(updated_bias1.get_schedule().dims()[1].var);
109  Var w(head2_conv_0_d_def___1.get_schedule().dims()[1].var);
110  Var wi("wi");
111  RVar r1294_x(head2_relu_0_d_def__.update(0).get_schedule().dims()[0].var);
112  RVar r1316_x(filter1_im_0_d_def__.update(0).get_schedule().dims()[0].var);
113  RVar r1316_y(filter1_im_0_d_def__.update(0).get_schedule().dims()[1].var);
114  RVar r1336_x(conv1_stage1_1_d_def__.update(0).get_schedule().dims()[0].var);
115  RVar r1400_x(head2_filter_im_0_d_def__.update(0).get_schedule().dims()[0].var);
116  RVar r1400_y(head2_filter_im_0_d_def__.update(0).get_schedule().dims()[1].var);
117  RVar r1421_x(head2_bias_im_0_d_def__.update(0).get_schedule().dims()[0].var);
118  RVar r1421_y(head2_bias_im_0_d_def__.update(0).get_schedule().dims()[1].var);
119  RVar r1491_x(head1_conv_1_d_def__.update(0).get_schedule().dims()[0].var);
120  RVar r1512_x(filter1_im_0_d_def__.update(1).get_schedule().dims()[0].var);
121  RVar r1532_x(bias1_im_0_d_def__.update(0).get_schedule().dims()[0].var);
122  RVar r1594_x(head1_bias_im_0_d_def__.update(0).get_schedule().dims()[0].var);
123  RVar r1614_x(squashed_head1_filter_0_d_def__.update(0).get_schedule().dims()[0].var);
124  RVar r31_x(head1_conv.update(0).get_schedule().dims()[0].var);
125  RVar r31_y(head1_conv.update(0).get_schedule().dims()[1].var);
126  RVar r40_x(head2_conv.update(0).get_schedule().dims()[0].var);
127  RVar r54_x(conv1_stage1.update(0).get_schedule().dims()[0].var);
128  RVar r63_x(conv1_stage2.update(0).get_schedule().dims()[0].var);
129  RVar r81_x(f1.update(0).get_schedule().dims()[0].var);
130  RVar r89_x(sum_1.update(0).get_schedule().dims()[0].var);
131  RVar r94_x(sum.update(0).get_schedule().dims()[0].var);
132  RVar r94_y(sum.update(0).get_schedule().dims()[1].var);
133  loss_output
134  .compute_root();
135  sum_1
136  .compute_root();
137  sum_1.update(0);
138  sum
139  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
140  .vectorize(ni)
141  .compute_root()
142  .reorder(ni, n)
143  .serial(n);
144  sum.update(0)
145  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
146  .vectorize(ni)
147  .reorder(ni, r94_x, r94_y, n)
148  .serial(n);
149  prediction_output
150  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
151  .vectorize(ni)
152  .compute_root()
153  .reorder(ni, n)
154  .serial(n);
155  updated_bias1
156  .split(v312, v312, v312i, 8, TailStrategy::GuardWithIf)
157  .vectorize(v312i)
158  .compute_root()
159  .reorder(v312i, v312, v313)
160  .fuse(v312, v313, v312)
161  .serial(v312);
162  updated_bias1.update(0)
163  .split(v312, v312, v312i, 8, TailStrategy::GuardWithIf)
164  .vectorize(v312i)
165  .reorder(v312i, v312)
166  .serial(v312);
167  updated_bias1.update(1)
168  .split(v312, v312, v312i, 8, TailStrategy::GuardWithIf)
169  .vectorize(v312i)
170  .reorder(v312i, v312)
171  .serial(v312);
172  updated_bias1.update(2)
173  .split(v312, v312, v312i, 8, TailStrategy::GuardWithIf)
174  .vectorize(v312i)
175  .reorder(v312i, v312)
176  .serial(v312);
177  updated_bias1.update(3)
178  .split(v312, v312, v312i, 8, TailStrategy::GuardWithIf)
179  .vectorize(v312i)
180  .reorder(v312i, v312)
181  .serial(v312);
182  bias1_im_0_d_def__
183  .split(v11, v11, v11i, 8, TailStrategy::GuardWithIf)
184  .vectorize(v11i)
185  .compute_at(updated_bias1, v312)
186  .reorder(v11i, v11);
187  bias1_im_0_d_def__.update(0)
188  .split(v11, v11, v11i, 8, TailStrategy::GuardWithIf)
189  .vectorize(v11i)
190  .reorder(v11i, v11, r1532_x);
191  updated_filter1
192  .split(v309, v309, v309i, 8, TailStrategy::GuardWithIf)
193  .vectorize(v309i)
194  .compute_root()
195  .reorder(v309i, v311, v309, v310)
196  .fuse(v309, v310, v309)
197  .serial(v309);
198  updated_filter1.update(0)
199  .split(v309, v309, v309i, 8, TailStrategy::GuardWithIf)
200  .vectorize(v309i)
201  .reorder(v309i, v309, v310)
202  .fuse(v309, v310, v309)
203  .serial(v309);
204  updated_filter1.update(1)
205  .split(v309, v309, v309i, 8, TailStrategy::GuardWithIf)
206  .vectorize(v309i)
207  .reorder(v309i, v309, v310)
208  .fuse(v309, v310, v309)
209  .serial(v309);
210  updated_filter1.update(2)
211  .split(v309, v309, v309i, 8, TailStrategy::GuardWithIf)
212  .vectorize(v309i)
213  .reorder(v309i, v309, v310)
214  .fuse(v309, v310, v309)
215  .serial(v309);
216  updated_filter1.update(3)
217  .split(v309, v309, v309i, 8, TailStrategy::GuardWithIf)
218  .vectorize(v309i)
219  .reorder(v309i, v309, v310)
220  .fuse(v309, v310, v309)
221  .serial(v309);
222  filter1_im_0_d_def__
223  .split(v13, v13, v13i, 8, TailStrategy::GuardWithIf)
224  .vectorize(v13i)
225  .compute_root()
226  .reorder(v13i, v13, v12)
227  .fuse(v13, v12, v13)
228  .parallel(v13)
229  .reorder_storage(v13, v12);
230  filter1_im_0_d_def__.update(0)
231  .reorder(r1316_z, r1316_x, r1316_y, v12)
232  .vectorize(r1316_z, 8)
233  .unroll(r1316_z)
234  .parallel(v12);
235  filter1_im_0_d_def__.update(1)
236  .reorder(r1512_x, r1512_y, v12)
237  .vectorize(r1512_y)
238  .parallel(v12);
239  updated_head2_bias
240  .split(v307, v307, v307i, 8, TailStrategy::GuardWithIf)
241  .vectorize(v307i)
242  .compute_root()
243  .reorder(v307i, v307, v308)
244  .fuse(v307, v308, v307)
245  .serial(v307);
246  updated_head2_bias.update(0)
247  .split(v307, v307, v307i, 8, TailStrategy::GuardWithIf)
248  .vectorize(v307i)
249  .reorder(v307i, v307)
250  .serial(v307);
251  updated_head2_bias.update(1)
252  .split(v307, v307, v307i, 8, TailStrategy::GuardWithIf)
253  .vectorize(v307i)
254  .reorder(v307i, v307)
255  .serial(v307);
256  updated_head2_bias.update(2)
257  .split(v307, v307, v307i, 8, TailStrategy::GuardWithIf)
258  .vectorize(v307i)
259  .reorder(v307i, v307)
260  .serial(v307);
261  updated_head2_bias.update(3)
262  .split(v307, v307, v307i, 8, TailStrategy::GuardWithIf)
263  .vectorize(v307i)
264  .reorder(v307i, v307)
265  .serial(v307);
266  head2_bias_im_0_d_def__
267  .split(v21, v21, v21i, 8, TailStrategy::GuardWithIf)
268  .vectorize(v21i)
269  .compute_at(updated_head2_bias, v307)
270  .reorder(v21i, v21);
271  head2_bias_im_0_d_def__.update(0)
272  .split(v21, v21, v21i, 8, TailStrategy::GuardWithIf)
273  .vectorize(v21i)
274  .reorder(v21i, v21, r1421_x, r1421_y);
275  head2_conv_0_d_def___1
277  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
278  .vectorize(ci)
279  .compute_at(head2_bias_im_0_d_def__, v21)
280  .reorder(ci, c, w, n);
281  updated_head2_filter
282  .split(v304, v304, v304i, 8, TailStrategy::GuardWithIf)
283  .vectorize(v304i)
284  .compute_root()
285  .reorder(v304i, v306, v304, v305)
286  .fuse(v304, v305, v304)
287  .parallel(v304);
288  updated_head2_filter.update(0)
289  .split(v304, v304, v304i, 8, TailStrategy::GuardWithIf)
290  .vectorize(v304i)
291  .reorder(v304i, v304, v305)
292  .fuse(v304, v305, v304)
293  .parallel(v304);
294  updated_head2_filter.update(1)
295  .split(v304, v304, v304i, 8, TailStrategy::GuardWithIf)
296  .vectorize(v304i)
297  .reorder(v304i, v304, v305)
298  .fuse(v304, v305, v304)
299  .serial(v304);
300  updated_head2_filter.update(2)
301  .split(v304, v304, v304i, 8, TailStrategy::GuardWithIf)
302  .vectorize(v304i)
303  .reorder(v304i, v304, v305)
304  .fuse(v304, v305, v304)
305  .serial(v304);
306  updated_head2_filter.update(3)
307  .split(v304, v304, v304i, 8, TailStrategy::GuardWithIf)
308  .vectorize(v304i)
309  .reorder(v304i, v304, v305)
310  .fuse(v304, v305, v304)
311  .parallel(v304);
312  head2_filter_im_0_d_def__
313  .split(v22, v22, v22i, 8, TailStrategy::GuardWithIf)
314  .vectorize(v22i)
315  .compute_at(updated_head2_filter, v304)
316  .reorder(v22i, v22, v23);
317  head2_filter_im_0_d_def__.update(0)
318  .split(v22, v22, v22i, 8, TailStrategy::GuardWithIf)
319  .vectorize(v22i)
320  .reorder(v22i, v22, v23, r1400_x, r1400_y);
321  head2_conv_1_d_def__
322  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
323  .split(n, n, ni, 128, TailStrategy::GuardWithIf)
324  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
325  .split(ni, ni, nii, 4, TailStrategy::GuardWithIf)
326  .unroll(wi)
327  .unroll(nii)
328  .vectorize(ci)
329  .compute_root()
330  .reorder(ci, wi, nii, c, ni, w, n)
331  .fuse(w, n, w)
332  .serial(w);
333  head2_relu_0_d_def__
335  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
336  .unroll(w)
337  .unroll(n)
338  .vectorize(ci)
339  .compute_at(head2_conv_1_d_def__, c)
340  .reorder(ci, c, w, n);
341  head2_relu_0_d_def__.update(0)
342  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
343  .unroll(w)
344  .unroll(n)
345  .vectorize(ci)
346  .reorder(ci, c, w, n, r1294_x);
347  updated_head1_bias
348  .split(v302, v302, v302i, 8, TailStrategy::GuardWithIf)
349  .vectorize(v302i)
350  .compute_root()
351  .reorder(v302i, v302, v303)
352  .serial(v303);
353  updated_head1_bias.update(0)
354  .split(v302, v302, v302i, 8, TailStrategy::GuardWithIf)
355  .vectorize(v302i)
356  .reorder(v302i, v302);
357  updated_head1_bias.update(1)
358  .split(v302, v302, v302i, 8, TailStrategy::GuardWithIf)
359  .vectorize(v302i)
360  .reorder(v302i, v302);
361  updated_head1_bias.update(2)
362  .split(v302, v302, v302i, 8, TailStrategy::GuardWithIf)
363  .vectorize(v302i)
364  .reorder(v302i, v302);
365  updated_head1_bias.update(3)
366  .split(v302, v302, v302i, 8, TailStrategy::GuardWithIf)
367  .vectorize(v302i)
368  .reorder(v302i, v302);
369  head1_bias_im_0_d_def__
370  .split(v14, v14, v14i, 8, TailStrategy::GuardWithIf)
371  .vectorize(v14i)
372  .compute_root()
373  .reorder(v14i, v14);
374  head1_bias_im_0_d_def__.update(0)
375  .split(v14, v14, v14i, 8, TailStrategy::GuardWithIf)
376  .vectorize(v14i)
377  .reorder(v14i, v14, r1594_x);
378  updated_head1_filter
379  .split(v299, v299, v299i, 5, TailStrategy::GuardWithIf)
380  .split(v301, v301, v301i, 2, TailStrategy::GuardWithIf)
381  .split(v298, v298, v298i, 8, TailStrategy::GuardWithIf)
382  .unroll(v299i)
383  .unroll(v301i)
384  .vectorize(v298i)
385  .compute_root()
386  .reorder(v298i, v298, v299i, v301i, v299, v300, v301)
387  .fuse(v300, v301, v300)
388  .fuse(v299, v300, v299)
389  .serial(v299);
390  updated_head1_filter.update(0)
391  .split(v299, v299, v299i, 5, TailStrategy::GuardWithIf)
392  .split(v298, v298, v298i, 8, TailStrategy::GuardWithIf)
393  .vectorize(v298i)
394  .reorder(v298i, v298, v299i, v299, v300)
395  .fuse(v299, v300, v299)
396  .serial(v299);
397  updated_head1_filter.update(1)
398  .split(v299, v299, v299i, 5, TailStrategy::GuardWithIf)
399  .split(v298, v298, v298i, 8, TailStrategy::GuardWithIf)
400  .vectorize(v298i)
401  .reorder(v298i, v298, v299i, v299, v300)
402  .fuse(v299, v300, v299)
403  .serial(v299);
404  updated_head1_filter.update(2)
405  .split(v299, v299, v299i, 5, TailStrategy::GuardWithIf)
406  .split(v298, v298, v298i, 8, TailStrategy::GuardWithIf)
407  .vectorize(v298i)
408  .reorder(v298i, v298, v299i, v299, v300)
409  .fuse(v299, v300, v299)
410  .serial(v299);
411  updated_head1_filter.update(3)
412  .split(v299, v299, v299i, 5, TailStrategy::GuardWithIf)
413  .split(v298, v298, v298i, 8, TailStrategy::GuardWithIf)
414  .vectorize(v298i)
415  .reorder(v298i, v298, v299i, v299, v300)
416  .fuse(v299, v300, v299)
417  .serial(v299);
418  squashed_head1_filter_0_d_def__
420  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
421  .vectorize(ci)
422  .compute_at(updated_head1_filter, v298)
423  .reorder(ci, c, s, n);
424  squashed_head1_filter_0_d_def__.update(0)
425  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
426  .vectorize(ci)
427  .reorder(ci, c, s, n, r1614_x);
428  squashed_head1_filter_broadcast_0_d_def__
430  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
431  .vectorize(ci)
432  .compute_at(updated_head1_filter, v299i)
433  .store_at(updated_head1_filter, v299)
434  .reorder(ci, c, w, s, n);
435  head1_conv_1_d_def__
436  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
437  .vectorize(ci)
438  .compute_root()
439  .reorder(ci, c, w)
440  .serial(w);
441  head1_conv_1_d_def__.update(0)
442  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
443  .vectorize(ci)
444  .reorder(ci, c, r1491_x, w)
445  .serial(w);
446  conv1_stage1_1_d_def__
447  .split(w, w, wi, 8, TailStrategy::GuardWithIf)
448  .vectorize(wi)
449  .compute_root()
450  .reorder(wi, w, c)
451  .fuse(w, c, w)
452  .serial(w)
453  .reorder_storage(w, c);
454  conv1_stage1_1_d_def__.update(0)
455  .split(w, w, wi, 8, TailStrategy::GuardWithIf)
456  .vectorize(wi)
457  .reorder(wi, r1336_x, w, c)
458  .fuse(w, c, w)
459  .serial(w);
460  conv1_stage2_1_d_def__
461  .split(c, c, ci, 14, TailStrategy::GuardWithIf)
462  .split(n, n, ni, 32, TailStrategy::GuardWithIf)
463  .split(ni, ni, nii, 4)
464  .split(nii, nii, niii, 2)
465  .split(w, w, wi, 8, TailStrategy::GuardWithIf)
466  .vectorize(wi)
467  .compute_root()
468  .reorder(wi, w, ci, niii, nii, ni, c, n)
469  .fuse(c, n, c)
470  .parallel(c)
471  .reorder_storage(w, c, n);
472  sum_1_d_def__
473  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
474  .unroll(n)
475  .vectorize(ni)
476  .compute_at(conv1_stage2_1_d_def__, c)
477  .reorder(ni, n);
478 
479  relu1_0_d_def__.in()
481  .split(c, c, ci, 8, TailStrategy::GuardWithIf)
482  .vectorize(ci)
483  .compute_at(conv1_stage2_1_d_def__, nii)
484  .reorder(ci, c, w, n);
485  relu1_0_d_def__.compute_at(relu1_0_d_def__.in(), w);
486 
487  cost_per_stage_output_0_d_def__
489  .split(w, w, wi, 8, TailStrategy::GuardWithIf)
490  .vectorize(wi)
491  .compute_at(conv1_stage2_1_d_def__, nii)
492  .store_at(conv1_stage2_1_d_def__, ni)
493  .reorder(wi, w, n)
494  .reorder_storage(w, n);
495  f1_1_d_def__
496  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
497  .unroll(n)
498  .vectorize(ni)
499  .compute_at(conv1_stage2_1_d_def__, c)
500  .reorder(ni, n);
501  adjoint
502  .compute_root();
503  f1
504  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
505  .vectorize(ni)
506  .compute_root()
507  .reorder(ni, n)
508  .serial(n);
509  f1.update(0)
510  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
511  .vectorize(ni)
512  .reorder(ni, r81_x, n)
513  .serial(n);
514  cost_per_stage_output
515  .split(n, n, ni, 128, TailStrategy::GuardWithIf)
516  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
517  .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
518  .vectorize(nii)
519  .compute_root()
520  .reorder(nii, ni, wi, n, w)
521  .fuse(n, w, n)
522  .serial(n);
523  conv1_stage2
524  .split(n, n, ni, 512, TailStrategy::GuardWithIf)
525  .split(c, c, ci, 10, TailStrategy::GuardWithIf)
526  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
527  .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
528  .vectorize(nii)
529  .compute_root()
530  .reorder(nii, ni, ci, wi, n, c, w)
531  .fuse(c, w, c)
532  .fuse(n, c, n)
533  .serial(n)
534  .reorder_storage(n, c, w);
535  conv1_stage2.update(0)
536  .split(n, n, ni, 512, TailStrategy::GuardWithIf)
537  .split(c, c, ci, 10, TailStrategy::GuardWithIf)
538  .split(w, w, wi, 2, TailStrategy::GuardWithIf)
539  .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
540  .vectorize(nii)
541  .reorder(nii, r63_x, ni, ci, wi, n, c, w)
542  .fuse(c, w, c)
543  .fuse(n, c, n)
544  .serial(n);
545  head2_relu
546  .split(c, c, ci, 3, TailStrategy::GuardWithIf)
547  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
548  .vectorize(ni)
549  .compute_root()
550  .reorder(ni, n, ci, c, w)
551  .fuse(c, w, c)
552  .serial(c)
553  .reorder_storage(n, c, w);
554  head2_conv
555  .split(n, n, ni, 512, TailStrategy::GuardWithIf)
556  .split(c, c, ci, 6, TailStrategy::GuardWithIf)
557  .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
558  .vectorize(nii)
559  .compute_root()
560  .reorder(nii, ni, ci, n, c, w)
561  .fuse(c, w, c)
562  .fuse(n, c, n)
563  .serial(n)
564  .reorder_storage(n, c, w);
565  head2_conv.update(0)
566  .split(n, n, ni, 512, TailStrategy::GuardWithIf)
567  .split(c, c, ci, 6, TailStrategy::GuardWithIf)
568  .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
569  .vectorize(nii)
570  .reorder(nii, r40_x, ni, ci, n, c, w)
571  .fuse(c, w, c)
572  .fuse(n, c, n)
573  .parallel(n);
574  normalized_schedule_features
575  .split(c, c, ci, 11, TailStrategy::GuardWithIf)
576  .split(n, n, ni, 8, TailStrategy::GuardWithIf)
577  .vectorize(ni)
578  .compute_root()
579  .reorder(ni, n, ci, c, s)
580  .fuse(c, s, c)
581  .serial(c);
582  conv1_stage1
583  .split(w, w, wi, 8, TailStrategy::GuardWithIf)
584  .vectorize(wi)
585  .compute_root()
586  .reorder(wi, w, c)
587  .fuse(w, c, w)
588  .serial(w)
589  .reorder_storage(w, c);
590  conv1_stage1.update(0)
591  .split(w, w, wi, 8, TailStrategy::GuardWithIf)
592  .vectorize(wi)
593  .reorder(wi, r54_x, w, c)
594  .fuse(w, c, w)
595  .serial(w);
596  head1_conv
597  .split(w, w, wi, 8, TailStrategy::GuardWithIf)
598  .vectorize(wi)
599  .compute_root()
600  .reorder(wi, w, c)
601  .fuse(w, c, w)
602  .serial(w)
603  .reorder_storage(w, c);
604  head1_conv.update(0)
605  .split(w, w, wi, 8, TailStrategy::GuardWithIf)
606  .vectorize(wi)
607  .reorder(wi, r31_x, r31_y, w, c)
608  .fuse(w, c, w)
609  .serial(w);
610  squashed_head1_filter
611  .split(s, s, si, 8, TailStrategy::GuardWithIf)
612  .vectorize(si)
613  .compute_at(head1_conv, w)
614  .reorder(si, s, c, n)
615  .reorder_storage(s, c, n);
616 }
Halide::Stage::parallel
Stage & parallel(const VarOrRVar &var)
Halide::sum
Expr sum(Expr, const std::string &s="sum")
An inline reduction.
Halide::Stage::reorder
Stage & reorder(const std::vector< VarOrRVar > &vars)
Halide::Var
A Halide variable, to be used when defining functions.
Definition: Var.h:19
Halide::TailStrategy::GuardWithIf
@ GuardWithIf
Guard the inner loop with an if statement that prevents evaluation beyond the original extent.
Halide::Func::reorder
Func & reorder(const std::vector< VarOrRVar > &vars)
Reorder variables to have the given nesting order, from innermost out.
Halide::Func::store_at
Func & store_at(const Func &f, const Var &var)
Allocate storage for this function within f's loop over var.
Halide::Func::vectorize
Func & vectorize(const VarOrRVar &var)
Mark a dimension to be computed all-at-once as a single vector.
Halide::Func::compute_at
Func & compute_at(const Func &f, const Var &var)
Compute this function as needed for each unique value of the given var for the given calling function...
Halide::Func::fuse
Func & fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRVar &fused)
Join two dimensions into a single fused dimension.
Halide::TailStrategy
TailStrategy
Different ways to handle a tail case in a split when the factor does not provably divide the extent.
Definition: Schedule.h:32
Halide::Pipeline
A class representing a Halide pipeline.
Definition: Pipeline.h:108
Halide
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Definition: AbstractGenerator.h:19
Halide::Internal::StageSchedule::dims
const std::vector< Dim > & dims() const
The list and ordering of dimensions used to evaluate this function, after all splits have taken place...
Halide::Func::update
Stage update(int idx=0)
Get a handle on an update step for the purposes of scheduling it.
Halide::Func::serial
Func & serial(const VarOrRVar &var)
Mark a dimension to be traversed serially.
Halide::Stage::unroll
Stage & unroll(const VarOrRVar &var)
Halide::Func::parallel
Func & parallel(const VarOrRVar &var)
Mark a dimension to be traversed in parallel.
Halide::Stage::vectorize
Stage & vectorize(const VarOrRVar &var)
Halide::Func::in
Func in(const Func &f)
Creates and returns a new identity Func that wraps this Func.
Halide::Func::store_in
Func & store_in(MemoryType memory_type)
Set the type of memory this Func should be stored in.
Halide::Func
A halide function.
Definition: Func.h:687
Halide::Func::compute_root
Func & compute_root()
Compute all of this function once ahead of time.
Halide::Func::unroll
Func & unroll(const VarOrRVar &var)
Mark a dimension to be completely unrolled.
Halide::RVar
A reduction variable represents a single dimension of a reduction domain (RDom).
Definition: RDom.h:29
Halide::Func::get_schedule
const Internal::StageSchedule & get_schedule() const
Return the current StageSchedule associated with this initial Stage of this Func.
Definition: Func.h:2474
Halide::MemoryType
MemoryType
An enum describing different address spaces to be used with Func::store_in.
Definition: Expr.h:347
Halide::MemoryType::Stack
@ Stack
Stack memory.
do_cost_model_schedule
void do_cost_model_schedule(Halide::Pipeline pipeline)
Definition: cost_model_schedule.h:5
Halide::Pipeline::get_func
Func get_func(size_t index)
Return handle to the index-th Func within the pipeline based on the topological order.
Halide::Func::split
Func & split(const VarOrRVar &old, const VarOrRVar &outer, const VarOrRVar &inner, const Expr &factor, TailStrategy tail=TailStrategy::Auto)
Split a dimension into inner and outer subdimensions with the given names, where the inner dimension ...
Halide::Stage::get_schedule
const Internal::StageSchedule & get_schedule() const
Return the current StageSchedule associated with this Stage.
Definition: Func.h:107