Halide 19.0.0
Halide compiler and libraries
Loading...
Searching...
No Matches
cost_model_schedule.h
Go to the documentation of this file.
1#include "Halide.h"
2
3using namespace Halide;
4
6 // Generated by autoscheduler, manually remove unrolls.
7 // Also manually replaced all RoundUp and ShiftInwards with GuardWithIf.
8
9 using ::Halide::Func;
10 using ::Halide::MemoryType;
11 using ::Halide::RVar;
12 using ::Halide::TailStrategy;
13 using ::Halide::Var;
14 Func loss_output = pipeline.get_func(57);
15 Func sum_1 = pipeline.get_func(56);
16 Func f2 = pipeline.get_func(55);
17 Func sum = pipeline.get_func(54);
18 Func prediction_output = pipeline.get_func(53);
19 Func updated_bias1 = pipeline.get_func(52);
20 Func bias1_im_0_d_def__ = pipeline.get_func(51);
21 Func conv1_stage1_0_d_def___1 = pipeline.get_func(50);
22 Func updated_filter1 = pipeline.get_func(49);
23 Func filter1_im_0_d_def__ = pipeline.get_func(48);
24 Func updated_head2_bias = pipeline.get_func(47);
25 Func head2_bias_im_0_d_def__ = pipeline.get_func(46);
26 Func head2_conv_0_d_def___1 = pipeline.get_func(45);
27 Func updated_head2_filter = pipeline.get_func(44);
28 Func head2_filter_im_0_d_def__ = pipeline.get_func(43);
29 Func head2_conv_1_d_def__ = pipeline.get_func(42);
30 Func head2_relu_0_d_def__ = pipeline.get_func(41);
31 Func updated_head1_bias = pipeline.get_func(40);
32 Func head1_bias_im_0_d_def__ = pipeline.get_func(39);
33 Func head1_conv_0_d_def___1 = pipeline.get_func(38);
34 Func updated_head1_filter = pipeline.get_func(37);
35 Func head1_filter_im_0_d_def__ = pipeline.get_func(36);
36 Func squashed_head1_filter_0_d_def__ = pipeline.get_func(35);
37 Func squashed_head1_filter_broadcast_0_d_def__ = pipeline.get_func(34);
38 Func head1_conv_1_d_def__ = pipeline.get_func(33);
39 Func conv1_stage1_1_d_def__ = pipeline.get_func(32);
40 Func conv1_stage2_0_d_def___1 = pipeline.get_func(31);
41 Func conv1_stage2_1_d_def__ = pipeline.get_func(30);
42 Func sum_1_d_def__ = pipeline.get_func(29);
43 Func relu1_0_d_def__ = pipeline.get_func(28);
44 Func f0_0_d_def__ = pipeline.get_func(27);
45 Func cost_per_stage_output_0_d_def__ = pipeline.get_func(26);
46 Func f1_1_d_def__ = pipeline.get_func(25);
47 Func f2_0_d_def__ = pipeline.get_func(23);
48 Func sum_1_1_d_def__ = pipeline.get_func(22);
49 Func loss_output_0_d_def__ = pipeline.get_func(21);
50 Func adjoint = pipeline.get_func(20);
51 Func f1 = pipeline.get_func(19);
52 Func cost_per_stage_output = pipeline.get_func(18);
53 Func f0 = pipeline.get_func(17);
54 Func relu1 = pipeline.get_func(16);
55 Func conv1_stage2 = pipeline.get_func(15);
56 Func head2_relu = pipeline.get_func(14);
57 Func head2_conv = pipeline.get_func(13);
58 Func normalized_schedule_features = pipeline.get_func(12);
59 Func conv1_stage1 = pipeline.get_func(8);
60 Func head1_conv = pipeline.get_func(7);
61 Func squashed_head1_filter_broadcast = pipeline.get_func(6);
62 Func squashed_head1_filter = pipeline.get_func(5);
63 Var c(head2_conv_0_d_def___1.get_schedule().dims()[0].var);
64 Var ci("ci");
65 Var n(sum.get_schedule().dims()[0].var);
66 Var ni("ni");
67 Var nii("nii");
68 Var niii("niii");
69 Var r1316_z(filter1_im_0_d_def__.update(0).get_schedule().dims()[2].var);
70 Var r1512_y(filter1_im_0_d_def__.update(1).get_schedule().dims()[1].var);
71 Var s(squashed_head1_filter_0_d_def__.get_schedule().dims()[1].var);
72 Var si("si");
73 Var v11(bias1_im_0_d_def__.get_schedule().dims()[0].var);
74 Var v11i("v11i");
75 Var v12(filter1_im_0_d_def__.get_schedule().dims()[0].var);
76 Var v13(filter1_im_0_d_def__.get_schedule().dims()[1].var);
77 Var v13i("v13i");
78 Var v14(head1_bias_im_0_d_def__.get_schedule().dims()[0].var);
79 Var v14i("v14i");
80 Var v21(head2_bias_im_0_d_def__.get_schedule().dims()[0].var);
81 Var v21i("v21i");
82 Var v22(head2_filter_im_0_d_def__.get_schedule().dims()[0].var);
83 Var v22i("v22i");
84 Var v23(head2_filter_im_0_d_def__.get_schedule().dims()[1].var);
85 Var v298(updated_head1_filter.get_schedule().dims()[0].var);
86 Var v298i("v298i");
87 Var v299(updated_head1_filter.get_schedule().dims()[1].var);
88 Var v299i("v299i");
89 Var v300(updated_head1_filter.get_schedule().dims()[2].var);
90 Var v301(updated_head1_filter.get_schedule().dims()[3].var);
91 Var v301i("v301i");
92 Var v302(updated_head1_bias.get_schedule().dims()[0].var);
93 Var v302i("v302i");
94 Var v303(updated_head1_bias.get_schedule().dims()[1].var);
95 Var v304(updated_head2_filter.get_schedule().dims()[0].var);
96 Var v304i("v304i");
97 Var v305(updated_head2_filter.get_schedule().dims()[1].var);
98 Var v306(updated_head2_filter.get_schedule().dims()[2].var);
99 Var v307(updated_head2_bias.get_schedule().dims()[0].var);
100 Var v307i("v307i");
101 Var v308(updated_head2_bias.get_schedule().dims()[1].var);
102 Var v309(updated_filter1.get_schedule().dims()[0].var);
103 Var v309i("v309i");
104 Var v310(updated_filter1.get_schedule().dims()[1].var);
105 Var v311(updated_filter1.get_schedule().dims()[2].var);
106 Var v312(updated_bias1.get_schedule().dims()[0].var);
107 Var v312i("v312i");
108 Var v313(updated_bias1.get_schedule().dims()[1].var);
109 Var w(head2_conv_0_d_def___1.get_schedule().dims()[1].var);
110 Var wi("wi");
111 RVar r1294_x(head2_relu_0_d_def__.update(0).get_schedule().dims()[0].var);
112 RVar r1316_x(filter1_im_0_d_def__.update(0).get_schedule().dims()[0].var);
113 RVar r1316_y(filter1_im_0_d_def__.update(0).get_schedule().dims()[1].var);
114 RVar r1336_x(conv1_stage1_1_d_def__.update(0).get_schedule().dims()[0].var);
115 RVar r1400_x(head2_filter_im_0_d_def__.update(0).get_schedule().dims()[0].var);
116 RVar r1400_y(head2_filter_im_0_d_def__.update(0).get_schedule().dims()[1].var);
117 RVar r1421_x(head2_bias_im_0_d_def__.update(0).get_schedule().dims()[0].var);
118 RVar r1421_y(head2_bias_im_0_d_def__.update(0).get_schedule().dims()[1].var);
119 RVar r1491_x(head1_conv_1_d_def__.update(0).get_schedule().dims()[0].var);
120 RVar r1512_x(filter1_im_0_d_def__.update(1).get_schedule().dims()[0].var);
121 RVar r1532_x(bias1_im_0_d_def__.update(0).get_schedule().dims()[0].var);
122 RVar r1594_x(head1_bias_im_0_d_def__.update(0).get_schedule().dims()[0].var);
123 RVar r1614_x(squashed_head1_filter_0_d_def__.update(0).get_schedule().dims()[0].var);
124 RVar r31_x(head1_conv.update(0).get_schedule().dims()[0].var);
125 RVar r31_y(head1_conv.update(0).get_schedule().dims()[1].var);
126 RVar r40_x(head2_conv.update(0).get_schedule().dims()[0].var);
127 RVar r54_x(conv1_stage1.update(0).get_schedule().dims()[0].var);
128 RVar r63_x(conv1_stage2.update(0).get_schedule().dims()[0].var);
129 RVar r81_x(f1.update(0).get_schedule().dims()[0].var);
130 RVar r89_x(sum_1.update(0).get_schedule().dims()[0].var);
131 RVar r94_x(sum.update(0).get_schedule().dims()[0].var);
132 RVar r94_y(sum.update(0).get_schedule().dims()[1].var);
133 loss_output
134 .compute_root();
135 sum_1
136 .compute_root();
137 sum_1.update(0);
138 sum
139 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
140 .vectorize(ni)
141 .compute_root()
142 .reorder(ni, n)
143 .serial(n);
144 sum.update(0)
145 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
146 .vectorize(ni)
147 .reorder(ni, r94_x, r94_y, n)
148 .serial(n);
149 prediction_output
150 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
151 .vectorize(ni)
152 .compute_root()
153 .reorder(ni, n)
154 .serial(n);
155 updated_bias1
156 .split(v312, v312, v312i, 8, TailStrategy::GuardWithIf)
157 .vectorize(v312i)
158 .compute_root()
159 .reorder(v312i, v312, v313)
160 .fuse(v312, v313, v312)
161 .serial(v312);
162 updated_bias1.update(0)
163 .split(v312, v312, v312i, 8, TailStrategy::GuardWithIf)
164 .vectorize(v312i)
165 .reorder(v312i, v312)
166 .serial(v312);
167 updated_bias1.update(1)
168 .split(v312, v312, v312i, 8, TailStrategy::GuardWithIf)
169 .vectorize(v312i)
170 .reorder(v312i, v312)
171 .serial(v312);
172 updated_bias1.update(2)
173 .split(v312, v312, v312i, 8, TailStrategy::GuardWithIf)
174 .vectorize(v312i)
175 .reorder(v312i, v312)
176 .serial(v312);
177 updated_bias1.update(3)
178 .split(v312, v312, v312i, 8, TailStrategy::GuardWithIf)
179 .vectorize(v312i)
180 .reorder(v312i, v312)
181 .serial(v312);
182 bias1_im_0_d_def__
183 .split(v11, v11, v11i, 8, TailStrategy::GuardWithIf)
184 .vectorize(v11i)
185 .compute_at(updated_bias1, v312)
186 .reorder(v11i, v11);
187 bias1_im_0_d_def__.update(0)
188 .split(v11, v11, v11i, 8, TailStrategy::GuardWithIf)
189 .vectorize(v11i)
190 .reorder(v11i, v11, r1532_x);
191 updated_filter1
192 .split(v309, v309, v309i, 8, TailStrategy::GuardWithIf)
193 .vectorize(v309i)
194 .compute_root()
195 .reorder(v309i, v311, v309, v310)
196 .fuse(v309, v310, v309)
197 .serial(v309);
198 updated_filter1.update(0)
199 .split(v309, v309, v309i, 8, TailStrategy::GuardWithIf)
200 .vectorize(v309i)
201 .reorder(v309i, v309, v310)
202 .fuse(v309, v310, v309)
203 .serial(v309);
204 updated_filter1.update(1)
205 .split(v309, v309, v309i, 8, TailStrategy::GuardWithIf)
206 .vectorize(v309i)
207 .reorder(v309i, v309, v310)
208 .fuse(v309, v310, v309)
209 .serial(v309);
210 updated_filter1.update(2)
211 .split(v309, v309, v309i, 8, TailStrategy::GuardWithIf)
212 .vectorize(v309i)
213 .reorder(v309i, v309, v310)
214 .fuse(v309, v310, v309)
215 .serial(v309);
216 updated_filter1.update(3)
217 .split(v309, v309, v309i, 8, TailStrategy::GuardWithIf)
218 .vectorize(v309i)
219 .reorder(v309i, v309, v310)
220 .fuse(v309, v310, v309)
221 .serial(v309);
222 filter1_im_0_d_def__
223 .split(v13, v13, v13i, 8, TailStrategy::GuardWithIf)
224 .vectorize(v13i)
225 .compute_root()
226 .reorder(v13i, v13, v12)
227 .fuse(v13, v12, v13)
228 .parallel(v13)
229 .reorder_storage(v13, v12);
230 filter1_im_0_d_def__.update(0)
231 .reorder(r1316_z, r1316_x, r1316_y, v12)
232 .vectorize(r1316_z, 8)
233 .unroll(r1316_z)
234 .parallel(v12);
235 filter1_im_0_d_def__.update(1)
236 .reorder(r1512_x, r1512_y, v12)
237 .vectorize(r1512_y)
238 .parallel(v12);
239 updated_head2_bias
240 .split(v307, v307, v307i, 8, TailStrategy::GuardWithIf)
241 .vectorize(v307i)
242 .compute_root()
243 .reorder(v307i, v307, v308)
244 .fuse(v307, v308, v307)
245 .serial(v307);
246 updated_head2_bias.update(0)
247 .split(v307, v307, v307i, 8, TailStrategy::GuardWithIf)
248 .vectorize(v307i)
249 .reorder(v307i, v307)
250 .serial(v307);
251 updated_head2_bias.update(1)
252 .split(v307, v307, v307i, 8, TailStrategy::GuardWithIf)
253 .vectorize(v307i)
254 .reorder(v307i, v307)
255 .serial(v307);
256 updated_head2_bias.update(2)
257 .split(v307, v307, v307i, 8, TailStrategy::GuardWithIf)
258 .vectorize(v307i)
259 .reorder(v307i, v307)
260 .serial(v307);
261 updated_head2_bias.update(3)
262 .split(v307, v307, v307i, 8, TailStrategy::GuardWithIf)
263 .vectorize(v307i)
264 .reorder(v307i, v307)
265 .serial(v307);
266 head2_bias_im_0_d_def__
267 .split(v21, v21, v21i, 8, TailStrategy::GuardWithIf)
268 .vectorize(v21i)
269 .compute_at(updated_head2_bias, v307)
270 .reorder(v21i, v21);
271 head2_bias_im_0_d_def__.update(0)
272 .split(v21, v21, v21i, 8, TailStrategy::GuardWithIf)
273 .vectorize(v21i)
274 .reorder(v21i, v21, r1421_x, r1421_y);
275 head2_conv_0_d_def___1
276 .store_in(MemoryType::Stack)
277 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
278 .vectorize(ci)
279 .compute_at(head2_bias_im_0_d_def__, v21)
280 .reorder(ci, c, w, n);
281 updated_head2_filter
282 .split(v304, v304, v304i, 8, TailStrategy::GuardWithIf)
283 .vectorize(v304i)
284 .compute_root()
285 .reorder(v304i, v306, v304, v305)
286 .fuse(v304, v305, v304)
287 .parallel(v304);
288 updated_head2_filter.update(0)
289 .split(v304, v304, v304i, 8, TailStrategy::GuardWithIf)
290 .vectorize(v304i)
291 .reorder(v304i, v304, v305)
292 .fuse(v304, v305, v304)
293 .parallel(v304);
294 updated_head2_filter.update(1)
295 .split(v304, v304, v304i, 8, TailStrategy::GuardWithIf)
296 .vectorize(v304i)
297 .reorder(v304i, v304, v305)
298 .fuse(v304, v305, v304)
299 .serial(v304);
300 updated_head2_filter.update(2)
301 .split(v304, v304, v304i, 8, TailStrategy::GuardWithIf)
302 .vectorize(v304i)
303 .reorder(v304i, v304, v305)
304 .fuse(v304, v305, v304)
305 .serial(v304);
306 updated_head2_filter.update(3)
307 .split(v304, v304, v304i, 8, TailStrategy::GuardWithIf)
308 .vectorize(v304i)
309 .reorder(v304i, v304, v305)
310 .fuse(v304, v305, v304)
311 .parallel(v304);
312 head2_filter_im_0_d_def__
313 .split(v22, v22, v22i, 8, TailStrategy::GuardWithIf)
314 .vectorize(v22i)
315 .compute_at(updated_head2_filter, v304)
316 .reorder(v22i, v22, v23);
317 head2_filter_im_0_d_def__.update(0)
318 .split(v22, v22, v22i, 8, TailStrategy::GuardWithIf)
319 .vectorize(v22i)
320 .reorder(v22i, v22, v23, r1400_x, r1400_y);
321 head2_conv_1_d_def__
322 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
323 .split(n, n, ni, 128, TailStrategy::GuardWithIf)
324 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
325 .split(ni, ni, nii, 4, TailStrategy::GuardWithIf)
326 .unroll(wi)
327 .unroll(nii)
328 .vectorize(ci)
329 .compute_root()
330 .reorder(ci, wi, nii, c, ni, w, n)
331 .fuse(w, n, w)
332 .serial(w);
333 head2_relu_0_d_def__
334 .store_in(MemoryType::Stack)
335 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
336 .unroll(w)
337 .unroll(n)
338 .vectorize(ci)
339 .compute_at(head2_conv_1_d_def__, c)
340 .reorder(ci, c, w, n);
341 head2_relu_0_d_def__.update(0)
342 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
343 .unroll(w)
344 .unroll(n)
345 .vectorize(ci)
346 .reorder(ci, c, w, n, r1294_x);
347 updated_head1_bias
348 .split(v302, v302, v302i, 8, TailStrategy::GuardWithIf)
349 .vectorize(v302i)
350 .compute_root()
351 .reorder(v302i, v302, v303)
352 .serial(v303);
353 updated_head1_bias.update(0)
354 .split(v302, v302, v302i, 8, TailStrategy::GuardWithIf)
355 .vectorize(v302i)
356 .reorder(v302i, v302);
357 updated_head1_bias.update(1)
358 .split(v302, v302, v302i, 8, TailStrategy::GuardWithIf)
359 .vectorize(v302i)
360 .reorder(v302i, v302);
361 updated_head1_bias.update(2)
362 .split(v302, v302, v302i, 8, TailStrategy::GuardWithIf)
363 .vectorize(v302i)
364 .reorder(v302i, v302);
365 updated_head1_bias.update(3)
366 .split(v302, v302, v302i, 8, TailStrategy::GuardWithIf)
367 .vectorize(v302i)
368 .reorder(v302i, v302);
369 head1_bias_im_0_d_def__
370 .split(v14, v14, v14i, 8, TailStrategy::GuardWithIf)
371 .vectorize(v14i)
372 .compute_root()
373 .reorder(v14i, v14);
374 head1_bias_im_0_d_def__.update(0)
375 .split(v14, v14, v14i, 8, TailStrategy::GuardWithIf)
376 .vectorize(v14i)
377 .reorder(v14i, v14, r1594_x);
378 updated_head1_filter
379 .split(v299, v299, v299i, 5, TailStrategy::GuardWithIf)
380 .split(v301, v301, v301i, 2, TailStrategy::GuardWithIf)
381 .split(v298, v298, v298i, 8, TailStrategy::GuardWithIf)
382 .unroll(v299i)
383 .unroll(v301i)
384 .vectorize(v298i)
385 .compute_root()
386 .reorder(v298i, v298, v299i, v301i, v299, v300, v301)
387 .fuse(v300, v301, v300)
388 .fuse(v299, v300, v299)
389 .serial(v299);
390 updated_head1_filter.update(0)
391 .split(v299, v299, v299i, 5, TailStrategy::GuardWithIf)
392 .split(v298, v298, v298i, 8, TailStrategy::GuardWithIf)
393 .vectorize(v298i)
394 .reorder(v298i, v298, v299i, v299, v300)
395 .fuse(v299, v300, v299)
396 .serial(v299);
397 updated_head1_filter.update(1)
398 .split(v299, v299, v299i, 5, TailStrategy::GuardWithIf)
399 .split(v298, v298, v298i, 8, TailStrategy::GuardWithIf)
400 .vectorize(v298i)
401 .reorder(v298i, v298, v299i, v299, v300)
402 .fuse(v299, v300, v299)
403 .serial(v299);
404 updated_head1_filter.update(2)
405 .split(v299, v299, v299i, 5, TailStrategy::GuardWithIf)
406 .split(v298, v298, v298i, 8, TailStrategy::GuardWithIf)
407 .vectorize(v298i)
408 .reorder(v298i, v298, v299i, v299, v300)
409 .fuse(v299, v300, v299)
410 .serial(v299);
411 updated_head1_filter.update(3)
412 .split(v299, v299, v299i, 5, TailStrategy::GuardWithIf)
413 .split(v298, v298, v298i, 8, TailStrategy::GuardWithIf)
414 .vectorize(v298i)
415 .reorder(v298i, v298, v299i, v299, v300)
416 .fuse(v299, v300, v299)
417 .serial(v299);
418 squashed_head1_filter_0_d_def__
419 .store_in(MemoryType::Stack)
420 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
421 .vectorize(ci)
422 .compute_at(updated_head1_filter, v298)
423 .reorder(ci, c, s, n);
424 squashed_head1_filter_0_d_def__.update(0)
425 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
426 .vectorize(ci)
427 .reorder(ci, c, s, n, r1614_x);
428 squashed_head1_filter_broadcast_0_d_def__
429 .store_in(MemoryType::Stack)
430 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
431 .vectorize(ci)
432 .compute_at(updated_head1_filter, v299i)
433 .store_at(updated_head1_filter, v299)
434 .reorder(ci, c, w, s, n);
435 head1_conv_1_d_def__
436 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
437 .vectorize(ci)
438 .compute_root()
439 .reorder(ci, c, w)
440 .serial(w);
441 head1_conv_1_d_def__.update(0)
442 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
443 .vectorize(ci)
444 .reorder(ci, c, r1491_x, w)
445 .serial(w);
446 conv1_stage1_1_d_def__
447 .split(w, w, wi, 8, TailStrategy::GuardWithIf)
448 .vectorize(wi)
449 .compute_root()
450 .reorder(wi, w, c)
451 .fuse(w, c, w)
452 .serial(w)
453 .reorder_storage(w, c);
454 conv1_stage1_1_d_def__.update(0)
455 .split(w, w, wi, 8, TailStrategy::GuardWithIf)
456 .vectorize(wi)
457 .reorder(wi, r1336_x, w, c)
458 .fuse(w, c, w)
459 .serial(w);
460 conv1_stage2_1_d_def__
461 .split(c, c, ci, 14, TailStrategy::GuardWithIf)
462 .split(n, n, ni, 32, TailStrategy::GuardWithIf)
463 .split(ni, ni, nii, 4)
464 .split(nii, nii, niii, 2)
465 .split(w, w, wi, 8, TailStrategy::GuardWithIf)
466 .vectorize(wi)
467 .compute_root()
468 .reorder(wi, w, ci, niii, nii, ni, c, n)
469 .fuse(c, n, c)
470 .parallel(c)
471 .reorder_storage(w, c, n);
472 sum_1_d_def__
473 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
474 .unroll(n)
475 .vectorize(ni)
476 .compute_at(conv1_stage2_1_d_def__, c)
477 .reorder(ni, n);
478
479 relu1_0_d_def__.in()
480 .store_in(MemoryType::Stack)
481 .split(c, c, ci, 8, TailStrategy::GuardWithIf)
482 .vectorize(ci)
483 .compute_at(conv1_stage2_1_d_def__, nii)
484 .reorder(ci, c, w, n);
485 relu1_0_d_def__.compute_at(relu1_0_d_def__.in(), w);
486
487 cost_per_stage_output_0_d_def__
488 .store_in(MemoryType::Stack)
489 .split(w, w, wi, 8, TailStrategy::GuardWithIf)
490 .vectorize(wi)
491 .compute_at(conv1_stage2_1_d_def__, nii)
492 .store_at(conv1_stage2_1_d_def__, ni)
493 .reorder(wi, w, n)
494 .reorder_storage(w, n);
495 f1_1_d_def__
496 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
497 .unroll(n)
498 .vectorize(ni)
499 .compute_at(conv1_stage2_1_d_def__, c)
500 .reorder(ni, n);
501 adjoint
502 .compute_root();
503 f1
504 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
505 .vectorize(ni)
506 .compute_root()
507 .reorder(ni, n)
508 .serial(n);
509 f1.update(0)
510 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
511 .vectorize(ni)
512 .reorder(ni, r81_x, n)
513 .serial(n);
514 cost_per_stage_output
515 .split(n, n, ni, 128, TailStrategy::GuardWithIf)
516 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
517 .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
518 .vectorize(nii)
519 .compute_root()
520 .reorder(nii, ni, wi, n, w)
521 .fuse(n, w, n)
522 .serial(n);
523 conv1_stage2
524 .split(n, n, ni, 512, TailStrategy::GuardWithIf)
525 .split(c, c, ci, 10, TailStrategy::GuardWithIf)
526 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
527 .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
528 .vectorize(nii)
529 .compute_root()
530 .reorder(nii, ni, ci, wi, n, c, w)
531 .fuse(c, w, c)
532 .fuse(n, c, n)
533 .serial(n)
534 .reorder_storage(n, c, w);
535 conv1_stage2.update(0)
536 .split(n, n, ni, 512, TailStrategy::GuardWithIf)
537 .split(c, c, ci, 10, TailStrategy::GuardWithIf)
538 .split(w, w, wi, 2, TailStrategy::GuardWithIf)
539 .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
540 .vectorize(nii)
541 .reorder(nii, r63_x, ni, ci, wi, n, c, w)
542 .fuse(c, w, c)
543 .fuse(n, c, n)
544 .serial(n);
545 head2_relu
546 .split(c, c, ci, 3, TailStrategy::GuardWithIf)
547 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
548 .vectorize(ni)
549 .compute_root()
550 .reorder(ni, n, ci, c, w)
551 .fuse(c, w, c)
552 .serial(c)
553 .reorder_storage(n, c, w);
554 head2_conv
555 .split(n, n, ni, 512, TailStrategy::GuardWithIf)
556 .split(c, c, ci, 6, TailStrategy::GuardWithIf)
557 .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
558 .vectorize(nii)
559 .compute_root()
560 .reorder(nii, ni, ci, n, c, w)
561 .fuse(c, w, c)
562 .fuse(n, c, n)
563 .serial(n)
564 .reorder_storage(n, c, w);
565 head2_conv.update(0)
566 .split(n, n, ni, 512, TailStrategy::GuardWithIf)
567 .split(c, c, ci, 6, TailStrategy::GuardWithIf)
568 .split(ni, ni, nii, 8, TailStrategy::GuardWithIf)
569 .vectorize(nii)
570 .reorder(nii, r40_x, ni, ci, n, c, w)
571 .fuse(c, w, c)
572 .fuse(n, c, n)
573 .parallel(n);
574 normalized_schedule_features
575 .split(c, c, ci, 11, TailStrategy::GuardWithIf)
576 .split(n, n, ni, 8, TailStrategy::GuardWithIf)
577 .vectorize(ni)
578 .compute_root()
579 .reorder(ni, n, ci, c, s)
580 .fuse(c, s, c)
581 .serial(c);
582 conv1_stage1
583 .split(w, w, wi, 8, TailStrategy::GuardWithIf)
584 .vectorize(wi)
585 .compute_root()
586 .reorder(wi, w, c)
587 .fuse(w, c, w)
588 .serial(w)
589 .reorder_storage(w, c);
590 conv1_stage1.update(0)
591 .split(w, w, wi, 8, TailStrategy::GuardWithIf)
592 .vectorize(wi)
593 .reorder(wi, r54_x, w, c)
594 .fuse(w, c, w)
595 .serial(w);
596 head1_conv
597 .split(w, w, wi, 8, TailStrategy::GuardWithIf)
598 .vectorize(wi)
599 .compute_root()
600 .reorder(wi, w, c)
601 .fuse(w, c, w)
602 .serial(w)
603 .reorder_storage(w, c);
604 head1_conv.update(0)
605 .split(w, w, wi, 8, TailStrategy::GuardWithIf)
606 .vectorize(wi)
607 .reorder(wi, r31_x, r31_y, w, c)
608 .fuse(w, c, w)
609 .serial(w);
610 squashed_head1_filter
611 .split(s, s, si, 8, TailStrategy::GuardWithIf)
612 .vectorize(si)
613 .compute_at(head1_conv, w)
614 .reorder(si, s, c, n)
615 .reorder_storage(s, c, n);
616}
void do_cost_model_schedule(Halide::Pipeline pipeline)
A halide function.
Definition Func.h:700
Func & store_in(MemoryType memory_type)
Set the type of memory this Func should be stored in.
Func & reorder(const std::vector< VarOrRVar > &vars)
Reorder variables to have the given nesting order, from innermost out.
Stage update(int idx=0)
Get a handle on an update step for the purposes of scheduling it.
Func & serial(const VarOrRVar &var)
Mark a dimension to be traversed serially.
Func & split(const VarOrRVar &old, const VarOrRVar &outer, const VarOrRVar &inner, const Expr &factor, TailStrategy tail=TailStrategy::Auto)
Split a dimension into inner and outer subdimensions with the given names, where the inner dimension ...
Func & compute_root()
Compute all of this function once ahead of time.
Func & unroll(const VarOrRVar &var)
Mark a dimension to be completely unrolled.
Func & store_at(const Func &f, const Var &var)
Allocate storage for this function within f's loop over var.
Func & parallel(const VarOrRVar &var)
Mark a dimension to be traversed in parallel.
Func & vectorize(const VarOrRVar &var)
Mark a dimension to be computed all-at-once as a single vector.
Func & fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRVar &fused)
Join two dimensions into a single fused dimension.
Func in(const Func &f)
Creates and returns a new identity Func that wraps this Func.
Func & compute_at(const Func &f, const Var &var)
Compute this function as needed for each unique value of the given var for the given calling function...
const Internal::StageSchedule & get_schedule() const
Return the current StageSchedule associated with this initial Stage of this Func.
Definition Func.h:2608
const std::vector< Dim > & dims() const
The list and ordering of dimensions used to evaluate this function, after all splits have taken place...
A class representing a Halide pipeline.
Definition Pipeline.h:107
Func get_func(size_t index)
Return handle to the index-th Func within the pipeline based on the topological order.
A reduction variable represents a single dimension of a reduction domain (RDom).
Definition RDom.h:29
Stage & vectorize(const VarOrRVar &var)
Stage & unroll(const VarOrRVar &var)
Stage & reorder(const std::vector< VarOrRVar > &vars)
Stage & parallel(const VarOrRVar &var)
const Internal::StageSchedule & get_schedule() const
Return the current StageSchedule associated with this Stage.
Definition Func.h:106
A Halide variable, to be used when defining functions.
Definition Var.h:19
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Expr sum(Expr, const std::string &s="sum")
An inline reduction.