x_alhdrawi1 / app.py
Alhdrawi's picture
Update app.py
48badb6 verified
import gradio as gr
import torch
import os
import urllib.request
from torchvision import transforms
from PIL import Image
import torch.nn as nn
# إعدادات النموذج
REPO_ID = "Alhdrawi/x_alhdrawi"
MODEL_FILE = "best_128_0.0002_original_15000_0.859.pt"
MODEL_URL = f"https://huggingface.co/{REPO_ID}/resolve/main/{MODEL_FILE}"
MODEL_LOCAL_PATH = f"/tmp/{MODEL_FILE}"
# قائمة الأمراض التي يتوقعها النموذج
diseases = [
"Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion",
"Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass", "Nodule",
"Pleural_Thickening", "Pneumonia", "Pneumothorax"
]
# تحويل الصورة مثل ما دربت النموذج
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485], std=[0.229])
])
# تعريف بنية النموذج (نفس اللي استخدمته وقت التدريب)
class SimpleCNN(nn.Module):
def __init__(self, num_classes=14):
super(SimpleCNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1))
)
self.classifier = nn.Linear(64, num_classes)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
# تحميل النموذج
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=len(diseases)).to(device)
def download_and_load_model():
if not os.path.exists(MODEL_LOCAL_PATH):
print(f"Downloading model from {MODEL_URL}")
urllib.request.urlretrieve(MODEL_URL, MODEL_LOCAL_PATH)
state_dict = torch.load(MODEL_LOCAL_PATH, map_location=device)
model.load_state_dict(state_dict)
model.eval()
print(f"✅ Model loaded from {MODEL_FILE}")
# دالة التنبؤ
def predict(image):
img = transform(image.convert("L")).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img)
probs = torch.sigmoid(outputs).cpu().squeeze().numpy()
results = {d: round(float(p), 3) for d, p in zip(diseases, probs)}
return results
# تحميل النموذج عند بدء التشغيل
download_and_load_model()
# واجهة Gradio
with gr.Blocks() as demo:
gr.Markdown(f"## 🧠 CheXzero | النموذج المستخدم: `{MODEL_FILE}`")
with gr.Row():
image_input = gr.Image(type="pil", label="صورة أشعة X-Ray")
output = gr.Label(num_top_classes=5)
image_input.change(fn=predict, inputs=image_input, outputs=output)
demo.launch()