Wan Xinyi
commited on
Commit
•
ac0b05c
1
Parent(s):
be3048f
Add some presets, support 1f1b with fewer microbatches
Browse files- app.py +36 -6
- hand_schedule.py +20 -11
app.py
CHANGED
@@ -46,6 +46,7 @@ def calculate(p, m, f, b, w, c, mem):
|
|
46 |
baseline_bubble=None
|
47 |
baseline_acceleration=None
|
48 |
baseline_image=None
|
|
|
49 |
else:
|
50 |
baseline_result = hand_schedule.get_hand_schedule(p, m, f, b + w, 0, c)
|
51 |
baseline_result = [
|
@@ -70,11 +71,12 @@ def calculate(p, m, f, b, w, c, mem):
|
|
70 |
zb_bubble=percentage(zb_time/(f+b+w)/m - 1)
|
71 |
zb_acceleration=percentage(baseline_time/zb_time - 1) if baseline_time is not None else None
|
72 |
|
73 |
-
if mem < p:
|
74 |
zbv_time=None
|
75 |
zbv_bubble=None
|
76 |
zbv_acceleration=None
|
77 |
zbv_image=None
|
|
|
78 |
else:
|
79 |
zbv_graph = v_schedule.PipelineGraph(
|
80 |
n_stage=p,
|
@@ -94,10 +96,13 @@ def calculate(p, m, f, b, w, c, mem):
|
|
94 |
zbv_bubble=percentage(zbv_time/(f+b+w)/m - 1)
|
95 |
zbv_acceleration=percentage(baseline_time/zbv_time - 1) if baseline_time is not None else None
|
96 |
|
97 |
-
|
98 |
-
|
|
|
99 |
baseline_image = get_schedule_image(baseline_result, max_time)
|
|
|
100 |
zb_image = get_schedule_image(zb_result, max_time)
|
|
|
101 |
zbv_image = get_schedule_image(zbv_result, max_time)
|
102 |
|
103 |
return [baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image]
|
@@ -105,6 +110,20 @@ def calculate(p, m, f, b, w, c, mem):
|
|
105 |
with gr.Blocks() as demo:
|
106 |
gr.Markdown(open("description1.md").read())
|
107 |
gr.Markdown("# Pipeline Scheduler Playground")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
with gr.Row():
|
109 |
with gr.Column(scale=1):
|
110 |
with gr.Group():
|
@@ -142,7 +161,7 @@ with gr.Blocks() as demo:
|
|
142 |
baseline_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
|
143 |
baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
144 |
with gr.Column(scale=4):
|
145 |
-
baseline_image=gr.Image(None, interactive=False, label="Schedule Image")
|
146 |
|
147 |
with gr.Group():
|
148 |
gr.Markdown("Zero Bubble Schedule")
|
@@ -152,7 +171,7 @@ with gr.Blocks() as demo:
|
|
152 |
zb_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
|
153 |
zb_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
154 |
with gr.Column(scale=4):
|
155 |
-
zb_image=gr.Image(None, interactive=False, label="Schedule Image")
|
156 |
with gr.Group():
|
157 |
gr.Markdown("Zero Bubble V Schedule (ZBV)")
|
158 |
with gr.Row():
|
@@ -161,7 +180,18 @@ with gr.Blocks() as demo:
|
|
161 |
zbv_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
|
162 |
zbv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
163 |
with gr.Column(scale=4):
|
164 |
-
zbv_image=gr.Image(None, interactive=False, label="Schedule Image")
|
165 |
button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
gr.Markdown(open("description2.md").read())
|
167 |
demo.launch()
|
|
|
46 |
baseline_bubble=None
|
47 |
baseline_acceleration=None
|
48 |
baseline_image=None
|
49 |
+
baseline_result=None
|
50 |
else:
|
51 |
baseline_result = hand_schedule.get_hand_schedule(p, m, f, b + w, 0, c)
|
52 |
baseline_result = [
|
|
|
71 |
zb_bubble=percentage(zb_time/(f+b+w)/m - 1)
|
72 |
zb_acceleration=percentage(baseline_time/zb_time - 1) if baseline_time is not None else None
|
73 |
|
74 |
+
if mem < p or m < 2 * p:
|
75 |
zbv_time=None
|
76 |
zbv_bubble=None
|
77 |
zbv_acceleration=None
|
78 |
zbv_image=None
|
79 |
+
zbv_result=None
|
80 |
else:
|
81 |
zbv_graph = v_schedule.PipelineGraph(
|
82 |
n_stage=p,
|
|
|
96 |
zbv_bubble=percentage(zbv_time/(f+b+w)/m - 1)
|
97 |
zbv_acceleration=percentage(baseline_time/zbv_time - 1) if baseline_time is not None else None
|
98 |
|
99 |
+
max_time = max(filter(lambda x: x is not None, [baseline_time, zb_time, zbv_time]))
|
100 |
+
print(max_time)
|
101 |
+
if baseline_result is not None:
|
102 |
baseline_image = get_schedule_image(baseline_result, max_time)
|
103 |
+
if zb_result is not None:
|
104 |
zb_image = get_schedule_image(zb_result, max_time)
|
105 |
+
if zbv_result is not None:
|
106 |
zbv_image = get_schedule_image(zbv_result, max_time)
|
107 |
|
108 |
return [baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image]
|
|
|
110 |
with gr.Blocks() as demo:
|
111 |
gr.Markdown(open("description1.md").read())
|
112 |
gr.Markdown("# Pipeline Scheduler Playground")
|
113 |
+
presets = {
|
114 |
+
'Ideal Case 1p': (4, 12, 20, 20, 20, 0, '1p (Same as 1F1B)'),
|
115 |
+
'Ideal Case 2p': (4, 12, 20, 20, 20, 0, '2p'),
|
116 |
+
'Real Case 1p': (4, 12, 1049, 1122, 903, 79, '1p (Same as 1F1B)'),
|
117 |
+
'Real Case 2p': (4, 12, 1049, 1122, 903, 79, '2p'),
|
118 |
+
}
|
119 |
+
preset_buttons = {}
|
120 |
+
|
121 |
+
with gr.Group():
|
122 |
+
gr.Markdown("Preset Setups")
|
123 |
+
with gr.Row():
|
124 |
+
for (k, v) in presets.items():
|
125 |
+
preset_buttons[k] = gr.Button(k, variant="secondary")
|
126 |
+
|
127 |
with gr.Row():
|
128 |
with gr.Column(scale=1):
|
129 |
with gr.Group():
|
|
|
161 |
baseline_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
|
162 |
baseline_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
163 |
with gr.Column(scale=4):
|
164 |
+
baseline_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
165 |
|
166 |
with gr.Group():
|
167 |
gr.Markdown("Zero Bubble Schedule")
|
|
|
171 |
zb_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
|
172 |
zb_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
173 |
with gr.Column(scale=4):
|
174 |
+
zb_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
175 |
with gr.Group():
|
176 |
gr.Markdown("Zero Bubble V Schedule (ZBV)")
|
177 |
with gr.Row():
|
|
|
180 |
zbv_bubble=gr.Textbox("", label="Bubble Rate. Calculated as (1 - longest stage time/(F+B+W)/m).")
|
181 |
zbv_acceleration=gr.Textbox("", label="Acceleration compared to 1F1B")
|
182 |
with gr.Column(scale=4):
|
183 |
+
zbv_image=gr.Image(None, interactive=False, label="Schedule Image", show_label=False)
|
184 |
button.click(calculate, inputs=[p, m, f, b, w, c, mem], outputs=[baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image])
|
185 |
+
|
186 |
+
for (k, v) in presets.items():
|
187 |
+
def update_preset(pb, p, m, f, b, w, c, mem):
|
188 |
+
print(pb)
|
189 |
+
print(presets[pb])
|
190 |
+
print(presets[pb][-1])
|
191 |
+
return *presets[pb],*calculate(*presets[pb][:-1], update_mem(p, presets[pb][-1], -1))
|
192 |
+
preset_buttons[k].click(
|
193 |
+
update_preset,
|
194 |
+
inputs=[preset_buttons[k], p, m, f, b, w, c, mem],
|
195 |
+
outputs=[p, m, f, b, w, c, memsel, baseline_time, baseline_bubble, baseline_acceleration, baseline_image, zb_time, zb_bubble, zb_acceleration, zb_image, zbv_time, zbv_bubble, zbv_acceleration, zbv_image])
|
196 |
gr.Markdown(open("description2.md").read())
|
197 |
demo.launch()
|
hand_schedule.py
CHANGED
@@ -11,8 +11,10 @@ class ScheduledNode:
|
|
11 |
|
12 |
|
13 |
def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
|
14 |
-
assert _n >= 2 * _p
|
15 |
stage = [[] for _ in range(_p)]
|
|
|
|
|
16 |
for rank in range(_p):
|
17 |
warmup = (_p - rank - 1) * warmup_c
|
18 |
for _ in range(warmup):
|
@@ -25,12 +27,13 @@ def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
|
|
25 |
stage[rank].append(2)
|
26 |
for _ in range((_p - 1) * warmup_c - warmup):
|
27 |
stage[rank].append(2)
|
28 |
-
labels = ["F", "B", "W"]
|
29 |
for rank in range(_p):
|
30 |
rank_str = " " * rank
|
31 |
-
for i in range(_n * 3):
|
|
|
32 |
rank_str += labels[stage[rank][i]]
|
33 |
-
|
34 |
size = _p * _n * 3
|
35 |
def get_id(_i, _j, _k):
|
36 |
return _i * _p * _n + _j * _n + _k
|
@@ -42,6 +45,8 @@ def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
|
|
42 |
for rank in range(_p):
|
43 |
last = e[rank]
|
44 |
if stage[rank][i] == 0:
|
|
|
|
|
45 |
tmp = e[rank] + _f
|
46 |
if rank > 0:
|
47 |
assert t[get_id(0, rank - 1, fc[rank])] > 0
|
@@ -50,17 +55,17 @@ def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
|
|
50 |
t[get_id(0, rank, fc[rank])] = tmp
|
51 |
fc[rank] += 1
|
52 |
elif stage[rank][i] == 1:
|
|
|
|
|
53 |
tmp = e[rank] + _b
|
54 |
if rank < _p - 1:
|
55 |
-
assert t[get_id(1, rank + 1, bc[rank])] > 0
|
56 |
tmp = max(tmp, t[get_id(1, rank + 1, bc[rank])] + _c + _b)
|
57 |
e[rank] = tmp
|
58 |
t[get_id(1, rank, bc[rank])] = tmp
|
59 |
bc[rank] += 1
|
60 |
-
|
61 |
-
|
62 |
-
e[rank] = tmp
|
63 |
-
t[get_id(2, rank, i - fc[rank] - bc[rank])] = tmp
|
64 |
# if rank == _p - 1:
|
65 |
# print(_f, _b, _w, _c, "->", rank, i, stage[rank][i], e[rank], e[rank] - last)
|
66 |
max_time = 0
|
@@ -73,7 +78,7 @@ def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
|
|
73 |
# exit(0)
|
74 |
res = [[] for _ in range(_p)]
|
75 |
for rank in range(_p):
|
76 |
-
for i in range(
|
77 |
res[rank].append(ScheduledNode(
|
78 |
"F", rank, i, t[get_id(0, rank, i)] - _f, t[get_id(0, rank, i)]))
|
79 |
res[rank].append(ScheduledNode(
|
@@ -81,4 +86,8 @@ def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
|
|
81 |
res[rank].append(ScheduledNode(
|
82 |
"W", rank, i, t[get_id(2, rank, i)] - _w, t[get_id(2, rank, i)]))
|
83 |
res[rank] = sorted(res[rank], key=lambda x: x.start_time)
|
84 |
-
return res
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
|
14 |
+
# assert _n >= 2 * _p
|
15 |
stage = [[] for _ in range(_p)]
|
16 |
+
real_n = _n
|
17 |
+
_n = max(_n, _p)
|
18 |
for rank in range(_p):
|
19 |
warmup = (_p - rank - 1) * warmup_c
|
20 |
for _ in range(warmup):
|
|
|
27 |
stage[rank].append(2)
|
28 |
for _ in range((_p - 1) * warmup_c - warmup):
|
29 |
stage[rank].append(2)
|
30 |
+
labels = ["F", "B", "W", '.']
|
31 |
for rank in range(_p):
|
32 |
rank_str = " " * rank
|
33 |
+
# for i in range(_n * 3):
|
34 |
+
for i in range(len(stage[rank])):
|
35 |
rank_str += labels[stage[rank][i]]
|
36 |
+
print(rank_str)
|
37 |
size = _p * _n * 3
|
38 |
def get_id(_i, _j, _k):
|
39 |
return _i * _p * _n + _j * _n + _k
|
|
|
45 |
for rank in range(_p):
|
46 |
last = e[rank]
|
47 |
if stage[rank][i] == 0:
|
48 |
+
if fc[rank] >= real_n:
|
49 |
+
continue
|
50 |
tmp = e[rank] + _f
|
51 |
if rank > 0:
|
52 |
assert t[get_id(0, rank - 1, fc[rank])] > 0
|
|
|
55 |
t[get_id(0, rank, fc[rank])] = tmp
|
56 |
fc[rank] += 1
|
57 |
elif stage[rank][i] == 1:
|
58 |
+
if bc[rank] >= real_n:
|
59 |
+
continue
|
60 |
tmp = e[rank] + _b
|
61 |
if rank < _p - 1:
|
62 |
+
assert t[get_id(1, rank + 1, bc[rank])] > 0, f"{rank} {i} {bc[rank]}"
|
63 |
tmp = max(tmp, t[get_id(1, rank + 1, bc[rank])] + _c + _b)
|
64 |
e[rank] = tmp
|
65 |
t[get_id(1, rank, bc[rank])] = tmp
|
66 |
bc[rank] += 1
|
67 |
+
elif stage[rank][i] == 2:
|
68 |
+
continue
|
|
|
|
|
69 |
# if rank == _p - 1:
|
70 |
# print(_f, _b, _w, _c, "->", rank, i, stage[rank][i], e[rank], e[rank] - last)
|
71 |
max_time = 0
|
|
|
78 |
# exit(0)
|
79 |
res = [[] for _ in range(_p)]
|
80 |
for rank in range(_p):
|
81 |
+
for i in range(real_n):
|
82 |
res[rank].append(ScheduledNode(
|
83 |
"F", rank, i, t[get_id(0, rank, i)] - _f, t[get_id(0, rank, i)]))
|
84 |
res[rank].append(ScheduledNode(
|
|
|
86 |
res[rank].append(ScheduledNode(
|
87 |
"W", rank, i, t[get_id(2, rank, i)] - _w, t[get_id(2, rank, i)]))
|
88 |
res[rank] = sorted(res[rank], key=lambda x: x.start_time)
|
89 |
+
return res
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
print(get_hand_schedule(16, 16, 1, 1, 1, 0))
|
93 |
+
|