Wan Xinyi
commited on
Commit
•
be3048f
1
Parent(s):
594a8f9
Add 1f1b
Browse files- app.py +15 -9
- hand_schedule.py +84 -0
- svg_event.py +7 -6
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import auto_schedule
|
3 |
import v_schedule
|
|
|
4 |
from PIL import Image
|
5 |
from svg_event import render_manual_graph
|
6 |
import pathlib
|
@@ -46,10 +47,13 @@ def calculate(p, m, f, b, w, c, mem):
|
|
46 |
baseline_acceleration=None
|
47 |
baseline_image=None
|
48 |
else:
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
|
51 |
baseline_acceleration=percentage(0)
|
52 |
-
baseline_image=None
|
53 |
|
54 |
|
55 |
zb_result = auto_schedule.auto_schedule(p, m, auto_schedule.GraphConfig(
|
@@ -91,6 +95,8 @@ def calculate(p, m, f, b, w, c, mem):
|
|
91 |
zbv_acceleration=percentage(baseline_time/zbv_time - 1) if baseline_time is not None else None
|
92 |
|
93 |
max_time = max([baseline_time, zb_time, zbv_time])
|
|
|
|
|
94 |
zb_image = get_schedule_image(zb_result, max_time)
|
95 |
zbv_image = get_schedule_image(zbv_result, max_time)
|
96 |
|
@@ -110,23 +116,23 @@ with gr.Blocks() as demo:
|
|
110 |
with gr.Group():
|
111 |
gr.Markdown("Costs. All costs are used as integers. For ZBV schedules, this is the time of two virtual stages on a stage combined.")
|
112 |
with gr.Row():
|
113 |
-
f=gr.Number(label="Time of F", value=
|
114 |
-
b=gr.Number(label="Time of B", value=
|
115 |
-
w=gr.Number(label="Time of W", value=
|
116 |
-
c=gr.Number(label="Time of one P2P communication", value=
|
117 |
with gr.Group():
|
118 |
gr.Markdown("Activation memory limit.")
|
119 |
def update_mem(p, s, mem):
|
120 |
print("update")
|
121 |
if s=="custom":
|
122 |
return mem
|
123 |
-
return p*
|
124 |
-
memsel=gr.Radio(choices=["1p (Same as 1F1B)", "2p", "3p", "custom"], value="1p (Same as 1F1B)")
|
125 |
mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For ZBV schedules, this is relative to two virtual stages on a stage combined.", value=p.value, interactive=True, precision=0)
|
126 |
memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
127 |
p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
128 |
|
129 |
-
button=gr.Button("Calculate")
|
130 |
|
131 |
with gr.Group():
|
132 |
gr.Markdown("1F1B")
|
|
|
1 |
import gradio as gr
|
2 |
import auto_schedule
|
3 |
import v_schedule
|
4 |
+
import hand_schedule
|
5 |
from PIL import Image
|
6 |
from svg_event import render_manual_graph
|
7 |
import pathlib
|
|
|
47 |
baseline_acceleration=None
|
48 |
baseline_image=None
|
49 |
else:
|
50 |
+
baseline_result = hand_schedule.get_hand_schedule(p, m, f, b + w, 0, c)
|
51 |
+
baseline_result = [
|
52 |
+
list(filter(lambda x: x.type in {'F', 'B'}, r)) for r in baseline_result
|
53 |
+
]
|
54 |
+
baseline_time = get_schedule_time(baseline_result)
|
55 |
baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
|
56 |
baseline_acceleration=percentage(0)
|
|
|
57 |
|
58 |
|
59 |
zb_result = auto_schedule.auto_schedule(p, m, auto_schedule.GraphConfig(
|
|
|
95 |
zbv_acceleration=percentage(baseline_time/zbv_time - 1) if baseline_time is not None else None
|
96 |
|
97 |
max_time = max([baseline_time, zb_time, zbv_time])
|
98 |
+
print(max_time)
|
99 |
+
baseline_image = get_schedule_image(baseline_result, max_time)
|
100 |
zb_image = get_schedule_image(zb_result, max_time)
|
101 |
zbv_image = get_schedule_image(zbv_result, max_time)
|
102 |
|
|
|
116 |
with gr.Group():
|
117 |
gr.Markdown("Costs. All costs are used as integers. For ZBV schedules, this is the time of two virtual stages on a stage combined.")
|
118 |
with gr.Row():
|
119 |
+
f=gr.Number(label="Time of F", value=100, interactive=True, precision=0)
|
120 |
+
b=gr.Number(label="Time of B", value=110, interactive=True, precision=0)
|
121 |
+
w=gr.Number(label="Time of W", value=90, interactive=True, precision=0)
|
122 |
+
c=gr.Number(label="Time of one P2P communication", value=5, interactive=True, precision=0)
|
123 |
with gr.Group():
|
124 |
gr.Markdown("Activation memory limit.")
|
125 |
def update_mem(p, s, mem):
|
126 |
print("update")
|
127 |
if s=="custom":
|
128 |
return mem
|
129 |
+
return int(p*float(s.split('p')[0]) + 0.5)
|
130 |
+
memsel=gr.Radio(choices=["1p (Same as 1F1B)", "1.5p", "2p", "3p", "custom"], value="1p (Same as 1F1B)")
|
131 |
mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For ZBV schedules, this is relative to two virtual stages on a stage combined.", value=p.value, interactive=True, precision=0)
|
132 |
memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
133 |
p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
|
134 |
|
135 |
+
button=gr.Button("Calculate", variant="primary")
|
136 |
|
137 |
with gr.Group():
|
138 |
gr.Markdown("1F1B")
|
hand_schedule.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
@dataclass(eq=True, frozen=True)
|
4 |
+
class ScheduledNode:
|
5 |
+
type: str
|
6 |
+
stage: int
|
7 |
+
minibatch: int
|
8 |
+
start_time: int
|
9 |
+
completion_time: int
|
10 |
+
rollback: bool = False
|
11 |
+
|
12 |
+
|
13 |
+
def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
|
14 |
+
assert _n >= 2 * _p
|
15 |
+
stage = [[] for _ in range(_p)]
|
16 |
+
for rank in range(_p):
|
17 |
+
warmup = (_p - rank - 1) * warmup_c
|
18 |
+
for _ in range(warmup):
|
19 |
+
stage[rank].append(0)
|
20 |
+
for i in range(_n):
|
21 |
+
if warmup + i < _n:
|
22 |
+
stage[rank].append(0)
|
23 |
+
stage[rank].append(1)
|
24 |
+
if warmup + i >= (_p - 1) * warmup_c:
|
25 |
+
stage[rank].append(2)
|
26 |
+
for _ in range((_p - 1) * warmup_c - warmup):
|
27 |
+
stage[rank].append(2)
|
28 |
+
labels = ["F", "B", "W"]
|
29 |
+
for rank in range(_p):
|
30 |
+
rank_str = " " * rank
|
31 |
+
for i in range(_n * 3):
|
32 |
+
rank_str += labels[stage[rank][i]]
|
33 |
+
# print(rank_str)
|
34 |
+
size = _p * _n * 3
|
35 |
+
def get_id(_i, _j, _k):
|
36 |
+
return _i * _p * _n + _j * _n + _k
|
37 |
+
t = [-1] * size
|
38 |
+
e = [0] * _p
|
39 |
+
fc = [0] * _p
|
40 |
+
bc = [0] * _p
|
41 |
+
for i in range(3 * _n):
|
42 |
+
for rank in range(_p):
|
43 |
+
last = e[rank]
|
44 |
+
if stage[rank][i] == 0:
|
45 |
+
tmp = e[rank] + _f
|
46 |
+
if rank > 0:
|
47 |
+
assert t[get_id(0, rank - 1, fc[rank])] > 0
|
48 |
+
tmp = max(tmp, t[get_id(0, rank - 1, fc[rank])] + _c + _f)
|
49 |
+
e[rank] = tmp
|
50 |
+
t[get_id(0, rank, fc[rank])] = tmp
|
51 |
+
fc[rank] += 1
|
52 |
+
elif stage[rank][i] == 1:
|
53 |
+
tmp = e[rank] + _b
|
54 |
+
if rank < _p - 1:
|
55 |
+
assert t[get_id(1, rank + 1, bc[rank])] > 0
|
56 |
+
tmp = max(tmp, t[get_id(1, rank + 1, bc[rank])] + _c + _b)
|
57 |
+
e[rank] = tmp
|
58 |
+
t[get_id(1, rank, bc[rank])] = tmp
|
59 |
+
bc[rank] += 1
|
60 |
+
else:
|
61 |
+
tmp = e[rank] + _w
|
62 |
+
e[rank] = tmp
|
63 |
+
t[get_id(2, rank, i - fc[rank] - bc[rank])] = tmp
|
64 |
+
# if rank == _p - 1:
|
65 |
+
# print(_f, _b, _w, _c, "->", rank, i, stage[rank][i], e[rank], e[rank] - last)
|
66 |
+
max_time = 0
|
67 |
+
for rank in range(_p):
|
68 |
+
if warmup_c == 2:
|
69 |
+
max_time = max(max_time, e[rank] - t[get_id(0, rank, 0)] + _f)
|
70 |
+
else:
|
71 |
+
max_time = max(max_time, e[rank])
|
72 |
+
# print(rank, "->", e[rank])
|
73 |
+
# exit(0)
|
74 |
+
res = [[] for _ in range(_p)]
|
75 |
+
for rank in range(_p):
|
76 |
+
for i in range(_n):
|
77 |
+
res[rank].append(ScheduledNode(
|
78 |
+
"F", rank, i, t[get_id(0, rank, i)] - _f, t[get_id(0, rank, i)]))
|
79 |
+
res[rank].append(ScheduledNode(
|
80 |
+
"B", rank, i, t[get_id(1, rank, i)] - _b, t[get_id(1, rank, i)]))
|
81 |
+
res[rank].append(ScheduledNode(
|
82 |
+
"W", rank, i, t[get_id(2, rank, i)] - _w, t[get_id(2, rank, i)]))
|
83 |
+
res[rank] = sorted(res[rank], key=lambda x: x.start_time)
|
84 |
+
return res
|
svg_event.py
CHANGED
@@ -170,8 +170,8 @@ def draw_experiment_and_schedule(exp_events, sched_events, output_filename, tail
|
|
170 |
d.save_svg(output_filename)
|
171 |
|
172 |
|
173 |
-
def draw_events(events, output_filename, include_w=True, include_o=True, tail=50):
|
174 |
-
canvas_info = CanvasInfo(events, tail, center_title_height=0, enable_info=True)
|
175 |
max_len = canvas_info.max_len
|
176 |
# height = canvas_info.height
|
177 |
# info_height = canvas_info.info_height
|
@@ -185,8 +185,9 @@ def draw_events(events, output_filename, include_w=True, include_o=True, tail=50
|
|
185 |
|
186 |
|
187 |
class CanvasInfo:
|
188 |
-
def __init__(self, events, tail, center_title_height=CENTER_TITLE_HEIGHT, enable_info=True):
|
189 |
-
|
|
|
190 |
self.max_len = (last_time + TIME_PER_UNIT - 1) // TIME_PER_UNIT + tail
|
191 |
|
192 |
self.height = SPAN_HEIGHT * len(events) + BORDER_SIZE * (len(events) + 1)
|
@@ -233,7 +234,7 @@ def plot_events(ctx, events, title_text: str, canvas_info: CanvasInfo, include_w
|
|
233 |
if ENABLE_BATCH_ID:
|
234 |
minibatch = str(e["minibatch"])
|
235 |
center = (start + end) // 2
|
236 |
-
data_ctx.text(h, center, minibatch, font_scale=0.
|
237 |
if ENABLE_BORDER:
|
238 |
data_ctx.line(h+SPAN_HEIGHT, 0, h+SPAN_HEIGHT+BORDER_SIZE, max_len - 1)
|
239 |
|
@@ -340,7 +341,7 @@ def render_manual_graph(data, longest_time, enable_batch_id = False):
|
|
340 |
#events = load_json_data("no-bb-schedule.json")
|
341 |
|
342 |
path = os.path.join(tempfile.mkdtemp(), 'a.svg')
|
343 |
-
draw_events(events, path, include_w=True, include_o=False, tail=50)
|
344 |
return path
|
345 |
|
346 |
|
|
|
170 |
d.save_svg(output_filename)
|
171 |
|
172 |
|
173 |
+
def draw_events(events, output_filename, include_w=True, include_o=True, tail=50, longest_time=None):
|
174 |
+
canvas_info = CanvasInfo(events, tail, center_title_height=0, enable_info=True, longest_time=longest_time)
|
175 |
max_len = canvas_info.max_len
|
176 |
# height = canvas_info.height
|
177 |
# info_height = canvas_info.info_height
|
|
|
185 |
|
186 |
|
187 |
class CanvasInfo:
|
188 |
+
def __init__(self, events, tail, center_title_height=CENTER_TITLE_HEIGHT, enable_info=True, longest_time=None):
|
189 |
+
|
190 |
+
last_time = max(max([e["completion_time"] for e in dev_evs]) for dev_evs in events) if longest_time is None else longest_time
|
191 |
self.max_len = (last_time + TIME_PER_UNIT - 1) // TIME_PER_UNIT + tail
|
192 |
|
193 |
self.height = SPAN_HEIGHT * len(events) + BORDER_SIZE * (len(events) + 1)
|
|
|
234 |
if ENABLE_BATCH_ID:
|
235 |
minibatch = str(e["minibatch"])
|
236 |
center = (start + end) // 2
|
237 |
+
data_ctx.text(h, center, minibatch, font_scale=0.6, fill='black' if e["chunk"] == 0 else 'white')
|
238 |
if ENABLE_BORDER:
|
239 |
data_ctx.line(h+SPAN_HEIGHT, 0, h+SPAN_HEIGHT+BORDER_SIZE, max_len - 1)
|
240 |
|
|
|
341 |
#events = load_json_data("no-bb-schedule.json")
|
342 |
|
343 |
path = os.path.join(tempfile.mkdtemp(), 'a.svg')
|
344 |
+
draw_events(events, path, include_w=True, include_o=False, tail=50, longest_time=longest_time * time_scale)
|
345 |
return path
|
346 |
|
347 |
|