Alhdrawi commited on
Commit
48badb6
·
verified ·
1 Parent(s): 89f3ba0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -45
app.py CHANGED
@@ -1,46 +1,33 @@
1
  import gradio as gr
2
  import torch
3
- import random
 
4
  from torchvision import transforms
5
  from PIL import Image
6
- import urllib.request
7
- import os
8
 
9
- # اسم الريبو الخاص بك على Hugging Face
10
  REPO_ID = "Alhdrawi/x_alhdrawi"
 
 
 
11
 
12
- # أسماء الملفات المرفوعة
13
- MODEL_FILES = [
14
- "best_128_0.0002_original_15000_0.859.pt",
15
- "best_128_0.0002_original_8000_0.857.pt",
16
- "best_128_5e-05_original_22000_0.855.pt",
17
- "best_64_0.0001_original_16000_0.861.pt",
18
- "best_64_0.0001_original_17000_0.863.pt",
19
- "best_64_0.0001_original_35000_0.864.pt",
20
- "best_64_0.0002_original_23000_0.854.pt",
21
- "best_64_5e-05_original_16000_0.858.pt",
22
- "best_64_5e-05_original_18000_0.862.pt",
23
- "best_64_5e-05_original_22000_0.864.pt",
24
  ]
25
 
26
- # نفس المعالجة المستخدمة
27
  transform = transforms.Compose([
28
  transforms.Resize((224, 224)),
29
  transforms.ToTensor(),
30
  transforms.Normalize(mean=[0.485], std=[0.229])
31
  ])
32
 
33
- # الأمراض اللي يشخصها
34
- diseases = [
35
- "Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion",
36
- "Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass", "Nodule",
37
- "Pleural_Thickening", "Pneumonia", "Pneumothorax"
38
- ]
39
-
40
- # تعرّف كلاس النموذج - لازم تعدله حسب المعمارية
41
- import torch.nn as nn
42
-
43
- class SimpleCNN(nn.Module): # عدل هذا الكلاس حسب اللي دربت عليه
44
  def __init__(self, num_classes=14):
45
  super(SimpleCNN, self).__init__()
46
  self.features = nn.Sequential(
@@ -59,28 +46,22 @@ class SimpleCNN(nn.Module): # عدل هذا الكلاس حسب اللي درب
59
  x = self.classifier(x)
60
  return x
61
 
 
62
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
- model = None
64
- selected_model_file = random.choice(MODEL_FILES)
65
 
66
  def download_and_load_model():
67
- global model
68
- url = f"https://huggingface.co/{REPO_ID}/resolve/main/{selected_model_file}"
69
- local_path = f"/tmp/{selected_model_file}"
70
-
71
- if not os.path.exists(local_path):
72
- print(f"Downloading model from {url}")
73
- urllib.request.urlretrieve(url, local_path)
74
 
75
- model = SimpleCNN(num_classes=len(diseases)).to(device)
76
- state_dict = torch.load(local_path, map_location=device)
77
  model.load_state_dict(state_dict)
78
  model.eval()
79
- print(f"✅ Model loaded: {selected_model_file}")
80
 
 
81
  def predict(image):
82
- if model is None:
83
- return "النموذج غير محمّل"
84
  img = transform(image.convert("L")).unsqueeze(0).to(device)
85
  with torch.no_grad():
86
  outputs = model(img)
@@ -88,13 +69,14 @@ def predict(image):
88
  results = {d: round(float(p), 3) for d, p in zip(diseases, probs)}
89
  return results
90
 
91
- # تحميل النموذج تلقائيًا عند بداية التشغيل
92
  download_and_load_model()
93
 
 
94
  with gr.Blocks() as demo:
95
- gr.Markdown(f"## 🧠 CheXzero | النموذج العشوائي المحمّل: `{selected_model_file}`")
96
  with gr.Row():
97
- image_input = gr.Image(type="pil", label="صورة الأشعة")
98
  output = gr.Label(num_top_classes=5)
99
 
100
  image_input.change(fn=predict, inputs=image_input, outputs=output)
 
1
  import gradio as gr
2
  import torch
3
+ import os
4
+ import urllib.request
5
  from torchvision import transforms
6
  from PIL import Image
7
+ import torch.nn as nn
 
8
 
9
+ # إعدادات النموذج
10
  REPO_ID = "Alhdrawi/x_alhdrawi"
11
+ MODEL_FILE = "best_128_0.0002_original_15000_0.859.pt"
12
+ MODEL_URL = f"https://huggingface.co/{REPO_ID}/resolve/main/{MODEL_FILE}"
13
+ MODEL_LOCAL_PATH = f"/tmp/{MODEL_FILE}"
14
 
15
+ # قائمة الأمراض التي يتوقعها النموذج
16
+ diseases = [
17
+ "Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion",
18
+ "Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass", "Nodule",
19
+ "Pleural_Thickening", "Pneumonia", "Pneumothorax"
 
 
 
 
 
 
 
20
  ]
21
 
22
+ # تحويل الصورة مثل ما دربت النموذج
23
  transform = transforms.Compose([
24
  transforms.Resize((224, 224)),
25
  transforms.ToTensor(),
26
  transforms.Normalize(mean=[0.485], std=[0.229])
27
  ])
28
 
29
+ # تعريف بنية النموذج (نفس اللي استخدمته وقت التدريب)
30
+ class SimpleCNN(nn.Module):
 
 
 
 
 
 
 
 
 
31
  def __init__(self, num_classes=14):
32
  super(SimpleCNN, self).__init__()
33
  self.features = nn.Sequential(
 
46
  x = self.classifier(x)
47
  return x
48
 
49
+ # تحميل النموذج
50
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+ model = SimpleCNN(num_classes=len(diseases)).to(device)
 
52
 
53
  def download_and_load_model():
54
+ if not os.path.exists(MODEL_LOCAL_PATH):
55
+ print(f"Downloading model from {MODEL_URL}")
56
+ urllib.request.urlretrieve(MODEL_URL, MODEL_LOCAL_PATH)
 
 
 
 
57
 
58
+ state_dict = torch.load(MODEL_LOCAL_PATH, map_location=device)
 
59
  model.load_state_dict(state_dict)
60
  model.eval()
61
+ print(f"✅ Model loaded from {MODEL_FILE}")
62
 
63
+ # دالة التنبؤ
64
  def predict(image):
 
 
65
  img = transform(image.convert("L")).unsqueeze(0).to(device)
66
  with torch.no_grad():
67
  outputs = model(img)
 
69
  results = {d: round(float(p), 3) for d, p in zip(diseases, probs)}
70
  return results
71
 
72
+ # تحميل النموذج عند بدء التشغيل
73
  download_and_load_model()
74
 
75
+ # واجهة Gradio
76
  with gr.Blocks() as demo:
77
+ gr.Markdown(f"## 🧠 CheXzero | النموذج المستخدم: `{MODEL_FILE}`")
78
  with gr.Row():
79
+ image_input = gr.Image(type="pil", label="صورة أشعة X-Ray")
80
  output = gr.Label(num_top_classes=5)
81
 
82
  image_input.change(fn=predict, inputs=image_input, outputs=output)