Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
cost_model_schedule.h
Go to the documentation of this file.
1#include "Halide.h"
2
3using namespace Halide;
4
6 // Generated by autoscheduler, manually remove unrolls.
7 // Also manually replaced all RoundUp and ShiftInwards with GuardWithIf.
8
9 using ::Halide::Func;
10 using ::Halide::MemoryType;
11 using ::Halide::RVar;
12 using ::Halide::TailStrategy;
13 using ::Halide::Var;
14 Func loss_output = pipeline.get_func(55);
15 Func sum_1 = pipeline.get_func(54);
16 Func f2 = pipeline.get_func(53);
17 Func sum = pipeline.get_func(52);
18 Func prediction_output = pipeline.get_func(51);
19 Func updated_bias1 = pipeline.get_func(50);
20 Func bias1_im_0_d_def__ = pipeline.get_func(49);
21 Func conv1_stage1_0_d_def___1 = pipeline.get_func(48);
22 Func updated_filter1 = pipeline.get_func(47);
23 Func filter1_im_0_d_def__ = pipeline.get_func(46);
24 Func updated_head2_bias = pipeline.get_func(45);
25 Func head2_bias_im_0_d_def__ = pipeline.get_func(44);
26 Func head2_conv_0_d_def___1 = pipeline.get_func(43);
27 Func updated_head2_filter = pipeline.get_func(42);
28 Func head2_filter_im_0_d_def__ = pipeline.get_func(41);
29 Func head2_conv_1_d_def__ = pipeline.get_func(40);
30 Func head2_relu_0_d_def__ = pipeline.get_func(39);
31 Func updated_head1_bias = pipeline.get_func(38);
32 Func head1_bias_im_0_d_def__ = pipeline.get_func(37);
33 Func head1_conv_0_d_def___1 = pipeline.get_func(36);
34 Func updated_head1_filter = pipeline.get_func(35);
35 Func head1_filter_im_0_d_def__ = pipeline.get_func(34);
36 Func squashed_head1_filter_0_d_def__ = pipeline.get_func(33);
37 Func squashed_head1_filter_broadcast_0_d_def__ = pipeline.get_func(32);
38 Func head1_conv_1_d_def__ = pipeline.get_func(31);
39 Func conv1_stage1_1_d_def__ = pipeline.get_func(30);
40 Func conv1_stage2_0_d_def___1 = pipeline.get_func(29);
41 Func conv1_stage2_1_d_def__ = pipeline.get_func(28);
42 Func sum_1_d_def__ = pipeline.get_func(27);
43 Func relu1_0_d_def__ = pipeline.get_func(26);
44 Func f0_0_d_def__ = pipeline.get_func(25);
45 Func f1_1_d_def__ = pipeline.get_func(24);
46 Func f2_0_d_def__ = pipeline.get_func(22);
47 Func sum_1_1_d_def__ = pipeline.get_func(21);
48 Func loss_output_0_d_def__ = pipeline.get_func(20);
49 Func adjoint = pipeline.get_func(19);
50 Func f1 = pipeline.get_func(18);
51 Func f0 = pipeline.get_func(17);
52 Func relu1 = pipeline.get_func(16);
53 Func conv1_stage2 = pipeline.get_func(15);
54 Func head2_relu = pipeline.get_func(14);
55 Func head2_conv = pipeline.get_func(13);
56 Func normalized_schedule_features = pipeline.get_func(12);
57 Func conv1_stage1 = pipeline.get_func(8);
58 Func head1_conv = pipeline.get_func(7);
59 Func squashed_head1_filter_broadcast = pipeline.get_func(6);
60 Func squashed_head1_filter = pipeline.get_func(5);
61 Var c(head2_conv_0_d_def___1.get_schedule().dims()[0].var);
62 Var ci("ci");
63 Var n(sum.get_schedule().dims()[0].var);
64 Var ni("ni");
65 Var nii("nii");
66 Var r1010_z(filter1_im_0_d_def__.update(0).get_schedule().dims()[2].var);
67 Var r1207_y(filter1_im_0_d_def__.update(1).get_schedule().dims()[1].var);
68 Var s(squashed_head1_filter_0_d_def__.get_schedule().dims()[1].var);
69 Var si("si");
70 Var v12(head2_bias_im_0_d_def__.get_schedule().dims()[0].var);
71 Var v12i("v12i");
72 Var v13(head2_filter_im_0_d_def__.get_schedule().dims()[0].var);
73 Var v13i("v13i");
74 Var v14(head2_filter_im_0_d_def__.get_schedule().dims()[1].var);
75 Var v2(bias1_im_0_d_def__.get_schedule().dims()[0].var);
76 Var v207(updated_head1_filter.get_schedule().dims()[0].var);
77 Var v207i("v207i");
78 Var v208(updated_head1_filter.get_schedule().dims()[1].var);
79 Var v208i("v208i");
80 Var v209(updated_head1_filter.get_schedule().dims()[2].var);
81 Var v209i("v209i");
82 Var v210(updated_head1_filter.get_schedule().dims()[3].var);
83 Var v210i("v210i");
84 Var v211(updated_head1_bias.get_schedule().dims()[0].var);
85 Var v211i("v211i");
86 Var v212(updated_head1_bias.get_schedule().dims()[1].var);
87 Var v213(updated_head2_filter.get_schedule().dims()[0].var);
88 Var v213i("v213i");
89 Var v214(updated_head2_filter.get_schedule().dims()[1].var);
90 Var v214i("v214i");
91 Var v215(updated_head2_filter.get_schedule().dims()[2].var);
92 Var v215i("v215i");
93 Var v216(updated_head2_bias.get_schedule().dims()[0].var);
94 Var v216i("v216i");
95 Var v217(updated_head2_bias.get_schedule().dims()[1].var);
96 Var v218(updated_filter1.get_schedule().dims()[0].var);
97 Var v218i("v218i");
98 Var v218ii("v218ii");
99 Var v219(updated_filter1.get_schedule().dims()[1].var);
100 Var v219i("v219i");
101 Var v220(updated_filter1.get_schedule().dims()[2].var);
102 Var v220i("v220i");
103 Var v221(updated_bias1.get_schedule().dims()[0].var);
104 Var v221i("v221i");
105 Var v222(updated_bias1.get_schedule().dims()[1].var);
106 Var v2i("v2i");
107 Var v3(filter1_im_0_d_def__.get_schedule().dims()[0].var);
108 Var v4(filter1_im_0_d_def__.get_schedule().dims()[1].var);
109 Var v4i("v4i");
110 Var v5(head1_bias_im_0_d_def__.get_schedule().dims()[0].var);
111 Var v5i("v5i");
112 Var w(head2_conv_0_d_def___1.get_schedule().dims()[1].var);
113 Var wi("wi");
114 Var wii("wii");
115 RVar r1010_x(filter1_im_0_d_def__.update(0).get_schedule().dims()[0].var);
116 RVar r1010_y(filter1_im_0_d_def__.update(0).get_schedule().dims()[1].var);
117 RVar r1029_x(conv1_stage1_1_d_def__.update(0).get_schedule().dims()[0].var);
118 RVar r1029_xi("r1029$xi");
119 RVar r1095_x(head2_filter_im_0_d_def__.update(0).get_schedule().dims()[0].var);
120 RVar r1095_y(head2_filter_im_0_d_def__.update(0).get_schedule().dims()[1].var);
121 RVar r1114_x(head2_bias_im_0_d_def__.update(0).get_schedule().dims()[0].var);
122 RVar r1114_y(head2_bias_im_0_d_def__.update(0).get_schedule().dims()[1].var);
123 RVar r1183_x(head1_conv_1_d_def__.update(0).get_schedule().dims()[0].var);
124 RVar r1207_x(filter1_im_0_d_def__.update(1).get_schedule().dims()[0].var);
125 RVar r1226_x(bias1_im_0_d_def__.update(0).get_schedule().dims()[0].var);
126 RVar r1302_x(head1_bias_im_0_d_def__.update(0).get_schedule().dims()[0].var);
127 RVar r1321_x(squashed_head1_filter_0_d_def__.update(0).get_schedule().dims()[0].var);
128 RVar r14_x(conv1_stage1.update(0).get_schedule().dims()[0].var);
129 RVar r19_x(conv1_stage2.update(0).get_schedule().dims()[0].var);
130 RVar r24_x(f1.update(0).get_schedule().dims()[0].var);
131 RVar r29_x(sum_1.update(0).get_schedule().dims()[0].var);
132 RVar r34_x(sum.update(0).get_schedule().dims()[0].var);
133 RVar r34_y(sum.update(0).get_schedule().dims()[1].var);
134 RVar r4_x(head1_conv.update(0).get_schedule().dims()[0].var);
135 RVar r4_y(head1_conv.update(0).get_schedule().dims()[1].var);
136 RVar r9_x(head2_conv.update(0).get_schedule().dims()[0].var);
137 RVar r986_x(head2_relu_0_d_def__.update(0).get_schedule().dims()[0].var);
138 loss_output
139 .compute_root();
140 sum_1
141 .compute_root();
142 sum_1.update(0);
143 sum
144 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
145 .vectorize(ni)
146 .compute_root()
147 .reorder(ni, n)
148 .parallel(n);
149 sum.update(0)
150 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
151 .vectorize(ni)
152 .reorder(ni, r34_x, r34_y, n)
153 .parallel(n);
154 prediction_output
155 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
156 .vectorize(ni)
157 .compute_root()
158 .reorder(ni, n)
159 .parallel(n);
160 updated_bias1
161 .split(v221, v221, v221i, 8, TailStrategy::GuardWithIf)
162 .vectorize(v221i)
163 .compute_root()
164 .reorder(v221i, v221, v222)
165 .fuse(v221, v222, v221)
166 .parallel(v221);
167 updated_bias1.update(0)
168 .split(v221, v221, v221i, 8, TailStrategy::GuardWithIf)
169 .vectorize(v221i)
170 .reorder(v221i, v221)
171 .parallel(v221);
172 updated_bias1.update(1)
173 .split(v221, v221, v221i, 8, TailStrategy::GuardWithIf)
174 .vectorize(v221i)
175 .reorder(v221i, v221)
176 .parallel(v221);
177 updated_bias1.update(2)
178 .split(v221, v221, v221i, 8, TailStrategy::GuardWithIf)
179 .vectorize(v221i)
180 .reorder(v221i, v221)
181 .parallel(v221);
182 updated_bias1.update(3)
183 .split(v221, v221, v221i, 8, TailStrategy::GuardWithIf)
184 .vectorize(v221i)
185 .reorder(v221i, v221)
186 .parallel(v221);
187 bias1_im_0_d_def__
188 .split(v2, v2, v2i, 8, TailStrategy::GuardWithIf)
189 .vectorize(v2i)
190 .compute_at(updated_bias1, v221)
191 .reorder(v2i, v2);
192 bias1_im_0_d_def__.update(0)
193 .split(v2, v2, v2i, 8, TailStrategy::GuardWithIf)
194 .vectorize(v2i)
195 .reorder(v2i, v2, r1226_x);
196 updated_filter1
197 .split(v218, v218, v218i, 16, TailStrategy::GuardWithIf)
198 .split(v219, v219, v219i, 2, TailStrategy::GuardWithIf)
199 .split(v220, v220, v220i, 2, TailStrategy::GuardWithIf)
200 .split(v218i, v218i, v218ii, 8, TailStrategy::GuardWithIf)
201 .vectorize(v218ii)
202 .compute_root()
203 .reorder(v218ii, v218i, v219i, v220i, v218, v219, v220)
204 .fuse(v219, v220, v219)
205 .fuse(v218, v219, v218)
206 .parallel(v218);
207 updated_filter1.update(0)
208 .split(v218, v218, v218i, 16, TailStrategy::GuardWithIf)
209 .split(v219, v219, v219i, 2, TailStrategy::GuardWithIf)
210 .split(v218i, v218i, v218ii, 8, TailStrategy::GuardWithIf)
211 .vectorize(v218ii)
212 .reorder(v218ii, v218i, v219i, v218, v219)
213 .fuse(v218, v219, v218)
214 .parallel(v218);
215 updated_filter1.update(1)
216 .split(v218, v218, v218i, 16, TailStrategy::GuardWithIf)
217 .split(v219, v219, v219i, 2, TailStrategy::GuardWithIf)
218 .split(v218i, v218i, v218ii, 8, TailStrategy::GuardWithIf)
219 .vectorize(v218ii)
220 .reorder(v218ii, v218i, v219i, v218, v219)
221 .fuse(v218, v219, v218)
222 .parallel(v218);
223 updated_filter1.update(2)
224 .split(v218, v218, v218i, 16, TailStrategy::GuardWithIf)
225 .split(v219, v219, v219i, 2, TailStrategy::GuardWithIf)
226 .split(v218i, v218i, v218ii, 8, TailStrategy::GuardWithIf)
227 .vectorize(v218ii)
228 .reorder(v218ii, v218i, v219i, v218, v219)
229 .fuse(v218, v219, v218)
230 .parallel(v218);
231 updated_filter1.update(3)
232 .split(v218, v218, v218i, 16, TailStrategy::GuardWithIf)
233 .split(v219, v219, v219i, 2, TailStrategy::GuardWithIf)
234 .split(v218i, v218i, v218ii, 8, TailStrategy::GuardWithIf)
235 .vectorize(v218ii)
236 .reorder(v218ii, v218i, v219i, v218, v219)
237 .fuse(v218, v219, v218)
238 .parallel(v218);
239 filter1_im_0_d_def__
240 .split(v4, v4, v4i, 8, TailStrategy::GuardWithIf)
241 .vectorize(v4i)
242 .compute_root()
243 .reorder(v4i, v4, v3)
244 .parallel(v3)
245 .reorder_storage(v4, v3);
246 filter1_im_0_d_def__.update(0)
247 .reorder(r1010_x, r1010_y, r1010_z, v3)
248 .parallel(v3);
249 filter1_im_0_d_def__.update(1)
250 .reorder(r1207_x, r1207_y, v3)
251 .parallel(v3);
252 updated_head2_bias
253 .split(v216, v216, v216i, 8, TailStrategy::GuardWithIf)
254 .vectorize(v216i)
255 .compute_root()
256 .reorder(v216i, v216, v217)
257 .fuse(v216, v217, v216)
258 .parallel(v216);
259 updated_head2_bias.update(0)
260 .split(v216, v216, v216i, 8, TailStrategy::GuardWithIf)
261 .vectorize(v216i)
262 .reorder(v216i, v216)
263 .parallel(v216);
264 updated_head2_bias.update(1)
265 .split(v216, v216, v216i, 8, TailStrategy::GuardWithIf)
266 .vectorize(v216i)
267 .reorder(v216i, v216)
268 .parallel(v216);
269 updated_head2_bias.update(2)
270 .split(v216, v216, v216i, 8, TailStrategy::GuardWithIf)
271 .vectorize(v216i)
272 .reorder(v216i, v216)
273 .parallel(v216);
274 updated_head2_bias.update(3)
275 .split(v216, v216, v216i, 8, TailStrategy::GuardWithIf)
276 .vectorize(v216i)
277 .reorder(v216i, v216)
278 .parallel(v216);
279 head2_bias_im_0_d_def__
280 .split(v12, v12, v12i, 8, TailStrategy::GuardWithIf)
281 .vectorize(v12i)
282 .compute_at(updated_head2_bias, v216)
283 .reorder(v12i, v12);
284 head2_bias_im_0_d_def__.update(0)
285 .split(v12, v12, v12i, 8, TailStrategy::GuardWithIf)
286 .vectorize(v12i)
287 .reorder(v12i, v12, r1114_x, r1114_y);
288 head2_conv_0_d_def___1
289 .store_in(MemoryType::Stack)
290 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
291 .vectorize(ci)
292 .compute_at(head2_bias_im_0_d_def__, v12)
293 .reorder(ci, c, w, n);
294 updated_head2_filter
295 .split(v213, v213, v213i, 8, TailStrategy::GuardWithIf)
296 .split(v214, v214, v214i, 2, TailStrategy::GuardWithIf)
297 .split(v215, v215, v215i, 2, TailStrategy::GuardWithIf)
298 .vectorize(v213i)
299 .compute_root()
300 .reorder(v213i, v214i, v215i, v213, v214, v215)
301 .fuse(v214, v215, v214)
302 .fuse(v213, v214, v213)
303 .parallel(v213);
304 updated_head2_filter.update(0)
305 .split(v213, v213, v213i, 8, TailStrategy::GuardWithIf)
306 .split(v214, v214, v214i, 2, TailStrategy::GuardWithIf)
307 .vectorize(v213i)
308 .reorder(v213i, v214i, v213, v214)
309 .fuse(v213, v214, v213)
310 .parallel(v213);
311 updated_head2_filter.update(1)
312 .split(v213, v213, v213i, 8, TailStrategy::GuardWithIf)
313 .split(v214, v214, v214i, 2, TailStrategy::GuardWithIf)
314 .vectorize(v213i)
315 .reorder(v213i, v214i, v213, v214)
316 .fuse(v213, v214, v213)
317 .parallel(v213);
318 updated_head2_filter.update(2)
319 .split(v213, v213, v213i, 8, TailStrategy::GuardWithIf)
320 .split(v214, v214, v214i, 2, TailStrategy::GuardWithIf)
321 .vectorize(v213i)
322 .reorder(v213i, v214i, v213, v214)
323 .fuse(v213, v214, v213)
324 .parallel(v213);
325 updated_head2_filter.update(3)
326 .split(v213, v213, v213i, 8, TailStrategy::GuardWithIf)
327 .split(v214, v214, v214i, 2, TailStrategy::GuardWithIf)
328 .vectorize(v213i)
329 .reorder(v213i, v214i, v213, v214)
330 .fuse(v213, v214, v213)
331 .parallel(v213);
332 head2_filter_im_0_d_def__
333 .store_in(MemoryType::Stack)
334 .split(v13, v13, v13i, 8, TailStrategy::GuardWithIf)
335 .vectorize(v13i)
336 .compute_at(updated_head2_filter, v214i)
337 .reorder(v13i, v13, v14);
338 head2_filter_im_0_d_def__.update(0)
339 .split(v13, v13, v13i, 8, TailStrategy::GuardWithIf)
340 .vectorize(v13i)
341 .reorder(v13i, v13, v14, r1095_x, r1095_y);
342 head2_conv_1_d_def__
343 .split(n, n, ni, 5, TailStrategy::GuardWithIf)
344 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
345 .vectorize(ci)
346 .compute_root()
347 .reorder(ci, ni, c, w, n)
348 .parallel(n);
349 head2_relu_0_d_def__
350 .store_in(MemoryType::Stack)
351 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
352 .vectorize(ci)
353 .compute_at(head2_conv_1_d_def__, c)
354 .reorder(ci, c, w, n);
355 head2_relu_0_d_def__.update(0)
356 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
357 .vectorize(ci)
358 .reorder(ci, c, w, n, r986_x);
359 updated_head1_bias
360 .split(v211, v211, v211i, 8, TailStrategy::GuardWithIf)
361 .vectorize(v211i)
362 .compute_root()
363 .reorder(v211i, v211, v212)
364 .parallel(v212);
365 updated_head1_bias.update(0)
366 .split(v211, v211, v211i, 8, TailStrategy::GuardWithIf)
367 .vectorize(v211i)
368 .reorder(v211i, v211);
369 updated_head1_bias.update(1)
370 .split(v211, v211, v211i, 8, TailStrategy::GuardWithIf)
371 .vectorize(v211i)
372 .reorder(v211i, v211);
373 updated_head1_bias.update(2)
374 .split(v211, v211, v211i, 8, TailStrategy::GuardWithIf)
375 .vectorize(v211i)
376 .reorder(v211i, v211);
377 updated_head1_bias.update(3)
378 .split(v211, v211, v211i, 8, TailStrategy::GuardWithIf)
379 .vectorize(v211i)
380 .reorder(v211i, v211);
381 head1_bias_im_0_d_def__
382 .split(v5, v5, v5i, 8, TailStrategy::GuardWithIf)
383 .vectorize(v5i)
384 .compute_root()
385 .reorder(v5i, v5);
386 head1_bias_im_0_d_def__.update(0)
387 .split(v5, v5, v5i, 8, TailStrategy::GuardWithIf)
388 .vectorize(v5i)
389 .reorder(v5i, v5, r1302_x);
390 updated_head1_filter
391 .split(v208, v208, v208i, 2, TailStrategy::GuardWithIf)
392 .split(v209, v209, v209i, 2, TailStrategy::GuardWithIf)
393 .split(v210, v210, v210i, 2, TailStrategy::GuardWithIf)
394 .split(v207, v207, v207i, 8, TailStrategy::GuardWithIf)
395 .vectorize(v207i)
396 .compute_root()
397 .reorder(v207i, v207, v208i, v209i, v210i, v208, v209, v210)
398 .fuse(v209, v210, v209)
399 .fuse(v208, v209, v208)
400 .parallel(v208);
401 updated_head1_filter.update(0)
402 .split(v208, v208, v208i, 2, TailStrategy::GuardWithIf)
403 .split(v209, v209, v209i, 2, TailStrategy::GuardWithIf)
404 .split(v207, v207, v207i, 8, TailStrategy::GuardWithIf)
405 .vectorize(v207i)
406 .reorder(v207i, v207, v208i, v209i, v208, v209)
407 .fuse(v208, v209, v208)
408 .parallel(v208);
409 updated_head1_filter.update(1)
410 .split(v208, v208, v208i, 2, TailStrategy::GuardWithIf)
411 .split(v209, v209, v209i, 2, TailStrategy::GuardWithIf)
412 .split(v207, v207, v207i, 8, TailStrategy::GuardWithIf)
413 .vectorize(v207i)
414 .reorder(v207i, v207, v208i, v209i, v208, v209)
415 .fuse(v208, v209, v208)
416 .parallel(v208);
417 updated_head1_filter.update(2)
418 .split(v208, v208, v208i, 2, TailStrategy::GuardWithIf)
419 .split(v209, v209, v209i, 2, TailStrategy::GuardWithIf)
420 .split(v207, v207, v207i, 8, TailStrategy::GuardWithIf)
421 .vectorize(v207i)
422 .reorder(v207i, v207, v208i, v209i, v208, v209)
423 .fuse(v208, v209, v208)
424 .parallel(v208);
425 updated_head1_filter.update(3)
426 .split(v208, v208, v208i, 2, TailStrategy::GuardWithIf)
427 .split(v209, v209, v209i, 2, TailStrategy::GuardWithIf)
428 .split(v207, v207, v207i, 8, TailStrategy::GuardWithIf)
429 .vectorize(v207i)
430 .reorder(v207i, v207, v208i, v209i, v208, v209)
431 .fuse(v208, v209, v208)
432 .parallel(v208);
433 squashed_head1_filter_0_d_def__
434 .store_in(MemoryType::Stack)
435 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
436 .vectorize(ci)
437 .compute_at(updated_head1_filter, v207)
438 .reorder(ci, c, s, n);
439 squashed_head1_filter_0_d_def__.update(0)
440 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
441 .vectorize(ci)
442 .reorder(ci, c, s, n, r1321_x);
443 head1_conv_1_d_def__
444 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
445 .vectorize(ci)
446 .compute_root()
447 .reorder(ci, c, w)
448 .parallel(w);
449 head1_conv_1_d_def__.update(0)
450 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
451 .vectorize(ci)
452 .reorder(ci, c, r1183_x, w)
453 .parallel(w);
454 conv1_stage1_1_d_def__
455 .split(w, w, wi, 8, TailStrategy::GuardWithIf)
456 .vectorize(wi)
457 .compute_root()
458 .reorder(wi, w, c)
459 .parallel(c)
460 .reorder_storage(w, c);
461 conv1_stage1_1_d_def__.update(0)
462 .split(r1029_x, r1029_x, r1029_xi, 2, TailStrategy::GuardWithIf)
463 .split(w, w, wi, 8, TailStrategy::GuardWithIf)
464 .vectorize(wi)
465 .reorder(wi, r1029_xi, r1029_x, w, c)
466 .parallel(c);
467 conv1_stage2_0_d_def___1
468 .store_in(MemoryType::Stack)
469 .split(w, w, wi, 8, TailStrategy::GuardWithIf)
470 .vectorize(wi)
471 .compute_at(conv1_stage1_1_d_def__, r1029_xi)
472 .store_at(conv1_stage1_1_d_def__, r1029_x)
473 .reorder(wi, w, c, n)
474 .reorder_storage(w, c, n);
475 conv1_stage2_1_d_def__
476 .split(c, c, ci, 2, TailStrategy::GuardWithIf)
477 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
478 .vectorize(ni)
479 .compute_root()
480 .reorder(ni, n, ci, w, c)
481 .parallel(c)
482 .reorder_storage(n, c, w);
483 sum_1_d_def__
484 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
485 .vectorize(ni)
486 .compute_root()
487 .reorder(ni, n)
488 .parallel(n);
489 relu1_0_d_def__
490 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
491 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
492 .vectorize(ni)
493 .compute_root()
494 .reorder(ni, c, wi, n, w)
495 .fuse(n, w, n)
496 .parallel(n)
497 .reorder_storage(n, c, w);
498 relu1_0_d_def__.update(0)
499 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
500 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
501 .vectorize(ni)
502 .reorder(ni, wi, n, w)
503 .fuse(n, w, n)
504 .parallel(n);
505 relu1_0_d_def__.update(1)
506 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
507 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
508 .vectorize(ni)
509 .reorder(ni, wi, n, w)
510 .fuse(n, w, n)
511 .parallel(n);
512 relu1_0_d_def__.update(2)
513 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
514 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
515 .vectorize(ni)
516 .reorder(ni, wi, n, w)
517 .fuse(n, w, n)
518 .parallel(n);
519 relu1_0_d_def__.update(3)
520 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
521 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
522 .vectorize(ni)
523 .reorder(ni, wi, n, w)
524 .fuse(n, w, n)
525 .parallel(n);
526 relu1_0_d_def__.update(4)
527 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
528 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
529 .vectorize(ni)
530 .reorder(ni, wi, n, w)
531 .fuse(n, w, n)
532 .parallel(n);
533 relu1_0_d_def__.update(5)
534 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
535 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
536 .vectorize(ni)
537 .reorder(ni, wi, n, w)
538 .fuse(n, w, n)
539 .parallel(n);
540 relu1_0_d_def__.update(6)
541 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
542 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
543 .vectorize(ni)
544 .reorder(ni, wi, n, w)
545 .fuse(n, w, n)
546 .parallel(n);
547 relu1_0_d_def__.update(7)
548 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
549 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
550 .vectorize(ni)
551 .reorder(ni, wi, n, w)
552 .fuse(n, w, n)
553 .parallel(n);
554 relu1_0_d_def__.update(8)
555 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
556 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
557 .vectorize(ni)
558 .reorder(ni, wi, n, w)
559 .fuse(n, w, n)
560 .parallel(n);
561 relu1_0_d_def__.update(9)
562 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
563 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
564 .vectorize(ni)
565 .reorder(ni, wi, n, w)
566 .fuse(n, w, n)
567 .parallel(n);
568 relu1_0_d_def__.update(10)
569 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
570 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
571 .vectorize(ni)
572 .reorder(ni, wi, n, w)
573 .fuse(n, w, n)
574 .parallel(n);
575 relu1_0_d_def__.update(11)
576 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
577 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
578 .vectorize(ni)
579 .reorder(ni, wi, n, w)
580 .fuse(n, w, n)
581 .parallel(n);
582 relu1_0_d_def__.update(12)
583 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
584 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
585 .vectorize(ni)
586 .reorder(ni, wi, n, w)
587 .fuse(n, w, n)
588 .parallel(n);
589 relu1_0_d_def__.update(13)
590 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
591 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
592 .vectorize(ni)
593 .reorder(ni, wi, n, w)
594 .fuse(n, w, n)
595 .parallel(n);
596 relu1_0_d_def__.update(14)
597 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
598 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
599 .vectorize(ni)
600 .reorder(ni, wi, n, w)
601 .fuse(n, w, n)
602 .parallel(n);
603 relu1_0_d_def__.update(15)
604 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
605 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
606 .vectorize(ni)
607 .reorder(ni, wi, n, w)
608 .fuse(n, w, n)
609 .parallel(n);
610 relu1_0_d_def__.update(16)
611 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
612 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
613 .vectorize(ni)
614 .reorder(ni, wi, n, w)
615 .fuse(n, w, n)
616 .parallel(n);
617 relu1_0_d_def__.update(17)
618 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
619 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
620 .vectorize(ni)
621 .reorder(ni, wi, n, w)
622 .fuse(n, w, n)
623 .parallel(n);
624 relu1_0_d_def__.update(18)
625 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
626 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
627 .vectorize(ni)
628 .reorder(ni, wi, n, w)
629 .fuse(n, w, n)
630 .parallel(n);
631 relu1_0_d_def__.update(19)
632 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
633 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
634 .vectorize(ni)
635 .reorder(ni, wi, n, w)
636 .fuse(n, w, n)
637 .parallel(n);
638 relu1_0_d_def__.update(20)
639 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
640 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
641 .vectorize(ni)
642 .reorder(ni, wi, n, w)
643 .fuse(n, w, n)
644 .parallel(n);
645 relu1_0_d_def__.update(21)
646 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
647 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
648 .vectorize(ni)
649 .reorder(ni, wi, n, w)
650 .fuse(n, w, n)
651 .parallel(n);
652 relu1_0_d_def__.update(22)
653 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
654 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
655 .vectorize(ni)
656 .reorder(ni, wi, n, w)
657 .fuse(n, w, n)
658 .parallel(n);
659 relu1_0_d_def__.update(23)
660 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
661 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
662 .vectorize(ni)
663 .reorder(ni, wi, n, w)
664 .fuse(n, w, n)
665 .parallel(n);
666 relu1_0_d_def__.update(24)
667 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
668 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
669 .vectorize(ni)
670 .reorder(ni, wi, n, w)
671 .fuse(n, w, n)
672 .parallel(n);
673 relu1_0_d_def__.update(25)
674 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
675 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
676 .vectorize(ni)
677 .reorder(ni, wi, n, w)
678 .fuse(n, w, n)
679 .parallel(n);
680 relu1_0_d_def__.update(26)
681 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
682 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
683 .vectorize(ni)
684 .reorder(ni, wi, n, w)
685 .fuse(n, w, n)
686 .parallel(n);
687 relu1_0_d_def__.update(27)
688 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
689 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
690 .vectorize(ni)
691 .reorder(ni, wi, n, w)
692 .fuse(n, w, n)
693 .parallel(n);
694 relu1_0_d_def__.update(28)
695 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
696 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
697 .vectorize(ni)
698 .reorder(ni, wi, n, w)
699 .fuse(n, w, n)
700 .parallel(n);
701 relu1_0_d_def__.update(29)
702 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
703 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
704 .vectorize(ni)
705 .reorder(ni, wi, n, w)
706 .fuse(n, w, n)
707 .parallel(n);
708 relu1_0_d_def__.update(30)
709 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
710 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
711 .vectorize(ni)
712 .reorder(ni, wi, n, w)
713 .fuse(n, w, n)
714 .parallel(n);
715 relu1_0_d_def__.update(31)
716 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
717 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
718 .vectorize(ni)
719 .reorder(ni, wi, n, w)
720 .fuse(n, w, n)
721 .parallel(n);
722 f0_0_d_def__
723 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
724 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
725 .vectorize(ni)
726 .compute_root()
727 .reorder(ni, wi, n, w)
728 .fuse(n, w, n)
729 .parallel(n);
730 f1_1_d_def__
731 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
732 .vectorize(ni)
733 .compute_at(f0_0_d_def__, n)
734 .reorder(ni, n);
735 sum_1_1_d_def__
736 .compute_root();
737 f1
738 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
739 .vectorize(ni)
740 .compute_root()
741 .reorder(ni, n)
742 .parallel(n);
743 f1.update(0)
744 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
745 .vectorize(ni)
746 .reorder(ni, r24_x, n)
747 .parallel(n);
748 conv1_stage2
749 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
750 .split(w, w, wi, 4, TailStrategy::GuardWithIf)
751 .split(wi, wi, wii, 2, TailStrategy::GuardWithIf)
752 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
753 .vectorize(ni)
754 .compute_root()
755 .reorder(ni, n, ci, wii, wi, c, w)
756 .fuse(c, w, c)
757 .parallel(c)
758 .reorder_storage(n, c, w);
759 conv1_stage2.update(0)
760 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
761 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
762 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
763 .vectorize(ni)
764 .reorder(ni, r19_x, n, ci, wi, c, w)
765 .fuse(c, w, c)
766 .parallel(c);
767 head2_relu
768 .split(c, c, ci, 3, TailStrategy::GuardWithIf)
769 .split(w, w, wi, 7, TailStrategy::GuardWithIf)
770 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
771 .vectorize(ni)
772 .compute_root()
773 .reorder(ni, n, ci, wi, c, w)
774 .fuse(c, w, c)
775 .parallel(c)
776 .reorder_storage(n, c, w);
777 head2_conv
778 .split(n, n, ni, 40, TailStrategy::GuardWithIf)
779 .split(c, c, ci, 12, TailStrategy::GuardWithIf)
780 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
781 .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
782 .vectorize(nii)
783 .compute_root()
784 .reorder(nii, ni, ci, wi, n, c, w)
785 .fuse(c, w, c)
786 .fuse(n, c, n)
787 .parallel(n)
788 .reorder_storage(n, c, w);
789 head2_conv.update(0)
790 .split(n, n, ni, 40, TailStrategy::GuardWithIf)
791 .split(c, c, ci, 12, TailStrategy::GuardWithIf)
792 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
793 .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
794 .vectorize(nii)
795 .reorder(nii, r9_x, ni, ci, wi, n, c, w)
796 .fuse(c, w, c)
797 .fuse(n, c, n)
798 .parallel(n);
799 normalized_schedule_features
800 .split(c, c, ci, 5, TailStrategy::GuardWithIf)
801 .split(s, s, si, 7, TailStrategy::GuardWithIf)
802 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
803 .vectorize(ni)
804 .compute_root()
805 .reorder(ni, n, ci, si, c, s)
806 .fuse(c, s, c)
807 .parallel(c);
808 conv1_stage1
809 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
810 .vectorize(ci)
811 .compute_at(conv1_stage2, c)
812 .reorder(ci, c, w);
813 conv1_stage1.update(0)
814 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
815 .vectorize(ci)
816 .reorder(ci, c, w, r14_x);
817 head1_conv
818 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
819 .vectorize(ci)
820 .compute_root()
821 .reorder(ci, c, w)
822 .parallel(w);
823 head1_conv.update(0)
824 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
825 .vectorize(ci)
826 .reorder(ci, c, r4_x, r4_y, w)
827 .parallel(w);
828 squashed_head1_filter_broadcast
829 .store_in(MemoryType::Stack)
830 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
831 .vectorize(ci)
832 .compute_at(head1_conv, c)
833 .reorder(ci, c, w, s, n);
834 squashed_head1_filter
835 .split(s, s, si, 10, TailStrategy::GuardWithIf)
836 .split(n, n, ni, 2, TailStrategy::GuardWithIf)
837 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
838 .vectorize(ci)
839 .compute_root()
840 .reorder(ci, c, si, ni, s, n)
841 .fuse(s, n, s)
842 .parallel(s);
843}
void do_cost_model_schedule(Halide::Pipeline pipeline)
A halide function.
Definition Func.h:700
Func & store_in(MemoryType memory_type)
Set the type of memory this Func should be stored in.
Func & reorder(const std::vector< VarOrRVar > &vars)
Reorder variables to have the given nesting order, from innermost out.
Stage update(int idx=0)
Get a handle on an update step for the purposes of scheduling it.
Func & split(const VarOrRVar &old, const VarOrRVar &outer, const VarOrRVar &inner, const Expr &factor, TailStrategy tail=TailStrategy::Auto)
Split a dimension into inner and outer subdimensions with the given names, where the inner dimension ...
Func & compute_root()
Compute all of this function once ahead of time.
Func & store_at(const Func &f, const Var &var)
Allocate storage for this function within f's loop over var.
Func & parallel(const VarOrRVar &var)
Mark a dimension to be traversed in parallel.
Func & vectorize(const VarOrRVar &var)
Mark a dimension to be computed all-at-once as a single vector.
Func & fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRVar &fused)
Join two dimensions into a single fused dimension.
Func & compute_at(const Func &f, const Var &var)
Compute this function as needed for each unique value of the given var for the given calling function...
const Internal::StageSchedule & get_schedule() const
Return the current StageSchedule associated with this initial Stage of this Func.
Definition Func.h:2608
const std::vector< Dim > & dims() const
The list and ordering of dimensions used to evaluate this function, after all splits have taken place...
A class representing a Halide pipeline.
Definition Pipeline.h:107
Func get_func(size_t index)
Return handle to the index-th Func within the pipeline based on the topological order.
A reduction variable represents a single dimension of a reduction domain (RDom).
Definition RDom.h:29
Stage & reorder(const std::vector< VarOrRVar > &vars)
Stage & parallel(const VarOrRVar &var)
const Internal::StageSchedule & get_schedule() const
Return the current StageSchedule associated with this Stage.
Definition Func.h:106
A Halide variable, to be used when defining functions.
Definition Var.h:19
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Expr sum(Expr, const std::string &s="sum")
An inline reduction.