File size: 3,126 Bytes
1dfa4f0
 
8c362bb
 
7fc7ae9
1dfa4f0
8c362bb
 
 
7a7e0cf
8c362bb
c52b92b
 
 
8c362bb
 
 
 
 
 
 
 
 
 
 
 
 
 
69d1f9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dfa4f0
 
69d1f9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dfa4f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr
import random
# from huggingface_hub import Repository
from datasets import load_dataset
import os

# hf_token = os.environ['hf_token']  # 确保环境变量中有你的令牌
# submission_url = "Vchitect/VBench_sampled_video"  # 数据集的 URL
# local_dir = "VBench_sampled_video"  # 本地文件夹路径
dataset = load_dataset("Vchitect/VBench_sampled_video")
# print(os.listdir("~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/"))
# root = "~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/"
root = os.environ['HF_DATASETS_CACHE']
print(root)
def print_directory_contents(path, indent=0):
    # 打印当前目录的内容
    try:
        for item in os.listdir(path):
            item_path = os.path.join(path, item)
            print('    ' * indent + item)  # 使用缩进打印文件或文件夹
            if os.path.isdir(item_path):  # 如果是目录,则递归调用
                print_directory_contents(item_path, indent + 1)
    except PermissionError:
        print('    ' * indent + "[权限错误,无法访问该目录]")

# 调用函数,传入你想要打印的目录路径
print_directory_contents(root)  # 替换为实际路径
# local_dir = 

# 克隆数据集
submission_repo = Repository(local_dir=local_dir, clone_from=submission_url, use_auth_token=hf_token, repo_type="dataset")
submission_repo.git_pull()  # 更新本地仓库

model_names = os.listdir(local_dir)

with open("videos_by_dimension.json") as f:
    dimension = json.load(f)['videos_by_dimension']

# with open("all_videos.json") as f:
    # all_videos = json.load(f)

types = ['appearance_style', 'color', 'temporal_style', 'spatial_relationship', 'temporal_flickering', 'scene', 'multiple_objects', 'object_class', 'human_action', 'overall_consistency', 'subject_consistency']

def get_random_video():
    # 随机选择一个索引
    random_index = random.randint(0, len(types) - 1)
    type = types[random_index]
    # 随机选择一个Prompt
    random_index = random.randint(0, len(dimension[type]) - 1)
    prompt = dimension[type][random_index]
    # 随机一个模型
    random_index = random.randint(0, len(model_names) - 1)
    model_name =  model_names[random_index]

    video_path = os.path.join(model_name, type, prompt)
    if os.path.exists(video_path):
        print(video_path)
        return video_path
    else:
        video_path = os.path.join(model_name, prompt)
        if os.path.exists(video_path):
            print(video_path)
            return video_path
    # video_path = dataset['train'][random_index]['video_path']
    print('error:', video_path)
    return video_path

# Gradio 接口
def display_video():
    video_path = get_random_video()
    return video_path

interface = gr.Interface(fn=display_video, 
                         outputs=gr.Video(label="随机视频展示"),
                         inputs=[], 
                         title="随机视频展示",
                         description="从 Vchitect/VBench_sampled_video 数据集中随机展示一个视频。")

if __name__ == "__main__":
    interface.launch()