File size: 1,880 Bytes
07c6a04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from utils import generate_func, read_prompt_list

import videosys
from videosys import OpenSoraConfig, OpenSoraPipeline
from videosys.models.open_sora import OpenSoraPABConfig


def attention_ablation_func(pab_kwargs, prompt_list, output_dir):
    pab_config = OpenSoraPABConfig(**pab_kwargs)
    config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
    pipeline = OpenSoraPipeline(config)

    generate_func(pipeline, prompt_list, output_dir)


def main(prompt_list):
    # spatial
    gap_list = [2, 3, 4, 5]
    for gap in gap_list:
        pab_kwargs = {
            "spatial_broadcast": True,
            "spatial_gap": gap,
            "temporal_broadcast": False,
            "cross_broadcast": False,
            "mlp_skip": False,
        }
        output_dir = f"./samples/attention_ablation/spatial_g{gap}"
        attention_ablation_func(pab_kwargs, prompt_list, output_dir)

    # temporal
    gap_list = [3, 4, 5, 6]
    for gap in gap_list:
        pab_kwargs = {
            "spatial_broadcast": False,
            "temporal_broadcast": True,
            "temporal_gap": gap,
            "cross_broadcast": False,
            "mlp_skip": False,
        }
        output_dir = f"./samples/attention_ablation/temporal_g{gap}"
        attention_ablation_func(pab_kwargs, prompt_list, output_dir)

    # cross
    gap_list = [5, 6, 7, 8]
    for gap in gap_list:
        pab_kwargs = {
            "spatial_broadcast": False,
            "temporal_broadcast": False,
            "cross_broadcast": True,
            "cross_gap": gap,
            "mlp_skip": False,
        }
        output_dir = f"./samples/attention_ablation/cross_g{gap}"
        attention_ablation_func(pab_kwargs, prompt_list, output_dir)


if __name__ == "__main__":
    videosys.initialize(42)
    prompt_list = read_prompt_list("vbench/VBench_full_info.json")
    main(prompt_list)