Wan Xinyi commited on
Commit
be3048f
1 Parent(s): 594a8f9
Files changed (3) hide show
  1. app.py +15 -9
  2. hand_schedule.py +84 -0
  3. svg_event.py +7 -6
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import auto_schedule
3
  import v_schedule
 
4
  from PIL import Image
5
  from svg_event import render_manual_graph
6
  import pathlib
@@ -46,10 +47,13 @@ def calculate(p, m, f, b, w, c, mem):
46
  baseline_acceleration=None
47
  baseline_image=None
48
  else:
49
- baseline_time=(f+b+w)*m + (f+b+w+c)*(p-1)
 
 
 
 
50
  baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
51
  baseline_acceleration=percentage(0)
52
- baseline_image=None
53
 
54
 
55
  zb_result = auto_schedule.auto_schedule(p, m, auto_schedule.GraphConfig(
@@ -91,6 +95,8 @@ def calculate(p, m, f, b, w, c, mem):
91
  zbv_acceleration=percentage(baseline_time/zbv_time - 1) if baseline_time is not None else None
92
 
93
  max_time = max([baseline_time, zb_time, zbv_time])
 
 
94
  zb_image = get_schedule_image(zb_result, max_time)
95
  zbv_image = get_schedule_image(zbv_result, max_time)
96
 
@@ -110,23 +116,23 @@ with gr.Blocks() as demo:
110
  with gr.Group():
111
  gr.Markdown("Costs. All costs are used as integers. For ZBV schedules, this is the time of two virtual stages on a stage combined.")
112
  with gr.Row():
113
- f=gr.Number(label="Time of F", value=8, interactive=True, precision=0)
114
- b=gr.Number(label="Time of B", value=8, interactive=True, precision=0)
115
- w=gr.Number(label="Time of W", value=8, interactive=True, precision=0)
116
- c=gr.Number(label="Time of one P2P communication", value=1, interactive=True, precision=0)
117
  with gr.Group():
118
  gr.Markdown("Activation memory limit.")
119
  def update_mem(p, s, mem):
120
  print("update")
121
  if s=="custom":
122
  return mem
123
- return p*int(s[:1])
124
- memsel=gr.Radio(choices=["1p (Same as 1F1B)", "2p", "3p", "custom"], value="1p (Same as 1F1B)")
125
  mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For ZBV schedules, this is relative to two virtual stages on a stage combined.", value=p.value, interactive=True, precision=0)
126
  memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
127
  p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
128
 
129
- button=gr.Button("Calculate")
130
 
131
  with gr.Group():
132
  gr.Markdown("1F1B")
 
1
  import gradio as gr
2
  import auto_schedule
3
  import v_schedule
4
+ import hand_schedule
5
  from PIL import Image
6
  from svg_event import render_manual_graph
7
  import pathlib
 
47
  baseline_acceleration=None
48
  baseline_image=None
49
  else:
50
+ baseline_result = hand_schedule.get_hand_schedule(p, m, f, b + w, 0, c)
51
+ baseline_result = [
52
+ list(filter(lambda x: x.type in {'F', 'B'}, r)) for r in baseline_result
53
+ ]
54
+ baseline_time = get_schedule_time(baseline_result)
55
  baseline_bubble=percentage(baseline_time/(f+b+w)/m - 1)
56
  baseline_acceleration=percentage(0)
 
57
 
58
 
59
  zb_result = auto_schedule.auto_schedule(p, m, auto_schedule.GraphConfig(
 
95
  zbv_acceleration=percentage(baseline_time/zbv_time - 1) if baseline_time is not None else None
96
 
97
  max_time = max([baseline_time, zb_time, zbv_time])
98
+ print(max_time)
99
+ baseline_image = get_schedule_image(baseline_result, max_time)
100
  zb_image = get_schedule_image(zb_result, max_time)
101
  zbv_image = get_schedule_image(zbv_result, max_time)
102
 
 
116
  with gr.Group():
117
  gr.Markdown("Costs. All costs are used as integers. For ZBV schedules, this is the time of two virtual stages on a stage combined.")
118
  with gr.Row():
119
+ f=gr.Number(label="Time of F", value=100, interactive=True, precision=0)
120
+ b=gr.Number(label="Time of B", value=110, interactive=True, precision=0)
121
+ w=gr.Number(label="Time of W", value=90, interactive=True, precision=0)
122
+ c=gr.Number(label="Time of one P2P communication", value=5, interactive=True, precision=0)
123
  with gr.Group():
124
  gr.Markdown("Activation memory limit.")
125
  def update_mem(p, s, mem):
126
  print("update")
127
  if s=="custom":
128
  return mem
129
+ return int(p*float(s.split('p')[0]) + 0.5)
130
+ memsel=gr.Radio(choices=["1p (Same as 1F1B)", "1.5p", "2p", "3p", "custom"], value="1p (Same as 1F1B)")
131
  mem=gr.Number(label="Custom memory limit in terms of pending F on a stage. For ZBV schedules, this is relative to two virtual stages on a stage combined.", value=p.value, interactive=True, precision=0)
132
  memsel.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
133
  p.change(update_mem, inputs=[p, memsel, mem], outputs=mem)
134
 
135
+ button=gr.Button("Calculate", variant="primary")
136
 
137
  with gr.Group():
138
  gr.Markdown("1F1B")
hand_schedule.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass(eq=True, frozen=True)
4
+ class ScheduledNode:
5
+ type: str
6
+ stage: int
7
+ minibatch: int
8
+ start_time: int
9
+ completion_time: int
10
+ rollback: bool = False
11
+
12
+
13
+ def get_hand_schedule(_p, _n, _f, _b, _w, _c, warmup_c=1):
14
+ assert _n >= 2 * _p
15
+ stage = [[] for _ in range(_p)]
16
+ for rank in range(_p):
17
+ warmup = (_p - rank - 1) * warmup_c
18
+ for _ in range(warmup):
19
+ stage[rank].append(0)
20
+ for i in range(_n):
21
+ if warmup + i < _n:
22
+ stage[rank].append(0)
23
+ stage[rank].append(1)
24
+ if warmup + i >= (_p - 1) * warmup_c:
25
+ stage[rank].append(2)
26
+ for _ in range((_p - 1) * warmup_c - warmup):
27
+ stage[rank].append(2)
28
+ labels = ["F", "B", "W"]
29
+ for rank in range(_p):
30
+ rank_str = " " * rank
31
+ for i in range(_n * 3):
32
+ rank_str += labels[stage[rank][i]]
33
+ # print(rank_str)
34
+ size = _p * _n * 3
35
+ def get_id(_i, _j, _k):
36
+ return _i * _p * _n + _j * _n + _k
37
+ t = [-1] * size
38
+ e = [0] * _p
39
+ fc = [0] * _p
40
+ bc = [0] * _p
41
+ for i in range(3 * _n):
42
+ for rank in range(_p):
43
+ last = e[rank]
44
+ if stage[rank][i] == 0:
45
+ tmp = e[rank] + _f
46
+ if rank > 0:
47
+ assert t[get_id(0, rank - 1, fc[rank])] > 0
48
+ tmp = max(tmp, t[get_id(0, rank - 1, fc[rank])] + _c + _f)
49
+ e[rank] = tmp
50
+ t[get_id(0, rank, fc[rank])] = tmp
51
+ fc[rank] += 1
52
+ elif stage[rank][i] == 1:
53
+ tmp = e[rank] + _b
54
+ if rank < _p - 1:
55
+ assert t[get_id(1, rank + 1, bc[rank])] > 0
56
+ tmp = max(tmp, t[get_id(1, rank + 1, bc[rank])] + _c + _b)
57
+ e[rank] = tmp
58
+ t[get_id(1, rank, bc[rank])] = tmp
59
+ bc[rank] += 1
60
+ else:
61
+ tmp = e[rank] + _w
62
+ e[rank] = tmp
63
+ t[get_id(2, rank, i - fc[rank] - bc[rank])] = tmp
64
+ # if rank == _p - 1:
65
+ # print(_f, _b, _w, _c, "->", rank, i, stage[rank][i], e[rank], e[rank] - last)
66
+ max_time = 0
67
+ for rank in range(_p):
68
+ if warmup_c == 2:
69
+ max_time = max(max_time, e[rank] - t[get_id(0, rank, 0)] + _f)
70
+ else:
71
+ max_time = max(max_time, e[rank])
72
+ # print(rank, "->", e[rank])
73
+ # exit(0)
74
+ res = [[] for _ in range(_p)]
75
+ for rank in range(_p):
76
+ for i in range(_n):
77
+ res[rank].append(ScheduledNode(
78
+ "F", rank, i, t[get_id(0, rank, i)] - _f, t[get_id(0, rank, i)]))
79
+ res[rank].append(ScheduledNode(
80
+ "B", rank, i, t[get_id(1, rank, i)] - _b, t[get_id(1, rank, i)]))
81
+ res[rank].append(ScheduledNode(
82
+ "W", rank, i, t[get_id(2, rank, i)] - _w, t[get_id(2, rank, i)]))
83
+ res[rank] = sorted(res[rank], key=lambda x: x.start_time)
84
+ return res
svg_event.py CHANGED
@@ -170,8 +170,8 @@ def draw_experiment_and_schedule(exp_events, sched_events, output_filename, tail
170
  d.save_svg(output_filename)
171
 
172
 
173
- def draw_events(events, output_filename, include_w=True, include_o=True, tail=50):
174
- canvas_info = CanvasInfo(events, tail, center_title_height=0, enable_info=True)
175
  max_len = canvas_info.max_len
176
  # height = canvas_info.height
177
  # info_height = canvas_info.info_height
@@ -185,8 +185,9 @@ def draw_events(events, output_filename, include_w=True, include_o=True, tail=50
185
 
186
 
187
  class CanvasInfo:
188
- def __init__(self, events, tail, center_title_height=CENTER_TITLE_HEIGHT, enable_info=True):
189
- last_time = max(max([e["completion_time"] for e in dev_evs]) for dev_evs in events)
 
190
  self.max_len = (last_time + TIME_PER_UNIT - 1) // TIME_PER_UNIT + tail
191
 
192
  self.height = SPAN_HEIGHT * len(events) + BORDER_SIZE * (len(events) + 1)
@@ -233,7 +234,7 @@ def plot_events(ctx, events, title_text: str, canvas_info: CanvasInfo, include_w
233
  if ENABLE_BATCH_ID:
234
  minibatch = str(e["minibatch"])
235
  center = (start + end) // 2
236
- data_ctx.text(h, center, minibatch, font_scale=0.7, fill='black' if e["chunk"] == 0 else 'white')
237
  if ENABLE_BORDER:
238
  data_ctx.line(h+SPAN_HEIGHT, 0, h+SPAN_HEIGHT+BORDER_SIZE, max_len - 1)
239
 
@@ -340,7 +341,7 @@ def render_manual_graph(data, longest_time, enable_batch_id = False):
340
  #events = load_json_data("no-bb-schedule.json")
341
 
342
  path = os.path.join(tempfile.mkdtemp(), 'a.svg')
343
- draw_events(events, path, include_w=True, include_o=False, tail=50)
344
  return path
345
 
346
 
 
170
  d.save_svg(output_filename)
171
 
172
 
173
+ def draw_events(events, output_filename, include_w=True, include_o=True, tail=50, longest_time=None):
174
+ canvas_info = CanvasInfo(events, tail, center_title_height=0, enable_info=True, longest_time=longest_time)
175
  max_len = canvas_info.max_len
176
  # height = canvas_info.height
177
  # info_height = canvas_info.info_height
 
185
 
186
 
187
  class CanvasInfo:
188
+ def __init__(self, events, tail, center_title_height=CENTER_TITLE_HEIGHT, enable_info=True, longest_time=None):
189
+
190
+ last_time = max(max([e["completion_time"] for e in dev_evs]) for dev_evs in events) if longest_time is None else longest_time
191
  self.max_len = (last_time + TIME_PER_UNIT - 1) // TIME_PER_UNIT + tail
192
 
193
  self.height = SPAN_HEIGHT * len(events) + BORDER_SIZE * (len(events) + 1)
 
234
  if ENABLE_BATCH_ID:
235
  minibatch = str(e["minibatch"])
236
  center = (start + end) // 2
237
+ data_ctx.text(h, center, minibatch, font_scale=0.6, fill='black' if e["chunk"] == 0 else 'white')
238
  if ENABLE_BORDER:
239
  data_ctx.line(h+SPAN_HEIGHT, 0, h+SPAN_HEIGHT+BORDER_SIZE, max_len - 1)
240
 
 
341
  #events = load_json_data("no-bb-schedule.json")
342
 
343
  path = os.path.join(tempfile.mkdtemp(), 'a.svg')
344
+ draw_events(events, path, include_w=True, include_o=False, tail=50, longest_time=longest_time * time_scale)
345
  return path
346
 
347