File size: 2,999 Bytes
285ea09 b94db6f d3d18f9 285ea09 b94db6f 285ea09 b94db6f d3d18f9 b94db6f d3d18f9 b94db6f 285ea09 b94db6f 285ea09 b94db6f 285ea09 b94db6f 285ea09 b94db6f e42e16e b94db6f e42e16e b94db6f e42e16e 285ea09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
# Kiểm tra thiết bị sử dụng GPU hay CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Tải các mô hình phân loại ảnh và video từ Hugging Face
image_classifier = pipeline("image-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1)
# Sử dụng mô hình phân loại video có sẵn trên Hugging Face
video_classifier = pipeline("video-classification", model="google/vit-base-patch16-224-in21k", device=0 if device == "cuda" else -1)
# Hàm phân loại ảnh
def classify_image(image, model_name):
# Tùy chọn chọn model ảnh khác nếu người dùng yêu cầu
if model_name == "ViT":
classifier = image_classifier
else:
classifier = image_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác
# Phân loại ảnh
result = classifier(image)
return result[0]['label'], result[0]['score']
# Hàm phân loại video
def classify_video(video, model_name):
# Tùy chọn chọn model video khác nếu người dùng yêu cầu
if model_name == "ViT":
classifier = video_classifier
else:
classifier = video_classifier # Chỉnh sửa ở đây nếu muốn hỗ trợ thêm các mô hình khác
# Phân loại video trực tiếp mà không cần trích xuất frame
result = classifier(video)
return result[0]['label'], result[0]['score']
# Giao diện Gradio
with gr.Blocks() as demo:
with gr.TabbedInterface() as tabs:
with gr.TabItem("Image Classification"):
gr.Markdown("### Upload an image for classification")
with gr.Row():
model_choice_image = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
image_input = gr.Image(type="pil", label="Upload Image")
image_output_label = gr.Textbox(label="Prediction")
image_output_score = gr.Textbox(label="Confidence Score")
classify_image_button = gr.Button("Classify Image")
classify_image_button.click(classify_image, inputs=[image_input, model_choice_image], outputs=[image_output_label, image_output_score])
with gr.TabItem("Video Classification"):
gr.Markdown("### Upload a video for classification")
with gr.Row():
model_choice_video = gr.Dropdown(choices=["ViT", "ResNet"], label="Choose a Model", value="ViT")
video_input = gr.Video(label="Upload Video")
video_output_label = gr.Textbox(label="Prediction")
video_output_score = gr.Textbox(label="Confidence Score")
classify_video_button = gr.Button("Classify Video")
classify_video_button.click(classify_video, inputs=[video_input, model_choice_video], outputs=[video_output_label, video_output_score])
demo.launch()
|