File size: 2,999 Bytes
285ea09
b94db6f
 
d3d18f9
285ea09
b94db6f
285ea09
 
b94db6f
 
d3d18f9
 
b94db6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3d18f9
 
b94db6f
 
 
 
 
 
 
285ea09
b94db6f
 
 
 
285ea09
b94db6f
285ea09
b94db6f
285ea09
b94db6f
 
 
 
 
 
 
e42e16e
b94db6f
e42e16e
b94db6f
e42e16e
285ea09
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch

# Kiểm tra thiết bị sử dụng GPU hay CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# Tải các mô hình phân loại ảnh và video từ Hugging Face
image_classifier = pipeline("image-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1)

# Sử dụng mô hình phân loại video có sẵn trên Hugging Face
video_classifier = pipeline("video-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1)

# Hàm phân loại ảnh
def classify_image(image, model_name):
    # Tùy chọn chọn model ảnh khác nếu người dùng yêu cầu
    if model_name == "ViT":
        classifier = image_classifier
    else:
        classifier = image_classifier  # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác

    # Phân loại ảnh
    result = classifier(image)
    return result[0]['label'], result[0]['score']

# Hàm phân loại video
def classify_video(video, model_name):
    # Tùy chọn chọn model video khác nếu người dùng yêu cầu
    if model_name == "ViT":
        classifier = video_classifier
    else:
        classifier = video_classifier  # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác

    # Phân loại video trực tiếp mà không cần trích xuất frame
    result = classifier(video)
    return result[0]['label'], result[0]['score']

# Giao diện Gradio
with gr.Blocks() as demo:
    with gr.TabbedInterface() as tabs:
        with gr.TabItem("Image Classification"):
            gr.Markdown("### Upload an image for classification")
            with gr.Row():
                model_choice_image = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
                image_input = gr.Image(type="pil", label="Upload Image")
                image_output_label = gr.Textbox(label="Prediction")
                image_output_score = gr.Textbox(label="Confidence Score")

            classify_image_button = gr.Button("Classify Image")

            classify_image_button.click(classify_image, inputs=[image_input, model_choice_image], outputs=[image_output_label, image_output_score])

        with gr.TabItem("Video Classification"):
            gr.Markdown("### Upload a video for classification")
            with gr.Row():
                model_choice_video = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
                video_input = gr.Video(label="Upload Video")
                video_output_label = gr.Textbox(label="Prediction")
                video_output_score = gr.Textbox(label="Confidence Score")

            classify_video_button = gr.Button("Classify Video")

            classify_video_button.click(classify_video, inputs=[video_input, model_choice_video], outputs=[video_output_label, video_output_score])

    demo.launch()