File size: 3,511 Bytes
d4d94d2
79c27a2
 
d4d94d2
 
 
 
 
79c27a2
d4d94d2
79c27a2
d4d94d2
79c27a2
 
d4d94d2
79c27a2
d4d94d2
e65d0e5
79c27a2
8c6ba75
d4d94d2
 
 
79c27a2
d4d94d2
 
79c27a2
d4d94d2
 
8c6ba75
e65d0e5
8c6ba75
d4d94d2
 
e65d0e5
d4d94d2
 
8c6ba75
 
 
 
 
79c27a2
d4d94d2
 
 
 
 
 
79c27a2
a54164e
 
8c6ba75
d4d94d2
 
8c6ba75
 
d4d94d2
8c6ba75
 
d4d94d2
8c6ba75
 
d4d94d2
 
 
 
 
8c6ba75
d4d94d2
79c27a2
8c6ba75
79c27a2
d4d94d2
 
79c27a2
 
d4d94d2
 
 
79c27a2
 
 
 
d4d94d2
 
 
8c6ba75
18a45f9
 
 
 
 
 
 
 
d4d94d2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# app.py
import torch
import numpy as np
from PIL import Image
import io
import gradio as gr
from torchvision import models, transforms
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from huggingface_hub import hf_hub_download
from model import CombinedModel, ImageToTextProjector
import pydicom
import os
import gc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

HF_TOKEN = os.getenv("HF_TOKEN")
os.environ["HF_HOME"] = "/tmp/huggingface_cache"

# Model loading
tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN)
video_model = models.video.r3d_18(weights="KINETICS400_V1")
video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)

report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base")
projector = ImageToTextProjector(512, report_generator.config.d_model)

num_classes = 4
class_names = ["acute", "normal", "chronic", "lacunar"]
combined_model = CombinedModel(video_model, report_generator, num_classes, projector, tokenizer)

model_file = hf_hub_download("baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN)
state_dict = torch.load(model_file, map_location=device)
combined_model.load_state_dict(state_dict)
combined_model.to(device)
combined_model.eval()

image_transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
])

def dicom_to_image(file_bytes):
    dicom_file = pydicom.dcmread(io.BytesIO(file_bytes))
    pixel_array = dicom_file.pixel_array.astype(np.float32)
    pixel_array = ((pixel_array - pixel_array.min()) / pixel_array.ptp()) * 255.0
    pixel_array = pixel_array.astype(np.uint8)
    return Image.fromarray(pixel_array).convert("RGB")

def predict(files):
    if not files:
        return "No images uploaded.", ""

    processed_imgs = []
    for file_obj in files:
        filename = file_obj.name.lower()
        if filename.endswith((".dcm", ".ima")):
            file_bytes = file_obj.read()
            img = dicom_to_image(file_bytes)
        else:
            img = Image.open(file_obj).convert("RGB")
        processed_imgs.append(img)

    n_frames = 16
    if len(processed_imgs) >= n_frames:
        images_sampled = [
            processed_imgs[i]
            for i in np.linspace(0, len(processed_imgs)-1, n_frames, dtype=int)
        ]
    else:
        images_sampled = processed_imgs + [processed_imgs[-1]] * (n_frames - len(processed_imgs))

    tensor_imgs = [image_transform(i) for i in images_sampled]
    input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device)

    with torch.no_grad():
        class_logits, report, _ = combined_model(input_tensor)
        class_pred = torch.argmax(class_logits, dim=1).item()
        class_name = class_names[class_pred]

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return class_name, report[0] if report else "No report generated."

# Gradio Blocks (100% reliable approach)
# Replace your Blocks interface with this simpler Interface approach
demo = gr.Interface(
    fn=predict,
    inputs=gr.File(file_count="multiple", file_types=[".dcm", ".jpg", ".jpeg", ".png"]),
    outputs=[gr.Textbox(label="Predicted Class"), gr.Textbox(label="Generated Report")],
    title="🩺 Phronesis Medical Report Generator",
    description="Upload CT scan images to generate a medical report"
)

demo.launch()