Questaaaa's picture
Update app.py
d1e8a6c verified
raw
history blame
1.3 kB
import gradio as gr
from transformers import AutoModelForImageClassification, AutoFeatureExtractor
import torch
from PIL import Image
import numpy as np
import json
import requests
# 加载模型和特征提取器
model_name = "microsoft/beit-base-patch16-224"
model = AutoModelForImageClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
# 获取 ImageNet 类别映射
LABELS_URL = "https://storage.googleapis.com/bit_models/imagenet21k_wordnet_id_map.json"
imagenet_classes = requests.get(LABELS_URL).json()
# 定义分类函数
def classify_image(image):
# 转换 PIL Image 为 numpy 数组
if isinstance(image, Image.Image):
image = np.array(image)
# 进行特征提取
inputs = feature_extractor(images=image, return_tensors="pt")
# 预测类别
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
# 获取类别名称
class_name = imagenet_classes.get(str(predicted_class_idx), "Unknown")
return f"Predicted class: {class_name} (ID: {predicted_class_idx})"
# 创建 Gradio 界面
demo = gr.Interface(fn=classify_image, inputs="image", outputs="text", title="Image Classification Demo")
demo.launch()