Spaces:
Sleeping
Sleeping
File size: 1,880 Bytes
07c6a04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from utils import generate_func, read_prompt_list
import videosys
from videosys import OpenSoraConfig, OpenSoraPipeline
from videosys.models.open_sora import OpenSoraPABConfig
def attention_ablation_func(pab_kwargs, prompt_list, output_dir):
pab_config = OpenSoraPABConfig(**pab_kwargs)
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
pipeline = OpenSoraPipeline(config)
generate_func(pipeline, prompt_list, output_dir)
def main(prompt_list):
# spatial
gap_list = [2, 3, 4, 5]
for gap in gap_list:
pab_kwargs = {
"spatial_broadcast": True,
"spatial_gap": gap,
"temporal_broadcast": False,
"cross_broadcast": False,
"mlp_skip": False,
}
output_dir = f"./samples/attention_ablation/spatial_g{gap}"
attention_ablation_func(pab_kwargs, prompt_list, output_dir)
# temporal
gap_list = [3, 4, 5, 6]
for gap in gap_list:
pab_kwargs = {
"spatial_broadcast": False,
"temporal_broadcast": True,
"temporal_gap": gap,
"cross_broadcast": False,
"mlp_skip": False,
}
output_dir = f"./samples/attention_ablation/temporal_g{gap}"
attention_ablation_func(pab_kwargs, prompt_list, output_dir)
# cross
gap_list = [5, 6, 7, 8]
for gap in gap_list:
pab_kwargs = {
"spatial_broadcast": False,
"temporal_broadcast": False,
"cross_broadcast": True,
"cross_gap": gap,
"mlp_skip": False,
}
output_dir = f"./samples/attention_ablation/cross_g{gap}"
attention_ablation_func(pab_kwargs, prompt_list, output_dir)
if __name__ == "__main__":
videosys.initialize(42)
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
main(prompt_list)
|