Nyamdavaa Amar
Edit presets
cf49f13
pattern_size = 6
from collections import Counter
from dataclasses import dataclass
@dataclass(eq=True, frozen=True)
class ScheduledNode:
type: str
chunk: int
stage: int
minibatch: int
start_time: int
completion_time: int
def transform_schedule(schedule, f, b, w, c):
result = []
stage_order = []
local_prev = {}
stages = len(schedule)
for sid, stage in enumerate(schedule):
counter = Counter()
order = []
for p in stage:
if not p.strip():
continue
mb = counter.get(p, 0)
if order:
local_prev[(sid, p, mb)] = order[-1]
order.append((p, mb))
counter.update(p)
stage_order.append(order)
nmb = max(counter.values())
time_map = {}
cost = {
'F': f,
'B': b + w,
'f': f,
'b': b + w,
}
def get_time(stage, type, mb):
if (stage, type, mb) in time_map:
return time_map.get((stage, type, mb))
time = 0
if (stage, type, mb) in local_prev:
time = get_time(stage, *local_prev[(stage, type, mb)])
if type in "FB"and stage > 0:
time = max(time, get_time(stage - 1, type, mb) + c)
if type in "fb" and stage + 1< len(schedule):
time = max(time, get_time(stage + 1, type, mb) + c)
time_map[(stage, type, mb)] = time + cost[type]
return time_map[(stage, type, mb)]
r = 0
for sid, stage in enumerate(schedule):
r = max(get_time(sid, 'b', nmb - 1) - get_time(sid, 'F', 0) + f, r)
for sid, stage in enumerate(stage_order):
result_stage = []
for p, mb in stage:
result_stage.append(ScheduledNode(
p.upper(),
p in "fBW",
sid,
mb,
get_time(sid, p, mb) - cost[p],
get_time(sid, p, mb)
)
)
result.append(result_stage)
return result
def get_pattern_str(pos):
pattern = [" "] * pattern_size
notations = "FfBbWw"
for i, v in enumerate(pos):
if v < 0:
continue
pattern[v] = notations[i]
_str = ""
for v in pattern:
_str += v
return _str
def init_repeated_schedule(p, m, patterns):
repeated = []
_len = 4 * p + m + 1
for i in range(p):
str_i = get_pattern_str(patterns[i]) * _len
repeated_i = []
for v in str_i:
repeated_i.append(v)
repeated.append(repeated_i)
return repeated
def clear_invalid(repeated, stage, pos, offset=-1):
while 0 <= pos < len(repeated[stage]):
repeated[stage][pos] = ' '
pos += offset * pattern_size
return repeated
def clear_invalid_index(repeated, m):
p = len(repeated)
index = pattern_size
for identifier in "FfBb":
if identifier in "FB":
_iter = range(p)
else:
_iter = range(p - 1, -1, -1)
for i in _iter:
for j in range(pattern_size):
if repeated[i][index] == identifier:
clear_invalid(repeated, i, index - pattern_size, offset=-1)
clear_invalid(repeated, i, index + pattern_size * m, offset=1)
index += 1
if identifier in "Bb":
w_identifier = {'B': 'W', 'b': 'w'}[identifier]
for k in range(pattern_size):
if repeated[i][index + k] == w_identifier:
clear_invalid(repeated, i, index + k - pattern_size, offset=-1)
clear_invalid(repeated, i, index + k + pattern_size * m, offset=1)
break
break
index += 1
return repeated
def process_warmup_without_increasing_peak_mem(schedules, m):
peak_mem = 0
mem = [[0 for _ in range(len(schedules[0]))] for _ in range(len(schedules))]
loc = [[{key: -1 for key in ('F', 'f', 'B', 'b', 'W', 'w')} for _ in range(m + 2)] for _ in range(len(schedules))]
cntr = [{key: 0 for key in ('F', 'f', 'B', 'b', 'W', 'w')} for _ in range(len(schedules))]
for sid in range(len(schedules)):
cur = 0
for i in range(len(schedules[sid])):
if schedules[sid][i] in "Ff":
cur += 1
if schedules[sid][i] in "Ww":
cur -= 1
mem[sid][i] = cur
peak_mem = max(peak_mem, cur)
for i in range(len(schedules[0])):
for sid in range(len(schedules)):
if schedules[sid][i] == ' ':
continue
cntr[sid][schedules[sid][i]] += 1
cnt = cntr[sid][schedules[sid][i]]
pos = -1
if cnt > 1:
pos = loc[sid][cnt - 1][schedules[sid][i]]
if schedules[sid][i] == 'W':
pos = max(pos, loc[sid][cnt]['B'])
if schedules[sid][i] == 'w':
pos = max(pos, loc[sid][cnt]['b'])
if schedules[sid][i] == 'F' and sid > 0:
pos = max(pos, loc[sid - 1][cnt]['F'])
if schedules[sid][i] == 'f':
if sid != len(schedules) - 1:
pos = max(pos, loc[sid + 1][cnt]['f'])
else :
pos = max(pos, loc[sid][cnt]['F'])
if schedules[sid][i] == 'B':
if sid != 0:
#Because B and W are always combined
pos = max(pos, loc[sid - 1][cnt]['W'])
else :
pos = max(pos, loc[sid][cnt]['f'])
if schedules[sid][i] == 'b':
if sid != len(schedules) - 1:
#Because B and W are always combined
pos = max(pos, loc[sid + 1][cnt]['w'])
else :
pos = max(pos, loc[sid][cnt]['W'])
pos += 1
while schedules[sid][pos] != ' ' and pos < i:
pos += 1
if schedules[sid][i] in "Bb":
while pos < i and (schedules[sid][pos] != ' ' or schedules[sid][pos + 1] != ' '):
pos += 1
if pos == i:
loc[sid][cnt][schedules[sid][i]] = i
continue
if schedules[sid][i] in "BbWw":
schedules[sid][pos] = schedules[sid][i]
schedules[sid][i] = ' '
if schedules[sid][pos] in "Ww":
for j in range(pos, i):
mem[sid][j] -= 1
loc[sid][cnt][schedules[sid][pos]] = pos
continue
#If F or f:
place = i
while place > pos and mem[sid][place - 1] < peak_mem:
place -= 1
while place < i and schedules[sid][place] != ' ':
place += 1
if place == i:
loc[sid][cnt][schedules[sid][i]] = i
continue
pos = place
schedules[sid][pos] = schedules[sid][i]
schedules[sid][i] = ' '
for j in range(pos, i):
mem[sid][j] += 1
loc[sid][cnt][schedules[sid][pos]] = pos
return schedules
def schedule_by_pattern(p, m, patterns):
schedules = init_repeated_schedule(p, m, patterns)
schedules = clear_invalid_index(schedules, m)
schedules = process_warmup_without_increasing_peak_mem(schedules, m)
for sid in range(len(schedules)):
cnt = {_id: 0 for _id in "FfBbWw"}
for i in range(len(schedules[sid])):
if(schedules[sid][i] == ' '):
continue
if cnt[schedules[sid][i]] >= m:
schedules[sid][i] = ' '
else:
cnt[schedules[sid][i]] += 1
return schedules
def create_whole_pattern(p):
whole_pattern = [[0 for _ in range(6)] for _ in range(p)]
now = 0
for i in range(p):
now += 1
whole_pattern[i][0] = now
for i in range(p):
now += 1
whole_pattern[p - 1 - i][1] = now
now += 1
if p % 3 == 0:
now += 3
cyc = (3 - (p + 2) % 3) % 3
for i in range(p):
whole_pattern[i][2], whole_pattern[i][4] = now, now + 1
cyc += 1
now += 2
if(cyc == 3):
cyc = 0
now += 3
for i in range(p):
whole_pattern[p - 1 - i][3], whole_pattern[p - 1 - i][5] = now, now + 1
cyc += 1
now += 2
if(cyc == 3):
cyc = 0
now += 3
for sid in range(p):
for i in range(6):
whole_pattern[sid][i] %= 6
return whole_pattern
def schedule(p, m, cost):
whole_pattern = create_whole_pattern(p)
s = schedule_by_pattern(p, m, whole_pattern)
for sid in range(len(s)):
for i in range(len(s[sid])):
if s[sid][i] in "Ww":
s[sid][i] = ' '
res = transform_schedule(s, *cost)
return res