File size: 1,836 Bytes
96d7ad8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from typing import List, Literal
import numpy as np


def generate_parameters_with_timesteps(
    start: int,
    num: int,
    stop: int = None,
    method: Literal["linear", "two_stage", "three_stage", "fix_two_stage"] = "linear",
    n_fix_start: int = 3,
) -> List[float]:
    if stop is None or start == stop:
        params = [start] * num
    else:
        if method == "linear":
            params = generate_linear_parameters(start, stop, num)
        elif method == "two_stage":
            params = generate_two_stages_parameters(start, stop, num)
        elif method == "three_stage":
            params = generate_three_stages_parameters(start, stop, num)
        elif method == "fix_two_stage":
            params = generate_fix_two_stages_parameters(start, stop, num, n_fix_start)
        else:
            raise ValueError(
                f"now only support linear, two_stage, three_stage, but given{method}"
            )
    return params


def generate_linear_parameters(start, stop, num):
    parames = list(
        np.linspace(
            start=start,
            stop=stop,
            num=num,
        )
    )
    return parames


def generate_two_stages_parameters(start, stop, num):
    num_start = num // 2
    num_end = num - num_start
    parames = [start] * num_start + [stop] * num_end
    return parames


def generate_fix_two_stages_parameters(start, stop, num, n_fix_start: int) -> List:
    num_start = n_fix_start
    num_end = num - num_start
    parames = [start] * num_start + [stop] * num_end
    return parames


def generate_three_stages_parameters(start, stop, num):
    middle = (start + stop) // 2
    num_start = num // 3
    num_middle = num_start
    num_end = num - num_start - num_middle
    parames = [start] * num_start + [middle] * num_middle + [stop] * num_end
    return parames