14 using ::Halide::MemoryType;
16 using ::Halide::TailStrategy;
21 Var c(relu.get_schedule().dims()[0].var);
23 Var n(relu.get_schedule().dims()[3].var);
24 Var x(relu.get_schedule().dims()[1].var);
27 Var y(relu.get_schedule().dims()[2].var);
30 RVar r13_x(conv.update(0).get_schedule().dims()[0].var);
31 RVar r13_y(conv.update(0).get_schedule().dims()[1].var);
32 RVar r13_z(conv.update(0).get_schedule().dims()[2].var);
33 Var yi_serial_outer(
"yi_serial_outer");
34 Var xi_serial_outer(
"xi_serial_outer");
35 Var ci_serial_outer(
"ci_serial_outer");
37 .split(c, c, ci, 24, TailStrategy::ShiftInwards)
38 .split(x, x, xi, 16, TailStrategy::ShiftInwards)
39 .split(y, y, yi, 4, TailStrategy::ShiftInwards)
40 .split(xi, xi, xii, 4, TailStrategy::ShiftInwards)
41 .split(yi, yi, yii, 2, TailStrategy::ShiftInwards)
45 .reorder(xii, yii, ci, xi, yi, c, x, y, n)
50 .split(ci, ci_serial_outer, ci, 24, TailStrategy::GuardWithIf)
52 .split(xi, xi_serial_outer, xi, 4, TailStrategy::GuardWithIf)
54 .split(yi, yi_serial_outer, yi, 2, TailStrategy::GuardWithIf)
57 .split(c, c, ci, 24, TailStrategy::GuardWithIf)
58 .split(x, x, xi, 4, TailStrategy::GuardWithIf)
59 .split(y, y, yi, 16, TailStrategy::GuardWithIf)
60 .split(yi, yi, yii, 2, TailStrategy::GuardWithIf)
62 .reorder(yii, r13_x, r13_y, r13_z, ci, xi, yi, c, x, y, n)
67 .split(ci, ci_serial_outer, ci, 24, TailStrategy::GuardWithIf)
69 .split(xi, xi_serial_outer, xi, 4, TailStrategy::GuardWithIf)
71 .split(yi, yi_serial_outer, yi, 8, TailStrategy::GuardWithIf)
74 .split(c, c, ci, 24, TailStrategy::ShiftInwards)
75 .split(x, x, xi, 4, TailStrategy::ShiftInwards)
76 .split(y, y, yi, 16, TailStrategy::ShiftInwards)
77 .split(yi, yi, yii, 2, TailStrategy::ShiftInwards)
80 .reorder(yii, ci, xi, yi, c, x, y, n)
85 .split(ci, ci_serial_outer, ci, 24, TailStrategy::GuardWithIf)
87 .split(xi, xi_serial_outer, xi, 4, TailStrategy::GuardWithIf)
89 .split(yi, yi_serial_outer, yi, 8, TailStrategy::GuardWithIf)
91 conv.in(relu).store_in(MemoryType::Register).compute_at(relu, ci).bound_extent(c, 1).unroll(c).bound_extent(x, 4).unroll(x).bound_extent(y, 2).unroll(y).bound_extent(n, 1).unroll(n);