baliddeki commited on
Commit
d4d94d2
·
1 Parent(s): 0372187
Files changed (2) hide show
  1. README.md +79 -3
  2. app.py +91 -106
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Phronesis
3
- emoji: 🌖
4
  colorFrom: green
5
  colorTo: gray
6
  sdk: gradio
@@ -8,6 +8,82 @@ sdk_version: 5.4.0
8
  app_file: app.py
9
  pinned: false
10
  short_description: 'REPORT GEN AND CLASSIFICATION MODEL '
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Phronesis Medical Report Generator
3
+ emoji: 🧠
4
  colorFrom: green
5
  colorTo: gray
6
  sdk: gradio
 
8
  app_file: app.py
9
  pinned: false
10
  short_description: 'REPORT GEN AND CLASSIFICATION MODEL '
11
+
12
+ ---
13
+
14
+ # 🧠 Phronesis: Medical Image Diagnosis & Report Generator
15
+
16
+ **Phronesis** is a multimodal AI tool that classifies medical CT scan images (DICOM or standard formats) and generates diagnostic reports using a combination of video classification and medical language generation.
17
+
18
+ ---
19
+
20
+ ## 🚀 Demo
21
+
22
+ Upload a set of DICOM (`.dcm`, `.ima`) or image (`.png`, `.jpg`) files representing slices of a CT scan. The model will:
23
+
24
+ - 🏷️ Predict a class: **acute**, **normal**, **chronic**, or **lacunar**
25
+ - 📋 Generate a short **radiology report**
26
+
27
+ [Live App →](https://huggingface.co/spaces/baliddeki/phronesis-ml-endpoint)
28
+
29
+ ---
30
+
31
+ ## 🏗️ Model Architecture
32
+
33
+ - **Vision Backbone**: `3D ResNet-18` pretrained on Kinetics-400
34
+ - **Language Head**: `BioBART v2` (pretrained biomedical seq2seq model)
35
+ - **Bridge Module**: Custom `ImageToTextProjector` to align visual features with the language model
36
+ - **CombinedModel**: Unified architecture for classification + report generation
37
+
38
+ ---
39
+
40
+ ## 🧪 Tasks
41
+
42
+ - **Image Classification**: Categorizes brain CT scans into one of four classes.
43
+ - **Report Generation**: Produces diagnostic text conditioned on image features.
44
+
45
  ---
46
 
47
+ ## 🖼️ Input Format
48
+
49
+ - Minimum 1, maximum ~30 image slices per scan.
50
+ - Acceptable file formats:
51
+ - DICOM (`.dcm`, `.ima`)
52
+ - PNG, JPEG
53
+
54
+ The model will sample or pad the series to 16 frames for temporal context.
55
+
56
+ ---
57
+
58
+ ## 📦 Dependencies
59
+
60
+ This app uses:
61
+ - `torch`
62
+ - `transformers`
63
+ - `torchvision`
64
+ - `huggingface_hub`
65
+ - `pydicom`
66
+ - `gradio`
67
+ - `PIL`, `numpy`
68
+
69
+ ---
70
+
71
+ ## 🔐 Notes
72
+
73
+ - This demo loads a private model from the Hugging Face Hub. Set your `HF_TOKEN` as a secret for the space if needed.
74
+ - Do **not use for real clinical decisions** – intended for research/demo only.
75
+
76
+ ---
77
+
78
+ ## 🙋‍♂️ Credits
79
+
80
+ Developed by [@baliddeki](https://huggingface.co/baliddeki)
81
+
82
+ Model weights: [`baliddeki/phronesis-ml`](https://huggingface.co/baliddeki/phronesis-ml)
83
+ Language model: [`GanjinZero/biobart-v2-base`](https://huggingface.co/GanjinZero/biobart-v2-base)
84
+
85
+ ---
86
+
87
+ ## 📄 License
88
+
89
+ MIT or Apache 2.0 (add yours here)
app.py CHANGED
@@ -1,129 +1,114 @@
1
- #app.py
2
- import os
3
- import io
4
- import uvicorn
5
-
6
  import torch
7
- from fastapi import FastAPI, File, UploadFile, HTTPException
8
- from fastapi.responses import JSONResponse
9
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
- from torchvision import models, transforms
11
- from PIL import Image
12
  import numpy as np
 
 
 
 
 
13
  from huggingface_hub import hf_hub_download
 
14
  import pydicom
 
15
  import gc
16
- from model import CombinedModel, ImageToTextProjector
17
- from fastapi import FastAPI, Request
18
- from fastapi.middleware.cors import CORSMiddleware
19
 
 
20
 
21
- app = FastAPI()
 
22
 
23
- app.add_middleware(
24
- CORSMiddleware,
25
- allow_origins=["*"],
26
- allow_credentials=True,
27
- allow_methods=["*"],
28
- allow_headers=["*"],
29
- )
30
 
31
- @app.get("/")
32
- async def root(request: Request):
33
- return {"message": "Welcome to Phronesis"}
34
 
35
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- def dicom_to_png(dicom_data):
38
- try:
39
- dicom_file = pydicom.dcmread(dicom_data)
40
- if not hasattr(dicom_file, 'PixelData'):
41
- raise HTTPException(status_code=400, detail="No pixel data in DICOM file.")
42
-
43
- pixel_array = dicom_file.pixel_array.astype(np.float32)
44
- pixel_array = ((pixel_array - pixel_array.min()) / (pixel_array.ptp())) * 255.0
45
- pixel_array = pixel_array.astype(np.uint8)
46
-
47
- img = Image.fromarray(pixel_array).convert("L")
48
- return img
49
- except Exception as e:
50
- raise HTTPException(status_code=500, detail=f"Error converting DICOM to PNG: {e}")
51
-
52
- # Set up secure model initialization
53
- HF_TOKEN = os.getenv('HF_TOKEN')
54
- if not HF_TOKEN:
55
- raise ValueError("Missing Hugging Face token in environment variables.")
56
-
57
- try:
58
- report_generator_tokenizer = AutoTokenizer.from_pretrained(
59
- "baliddeki/phronesis-ml",
60
- token=HF_TOKEN if HF_TOKEN else None
61
- )
62
- video_model = models.video.r3d_18(weights="KINETICS400_V1")
63
- video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
64
- report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base")
65
- projector = ImageToTextProjector(512, report_generator.config.d_model)
66
- num_classes = 4
67
- combined_model = CombinedModel(video_model, report_generator, num_classes, projector, report_generator_tokenizer)
68
- model_file = hf_hub_download("baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN)
69
- state_dict = torch.load(model_file, map_location=device)
70
- combined_model.load_state_dict(state_dict)
71
- combined_model.eval()
72
- except Exception as e:
73
- raise SystemExit(f"Error loading models: {e}")
74
-
75
- image_transform = transforms.Compose([
76
- transforms.Resize((112, 112)),
77
- transforms.ToTensor(),
78
- transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989])
79
- ])
80
 
81
- class_names = ["acute", "normal", "chronic", "lacunar"]
 
 
 
 
 
82
 
83
- @app.post("/predict/")
84
- async def predict(files: list[UploadFile]):
85
- print(f"Received {len(files)} files")
86
- n_frames = 16
87
- images = []
88
-
89
- for file in files:
90
- ext = file.filename.split('.')[-1].lower()
91
- try:
92
- if ext in ['dcm', 'ima']:
93
- dicom_img = dicom_to_png(await file.read())
94
- images.append(dicom_img.convert("RGB"))
95
- elif ext in ['png', 'jpeg', 'jpg']:
96
- img = Image.open(io.BytesIO(await file.read())).convert("RGB")
97
- images.append(img)
98
- else:
99
- raise HTTPException(status_code=400, detail="Unsupported file type.")
100
- except Exception as e:
101
- raise HTTPException(status_code=500, detail=f"Error processing file {file.filename}: {e}")
102
 
 
103
  if not images:
104
- return JSONResponse(content={"error": "No valid images provided."}, status_code=400)
105
-
106
- if len(images) >= n_frames:
107
- images_sampled = [images[i] for i in np.linspace(0, len(images) - 1, n_frames, dtype=int)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  else:
109
- images_sampled = images + [images[-1]] * (n_frames - len(images))
 
 
110
 
111
- image_tensors = [image_transform(img) for img in images_sampled]
112
- images_tensor = torch.stack(image_tensors).permute(1, 0, 2, 3).unsqueeze(0).to(device)
113
 
114
  with torch.no_grad():
115
- class_outputs, generated_report, _ = combined_model(images_tensor)
116
- predicted_class = torch.argmax(class_outputs, dim=1).item()
117
- predicted_class_name = class_names[predicted_class]
118
 
119
  gc.collect()
120
  if torch.cuda.is_available():
121
  torch.cuda.empty_cache()
122
-
123
- return {
124
- "predicted_class": predicted_class_name,
125
- "generated_report": generated_report[0] if generated_report else "No report generated."
126
- }
127
-
128
- if __name__ == "__main__":
129
- uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
 
 
 
 
2
  import torch
 
 
 
 
 
3
  import numpy as np
4
+ from PIL import Image
5
+ import io
6
+ import gradio as gr
7
+ from torchvision import models, transforms
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
9
  from huggingface_hub import hf_hub_download
10
+ from model import CombinedModel, ImageToTextProjector
11
  import pydicom
12
+ import os
13
  import gc
 
 
 
14
 
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
 
17
+ # Load tokenizer and models
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
 
20
+ tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN)
21
+ video_model = models.video.r3d_18(weights="KINETICS400_V1")
22
+ video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
 
 
 
 
23
 
24
+ report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base")
25
+ projector = ImageToTextProjector(512, report_generator.config.d_model)
 
26
 
27
+ num_classes = 4
28
+ class_names = ["acute", "normal", "chronic", "lacunar"]
29
+ combined_model = CombinedModel(
30
+ video_model, report_generator, num_classes, projector, tokenizer
31
+ )
32
+ model_file = hf_hub_download(
33
+ "baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN
34
+ )
35
+ state_dict = torch.load(model_file, map_location=device)
36
+ combined_model.load_state_dict(state_dict)
37
+ combined_model.eval()
38
+
39
+ # Image transforms
40
+ image_transform = transforms.Compose(
41
+ [
42
+ transforms.Resize((112, 112)),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize(
45
+ mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]
46
+ ),
47
+ ]
48
+ )
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ def dicom_to_image(file_bytes):
52
+ dicom_file = pydicom.dcmread(io.BytesIO(file_bytes))
53
+ pixel_array = dicom_file.pixel_array.astype(np.float32)
54
+ pixel_array = ((pixel_array - pixel_array.min()) / pixel_array.ptp()) * 255.0
55
+ pixel_array = pixel_array.astype(np.uint8)
56
+ return Image.fromarray(pixel_array).convert("RGB")
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ def predict(images):
60
  if not images:
61
+ return "No image uploaded.", ""
62
+
63
+ # Convert images
64
+ processed_imgs = []
65
+ for img in images:
66
+ filename = img.name.lower()
67
+ if filename.endswith((".dcm", ".ima")):
68
+ dicom_img = dicom_to_image(img.read())
69
+ processed_imgs.append(dicom_img)
70
+ else:
71
+ pil_img = Image.open(img).convert("RGB")
72
+ processed_imgs.append(pil_img)
73
+
74
+ # Sample or pad
75
+ n_frames = 16
76
+ if len(processed_imgs) >= n_frames:
77
+ images_sampled = [
78
+ processed_imgs[i]
79
+ for i in np.linspace(0, len(processed_imgs) - 1, n_frames, dtype=int)
80
+ ]
81
  else:
82
+ images_sampled = processed_imgs + [processed_imgs[-1]] * (
83
+ n_frames - len(processed_imgs)
84
+ )
85
 
86
+ tensor_imgs = [image_transform(i) for i in images_sampled]
87
+ input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device)
88
 
89
  with torch.no_grad():
90
+ class_logits, report, _ = combined_model(input_tensor)
91
+ class_pred = torch.argmax(class_logits, dim=1).item()
92
+ class_name = class_names[class_pred]
93
 
94
  gc.collect()
95
  if torch.cuda.is_available():
96
  torch.cuda.empty_cache()
97
+
98
+ return class_name, report[0] if report else "No report generated."
99
+
100
+
101
+ # Gradio interface
102
+ demo = gr.Interface(
103
+ fn=predict,
104
+ inputs=gr.File(
105
+ file_types=[".dcm", ".jpg", ".jpeg", ".png"],
106
+ file_count="multiple",
107
+ label="Upload CT Scan Images",
108
+ ),
109
+ outputs=[gr.Textbox(label="Predicted Class"), gr.Textbox(label="Generated Report")],
110
+ title="Phronesis Medical Report Generator",
111
+ description="Upload CT scan DICOM or image files. Returns diagnosis classification and generated report.",
112
+ )
113
+
114
+ demo.launch()