File size: 3,517 Bytes
d0269f7
 
 
 
1dfa4f0
 
4f58d50
06c102b
1f87132
1dfa4f0
06c102b
4f58d50
06c102b
 
8c362bb
c52b92b
06c102b
 
 
8c362bb
 
 
 
 
 
 
 
 
 
 
4f58d50
 
 
 
 
69d1f9e
4f58d50
 
 
 
 
 
 
 
 
69d1f9e
 
 
 
 
 
 
 
 
 
 
1dfa4f0
 
69d1f9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dfa4f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import shutil

shutil.move("repository.py", "/usr/local/lib/python3.10/site-packages/huggingface_hub/repository.py")
import gradio as gr
import random
from huggingface_hub import Repository,HfApi
# from datasets import load_dataset
from datasets import config

hf_token = os.environ['hf_token']  # 确保环境变量中有你的令牌
submission_url =  # 数据集的 URL
local_dir = "VBench_sampled_video"  # 本地文件夹路径
# dataset = load_dataset("Vchitect/VBench_sampled_video")
# print(os.listdir("~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/"))
# root = "~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/"
# print(config.HF_DATASETS_CACHE)
# root = config.HF_DATASETS_CACHE
# print(root)
def print_directory_contents(path, indent=0):
    # 打印当前目录的内容
    try:
        for item in os.listdir(path):
            item_path = os.path.join(path, item)
            print('    ' * indent + item)  # 使用缩进打印文件或文件夹
            if os.path.isdir(item_path):  # 如果是目录,则递归调用
                print_directory_contents(item_path, indent + 1)
    except PermissionError:
        print('    ' * indent + "[权限错误,无法访问该目录]")

# 拉取数据集
os.makedirs(local_dir, exists_ok=True)
hf_api = HfApi(token=hf_token)
repo_id = "Vchitect/VBench_sampled_video"
dataset_files = api.list_repo_files(repo_id=repo_id, token=hf_token, repo_type='dataset')

for file in dataset_files:
    print(file)
    api.download_file(repo_id=repo_id, filename=file, token=hf_token,repo_type='dataset',cache_dir=local_dir)

# 下载数据集文件
for file in dataset_files:
    api.download_file(repo_id=repo_id, filename=file, token=hf_token)

repo = HfApi(endpoint="https://huggingface.co", token=hf_token)

model_names = os.listdir(local_dir)

with open("videos_by_dimension.json") as f:
    dimension = json.load(f)['videos_by_dimension']

# with open("all_videos.json") as f:
    # all_videos = json.load(f)

types = ['appearance_style', 'color', 'temporal_style', 'spatial_relationship', 'temporal_flickering', 'scene', 'multiple_objects', 'object_class', 'human_action', 'overall_consistency', 'subject_consistency']

def get_random_video():
    # 随机选择一个索引
    random_index = random.randint(0, len(types) - 1)
    type = types[random_index]
    # 随机选择一个Prompt
    random_index = random.randint(0, len(dimension[type]) - 1)
    prompt = dimension[type][random_index]
    # 随机一个模型
    random_index = random.randint(0, len(model_names) - 1)
    model_name =  model_names[random_index]

    video_path = os.path.join(model_name, type, prompt)
    if os.path.exists(video_path):
        print(video_path)
        return video_path
    else:
        video_path = os.path.join(model_name, prompt)
        if os.path.exists(video_path):
            print(video_path)
            return video_path
    # video_path = dataset['train'][random_index]['video_path']
    print('error:', video_path)
    return video_path

# Gradio 接口
def display_video():
    video_path = get_random_video()
    return video_path

interface = gr.Interface(fn=display_video, 
                         outputs=gr.Video(label="随机视频展示"),
                         inputs=[], 
                         title="随机视频展示",
                         description="从 Vchitect/VBench_sampled_video 数据集中随机展示一个视频。")

if __name__ == "__main__":
    interface.launch()