Spaces:
Running
Running
| import os | |
| import json | |
| import shutil | |
| import gradio as gr | |
| import random | |
| from huggingface_hub import Repository,HfApi | |
| from huggingface_hub import snapshot_download | |
| # from datasets import load_dataset | |
| from datasets import config | |
| hf_token = os.environ['hf_token'] # 确保环境变量中有你的令牌 | |
| local_dir = "VBench_sampled_video" # 本地文件夹路径 | |
| # dataset = load_dataset("Vchitect/VBench_sampled_video") | |
| # print(os.listdir("~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/")) | |
| # root = "~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/" | |
| # print(config.HF_DATASETS_CACHE) | |
| # root = config.HF_DATASETS_CACHE | |
| # print(root) | |
| def print_directory_contents(path, indent=0): | |
| # 打印当前目录的内容 | |
| try: | |
| for item in os.listdir(path): | |
| item_path = os.path.join(path, item) | |
| print(' ' * indent + item) # 使用缩进打印文件或文件夹 | |
| if os.path.isdir(item_path): # 如果是目录,则递归调用 | |
| print_directory_contents(item_path, indent + 1) | |
| except PermissionError: | |
| print(' ' * indent + "[权限错误,无法访问该目录]") | |
| # 拉取数据集 | |
| os.makedirs(local_dir, exist_ok=True) | |
| hf_api = HfApi(endpoint="https://huggingface.co", token=hf_token) | |
| hf_api = HfApi(token=hf_token) | |
| repo_id = "Vchitect/VBench_sampled_video" | |
| model_names=['Gen-2','Gen-3'] | |
| with open("videos_by_dimension.json") as f: | |
| dimension = json.load(f)['videos_by_dimension'] | |
| # with open("all_videos.json") as f: | |
| # all_videos = json.load(f) | |
| types = ['appearance_style', 'color', 'temporal_style', 'spatial_relationship', 'temporal_flickering', 'scene', 'multiple_objects', 'object_class', 'human_action', 'overall_consistency', 'subject_consistency'] | |
| def get_random_video(): | |
| # 随机选择一个索引 | |
| random_index = random.randint(0, len(types) - 1) | |
| type = types[random_index] | |
| # 随机选择一个Prompt | |
| random_index = random.randint(0, len(dimension[type]) - 1) | |
| prompt = dimension[type][random_index] | |
| # 随机一个模型 | |
| random_index = random.randint(0, len(model_names) - 1) | |
| model_name = model_names[random_index] | |
| video_path_subfolder = os.path.join(model_name, type) | |
| try: | |
| hf_api.hf_hub_download( | |
| repo_id = repo_id, | |
| filename = prompt, | |
| subfolder = video_path_subfolder, | |
| repo_type = dataset, | |
| local_dir = local_dir | |
| ) | |
| except Exception as e: | |
| print(f"[PATH]{video_path_subfolder} NOT in hf repo, try {model_name}") | |
| print(e) | |
| video_path_subfolder = model_name | |
| try: | |
| hf_api.hf_hub_download( | |
| repo_id = repo_id, | |
| filename = prompt, | |
| subfolder = video_path_subfolder, | |
| repo_type = dataset, | |
| local_dir = local_dir | |
| ) | |
| except Exception as e: | |
| print(e) | |
| # video_path = dataset['train'][random_index]['video_path'] | |
| print('error:', video_path) | |
| return video_path | |
| # Gradio 接口 | |
| def display_video(): | |
| video_path = get_random_video() | |
| return video_path | |
| interface = gr.Interface(fn=display_video, | |
| outputs=gr.Video(label="随机视频展示"), | |
| inputs=[], | |
| title="随机视频展示", | |
| description="从 Vchitect/VBench_sampled_video 数据集中随机展示一个视频。") | |
| if __name__ == "__main__": | |
| interface.launch() |