Spaces:
Running
Running
import gradio as gr | |
import random | |
# from huggingface_hub import Repository | |
from datasets import load_dataset | |
import os | |
# hf_token = os.environ['hf_token'] # 确保环境变量中有你的令牌 | |
# submission_url = "Vchitect/VBench_sampled_video" # 数据集的 URL | |
# local_dir = "VBench_sampled_video" # 本地文件夹路径 | |
dataset = load_dataset("Vchitect/VBench_sampled_video") | |
# print(os.listdir("~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/")) | |
# root = "~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/" | |
root = os.environ['HF_DATASETS_CACHE'] | |
print(root) | |
def print_directory_contents(path, indent=0): | |
# 打印当前目录的内容 | |
try: | |
for item in os.listdir(path): | |
item_path = os.path.join(path, item) | |
print(' ' * indent + item) # 使用缩进打印文件或文件夹 | |
if os.path.isdir(item_path): # 如果是目录,则递归调用 | |
print_directory_contents(item_path, indent + 1) | |
except PermissionError: | |
print(' ' * indent + "[权限错误,无法访问该目录]") | |
# 调用函数,传入你想要打印的目录路径 | |
print_directory_contents(root) # 替换为实际路径 | |
# local_dir = | |
# 克隆数据集 | |
submission_repo = Repository(local_dir=local_dir, clone_from=submission_url, use_auth_token=hf_token, repo_type="dataset") | |
submission_repo.git_pull() # 更新本地仓库 | |
model_names = os.listdir(local_dir) | |
with open("videos_by_dimension.json") as f: | |
dimension = json.load(f)['videos_by_dimension'] | |
# with open("all_videos.json") as f: | |
# all_videos = json.load(f) | |
types = ['appearance_style', 'color', 'temporal_style', 'spatial_relationship', 'temporal_flickering', 'scene', 'multiple_objects', 'object_class', 'human_action', 'overall_consistency', 'subject_consistency'] | |
def get_random_video(): | |
# 随机选择一个索引 | |
random_index = random.randint(0, len(types) - 1) | |
type = types[random_index] | |
# 随机选择一个Prompt | |
random_index = random.randint(0, len(dimension[type]) - 1) | |
prompt = dimension[type][random_index] | |
# 随机一个模型 | |
random_index = random.randint(0, len(model_names) - 1) | |
model_name = model_names[random_index] | |
video_path = os.path.join(model_name, type, prompt) | |
if os.path.exists(video_path): | |
print(video_path) | |
return video_path | |
else: | |
video_path = os.path.join(model_name, prompt) | |
if os.path.exists(video_path): | |
print(video_path) | |
return video_path | |
# video_path = dataset['train'][random_index]['video_path'] | |
print('error:', video_path) | |
return video_path | |
# Gradio 接口 | |
def display_video(): | |
video_path = get_random_video() | |
return video_path | |
interface = gr.Interface(fn=display_video, | |
outputs=gr.Video(label="随机视频展示"), | |
inputs=[], | |
title="随机视频展示", | |
description="从 Vchitect/VBench_sampled_video 数据集中随机展示一个视频。") | |
if __name__ == "__main__": | |
interface.launch() |