MuseVSpace / MuseV /musev /utils /timesteps_util.py
anchorxia's picture
add musev
96d7ad8
from typing import List, Literal
import numpy as np
def generate_parameters_with_timesteps(
start: int,
num: int,
stop: int = None,
method: Literal["linear", "two_stage", "three_stage", "fix_two_stage"] = "linear",
n_fix_start: int = 3,
) -> List[float]:
if stop is None or start == stop:
params = [start] * num
else:
if method == "linear":
params = generate_linear_parameters(start, stop, num)
elif method == "two_stage":
params = generate_two_stages_parameters(start, stop, num)
elif method == "three_stage":
params = generate_three_stages_parameters(start, stop, num)
elif method == "fix_two_stage":
params = generate_fix_two_stages_parameters(start, stop, num, n_fix_start)
else:
raise ValueError(
f"now only support linear, two_stage, three_stage, but given{method}"
)
return params
def generate_linear_parameters(start, stop, num):
parames = list(
np.linspace(
start=start,
stop=stop,
num=num,
)
)
return parames
def generate_two_stages_parameters(start, stop, num):
num_start = num // 2
num_end = num - num_start
parames = [start] * num_start + [stop] * num_end
return parames
def generate_fix_two_stages_parameters(start, stop, num, n_fix_start: int) -> List:
num_start = n_fix_start
num_end = num - num_start
parames = [start] * num_start + [stop] * num_end
return parames
def generate_three_stages_parameters(start, stop, num):
middle = (start + stop) // 2
num_start = num // 3
num_middle = num_start
num_end = num - num_start - num_middle
parames = [start] * num_start + [middle] * num_middle + [stop] * num_end
return parames