import gradio as gr import torch from transformers import pipeline from moviepy.editor import VideoFileClip from PIL import Image import os # Kiểm tra thiết bị sử dụng GPU hay CPU device = "cuda" if torch.cuda.is_available() else "cpu" # Tải các mô hình phân loại ảnh và video từ Hugging Face image_classifier = pipeline("image-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1) video_classifier = pipeline("video-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1) # Hàm phân loại ảnh def classify_image(image, model_name): # Tùy chọn chọn model ảnh khác nếu người dùng yêu cầu if model_name == "ViT": classifier = image_classifier else: classifier = image_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác # Phân loại ảnh result = classifier(image) return result[0]['label'], result[0]['score'] # Hàm phân loại video def classify_video(video, model_name): # Tùy chọn chọn model video khác nếu người dùng yêu cầu if model_name == "ViT": classifier = video_classifier else: classifier = video_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác # Đọc video và trích xuất các frame (ở đây đơn giản là lấy 1 frame đầu tiên) clip = VideoFileClip(video.name) frame = clip.get_frame(0) image = Image.fromarray(frame) # Phân loại frame đầu tiên của video result = classifier(image) return result[0]['label'], result[0]['score'] # Giao diện Gradio with gr.Blocks() as demo: with gr.TabbedInterface() as tabs: with gr.TabItem("Image Classification"): gr.Markdown("### Upload an image for classification") with gr.Row(): model_choice_image = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT") image_input = gr.Image(type="pil", label="Upload Image") image_output_label = gr.Textbox(label="Prediction") image_output_score = gr.Textbox(label="Confidence Score") classify_image_button = gr.Button("Classify Image") classify_image_button.click(classify_image, inputs=[image_input, model_choice_image], outputs=[image_output_label, image_output_score]) with gr.TabItem("Video Classification"): gr.Markdown("### Upload a video for classification") with gr.Row(): model_choice_video = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT") video_input = gr.Video(label="Upload Video") video_output_label = gr.Textbox(label="Prediction") video_output_score = gr.Textbox(label="Confidence Score") classify_video_button = gr.Button("Classify Video") classify_video_button.click(classify_video, inputs=[video_input, model_choice_video], outputs=[video_output_label, video_output_score]) demo.launch()