Alhdrawi commited on
Commit
d02841f
·
verified ·
1 Parent(s): 429b762

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -22
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import gradio as gr
2
  import torch
 
 
 
3
  from torchvision import transforms
4
  from PIL import Image
5
 
6
- # اسم الريبو الخاص بك على Hugging Face
7
  REPO_ID = "Alhdrawi/x_alhdrawi"
8
-
9
- # أسماء الملفات المرفوعة
10
  MODEL_FILES = [
11
  "best_128_0.0002_original_15000_0.859.pt",
12
  "best_128_0.0002_original_8000_0.857.pt",
@@ -20,34 +20,44 @@ MODEL_FILES = [
20
  "best_64_5e-05_original_22000_0.864.pt",
21
  ]
22
 
23
- # نفس المعالجة المستخدمة
24
- transform = transforms.Compose([
25
- transforms.Resize((224, 224)),
26
- transforms.ToTensor(),
27
- transforms.Normalize(mean=[0.485], std=[0.229])
28
- ])
29
-
30
- # الأمراض اللي يشخصها
31
  diseases = [
32
  "Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion",
33
  "Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass", "Nodule",
34
  "Pleural_Thickening", "Pneumonia", "Pneumothorax"
35
  ]
36
 
 
 
 
 
 
 
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
  model = None
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- def load_model(checkpoint_file):
41
- global model
42
- print(f"Loading model: {checkpoint_file}")
43
- model_url = f"https://huggingface.co/{REPO_ID}/resolve/main/{checkpoint_file}"
44
- model = torch.load(model_url, map_location=device)
45
  model.eval()
46
- return f"تم تحميل النموذج: {checkpoint_file}"
 
47
 
48
  def predict(image):
49
  if model is None:
50
- return "الرجاء تحميل نموذج أولاً"
51
  img = transform(image.convert("L")).unsqueeze(0).to(device)
52
  with torch.no_grad():
53
  outputs = model(img)
@@ -56,16 +66,15 @@ def predict(image):
56
  return results
57
 
58
  with gr.Blocks() as demo:
59
- gr.Markdown("## 🧠 CheXzero | اختر النموذج وارفع صورة أشعة")
60
  with gr.Row():
61
- model_selector = gr.Dropdown(label="اختر نموذج", choices=MODEL_FILES, value=MODEL_FILES[-1])
62
- load_button = gr.Button("تحميل النموذج")
63
  load_status = gr.Textbox(label="الحالة", interactive=False)
64
  with gr.Row():
65
  image_input = gr.Image(type="pil", label="صورة الأشعة")
66
  output = gr.Label(num_top_classes=5)
67
 
68
- load_button.click(fn=load_model, inputs=model_selector, outputs=load_status)
69
  image_input.change(fn=predict, inputs=image_input, outputs=output)
70
 
71
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import random
4
+ import requests
5
+ import os
6
  from torchvision import transforms
7
  from PIL import Image
8
 
 
9
  REPO_ID = "Alhdrawi/x_alhdrawi"
 
 
10
  MODEL_FILES = [
11
  "best_128_0.0002_original_15000_0.859.pt",
12
  "best_128_0.0002_original_8000_0.857.pt",
 
20
  "best_64_5e-05_original_22000_0.864.pt",
21
  ]
22
 
 
 
 
 
 
 
 
 
23
  diseases = [
24
  "Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Effusion",
25
  "Emphysema", "Fibrosis", "Hernia", "Infiltration", "Mass", "Nodule",
26
  "Pleural_Thickening", "Pneumonia", "Pneumothorax"
27
  ]
28
 
29
+ transform = transforms.Compose([
30
+ transforms.Resize((224, 224)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.485], std=[0.229])
33
+ ])
34
+
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  model = None
37
+ current_model_path = None
38
+
39
+ def download_model_file(filename):
40
+ url = f"https://huggingface.co/{REPO_ID}/resolve/main/{filename}"
41
+ local_path = f"./{filename}"
42
+ if not os.path.exists(local_path):
43
+ print(f"Downloading model from {url}")
44
+ response = requests.get(url)
45
+ with open(local_path, "wb") as f:
46
+ f.write(response.content)
47
+ return local_path
48
 
49
+ def load_model(_):
50
+ global model, current_model_path
51
+ selected_model = random.choice(MODEL_FILES)
52
+ local_path = download_model_file(selected_model)
53
+ model = torch.load(local_path, map_location=device)
54
  model.eval()
55
+ current_model_path = selected_model
56
+ return f"تم تحميل النموذج العشوائي: {selected_model}"
57
 
58
  def predict(image):
59
  if model is None:
60
+ return "الرجاء تحميل نموذج أولاً"
61
  img = transform(image.convert("L")).unsqueeze(0).to(device)
62
  with torch.no_grad():
63
  outputs = model(img)
 
66
  return results
67
 
68
  with gr.Blocks() as demo:
69
+ gr.Markdown("## 🧠 CheXzero | اختر صورة أشعة لتحليلها عبر نموذج يتم تحميله عشوائيًا")
70
  with gr.Row():
71
+ load_button = gr.Button("تحميل نموذج عشوائي")
 
72
  load_status = gr.Textbox(label="الحالة", interactive=False)
73
  with gr.Row():
74
  image_input = gr.Image(type="pil", label="صورة الأشعة")
75
  output = gr.Label(num_top_classes=5)
76
 
77
+ load_button.click(fn=load_model, inputs=None, outputs=load_status)
78
  image_input.change(fn=predict, inputs=image_input, outputs=output)
79
 
80
  demo.launch()