Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import os | |
import urllib.request | |
from torchvision import transforms | |
from PIL import Image | |
import torch.nn as nn | |
# إعدادات النموذج | |
REPO_ID = "Alhdrawi/x_alhdrawi" | |
MODEL_FILE = "best_128_0.0002_original_15000_0.859.pt" | |
MODEL_URL = f"https://huggingface.co/{REPO_ID}/resolve/main/{MODEL_FILE}" | |
MODEL_LOCAL_PATH = f"/tmp/{MODEL_FILE}" | |
# قائمة الأمراض التي يتوقعها النموذج | |
diseases = [ | |
"Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion", | |
"Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass", "Nodule", | |
"Pleural_Thickening", "Pneumonia", "Pneumothorax" | |
] | |
# تحويل الصورة مثل ما دربت النموذج | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485], std=[0.229]) | |
]) | |
# تعريف بنية النموذج (نفس اللي استخدمته وقت التدريب) | |
class SimpleCNN(nn.Module): | |
def __init__(self, num_classes=14): | |
super(SimpleCNN, self).__init__() | |
self.features = nn.Sequential( | |
nn.Conv2d(1, 32, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(kernel_size=2), | |
nn.Conv2d(32, 64, kernel_size=3, padding=1), | |
nn.ReLU(inplace=True), | |
nn.AdaptiveAvgPool2d((1, 1)) | |
) | |
self.classifier = nn.Linear(64, num_classes) | |
def forward(self, x): | |
x = self.features(x) | |
x = x.view(x.size(0), -1) | |
x = self.classifier(x) | |
return x | |
# تحميل النموذج | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = SimpleCNN(num_classes=len(diseases)).to(device) | |
def download_and_load_model(): | |
if not os.path.exists(MODEL_LOCAL_PATH): | |
print(f"Downloading model from {MODEL_URL}") | |
urllib.request.urlretrieve(MODEL_URL, MODEL_LOCAL_PATH) | |
state_dict = torch.load(MODEL_LOCAL_PATH, map_location=device) | |
model.load_state_dict(state_dict) | |
model.eval() | |
print(f"✅ Model loaded from {MODEL_FILE}") | |
# دالة التنبؤ | |
def predict(image): | |
img = transform(image.convert("L")).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
outputs = model(img) | |
probs = torch.sigmoid(outputs).cpu().squeeze().numpy() | |
results = {d: round(float(p), 3) for d, p in zip(diseases, probs)} | |
return results | |
# تحميل النموذج عند بدء التشغيل | |
download_and_load_model() | |
# واجهة Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown(f"## 🧠 CheXzero | النموذج المستخدم: `{MODEL_FILE}`") | |
with gr.Row(): | |
image_input = gr.Image(type="pil", label="صورة أشعة X-Ray") | |
output = gr.Label(num_top_classes=5) | |
image_input.change(fn=predict, inputs=image_input, outputs=output) | |
demo.launch() | |