File size: 3,524 Bytes
d0269f7
44efa53
d0269f7
1dfa4f0
 
4f58d50
afc843a
06c102b
1f87132
1dfa4f0
06c102b
5af1ec5
06c102b
 
8c362bb
c52b92b
06c102b
 
 
8c362bb
 
 
 
 
 
 
 
 
 
 
4f58d50
b27f379
4b2f4dc
c9095ee
4f58d50
69d1f9e
f49b111
 
69d1f9e
 
 
 
 
 
 
 
1dfa4f0
 
69d1f9e
 
 
 
 
 
 
 
 
c9095ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69d1f9e
 
1dfa4f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import json
import shutil
import gradio as gr
import random
from huggingface_hub import Repository,HfApi
from huggingface_hub import snapshot_download
# from datasets import load_dataset
from datasets import config

hf_token = os.environ['hf_token']  # 确保环境变量中有你的令牌

local_dir = "VBench_sampled_video"  # 本地文件夹路径
# dataset = load_dataset("Vchitect/VBench_sampled_video")
# print(os.listdir("~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/"))
# root = "~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/"
# print(config.HF_DATASETS_CACHE)
# root = config.HF_DATASETS_CACHE
# print(root)
def print_directory_contents(path, indent=0):
    # 打印当前目录的内容
    try:
        for item in os.listdir(path):
            item_path = os.path.join(path, item)
            print('    ' * indent + item)  # 使用缩进打印文件或文件夹
            if os.path.isdir(item_path):  # 如果是目录,则递归调用
                print_directory_contents(item_path, indent + 1)
    except PermissionError:
        print('    ' * indent + "[权限错误,无法访问该目录]")

# 拉取数据集
os.makedirs(local_dir, exist_ok=True)
hf_api = HfApi(endpoint="https://huggingface.co", token=hf_token)
hf_api = HfApi(token=hf_token)
repo_id = "Vchitect/VBench_sampled_video"

model_names=['Gen-2','Gen-3']

with open("videos_by_dimension.json") as f:
    dimension = json.load(f)['videos_by_dimension']

# with open("all_videos.json") as f:
    # all_videos = json.load(f)

types = ['appearance_style', 'color', 'temporal_style', 'spatial_relationship', 'temporal_flickering', 'scene', 'multiple_objects', 'object_class', 'human_action', 'overall_consistency', 'subject_consistency']

def get_random_video():
    # 随机选择一个索引
    random_index = random.randint(0, len(types) - 1)
    type = types[random_index]
    # 随机选择一个Prompt
    random_index = random.randint(0, len(dimension[type]) - 1)
    prompt = dimension[type][random_index]
    # 随机一个模型
    random_index = random.randint(0, len(model_names) - 1)
    model_name =  model_names[random_index]

    

    video_path_subfolder = os.path.join(model_name, type)
    try:
        hf_api.hf_hub_download(
            repo_id = repo_id,
            filename = prompt,
            subfolder = video_path_subfolder,
            repo_type = dataset,
            local_dir = local_dir
        )
    except Exception as e:
        print(f"[PATH]{video_path_subfolder} NOT in hf repo, try {model_name}")
        print(e)
        video_path_subfolder = model_name
        try:
            hf_api.hf_hub_download(
                repo_id = repo_id,
                filename = prompt,
                subfolder = video_path_subfolder,
                repo_type = dataset,
                local_dir = local_dir
            )
        except Exception as e:
            print(e)
    # video_path = dataset['train'][random_index]['video_path']
    print('error:', video_path)
    return video_path

# Gradio 接口
def display_video():
    video_path = get_random_video()
    return video_path

interface = gr.Interface(fn=display_video, 
                         outputs=gr.Video(label="随机视频展示"),
                         inputs=[], 
                         title="随机视频展示",
                         description="从 Vchitect/VBench_sampled_video 数据集中随机展示一个视频。")

if __name__ == "__main__":
    interface.launch()