Spaces:
Sleeping
Sleeping
# app.py | |
import torch | |
import numpy as np | |
from PIL import Image | |
import io | |
import gradio as gr | |
from torchvision import models, transforms | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from huggingface_hub import hf_hub_download | |
from model import CombinedModel, ImageToTextProjector | |
import pydicom | |
import os | |
import gc | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
os.environ["HF_HOME"] = "/tmp/huggingface_cache" | |
# Model loading | |
tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN) | |
video_model = models.video.r3d_18(weights="KINETICS400_V1") | |
video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512) | |
report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base") | |
projector = ImageToTextProjector(512, report_generator.config.d_model) | |
num_classes = 4 | |
class_names = ["acute", "normal", "chronic", "lacunar"] | |
combined_model = CombinedModel(video_model, report_generator, num_classes, projector, tokenizer) | |
model_file = hf_hub_download("baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN) | |
state_dict = torch.load(model_file, map_location=device) | |
combined_model.load_state_dict(state_dict) | |
combined_model.to(device) | |
combined_model.eval() | |
image_transform = transforms.Compose([ | |
transforms.Resize((112, 112)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]), | |
]) | |
def dicom_to_image(file_bytes): | |
dicom_file = pydicom.dcmread(io.BytesIO(file_bytes)) | |
pixel_array = dicom_file.pixel_array.astype(np.float32) | |
pixel_array = ((pixel_array - pixel_array.min()) / pixel_array.ptp()) * 255.0 | |
pixel_array = pixel_array.astype(np.uint8) | |
return Image.fromarray(pixel_array).convert("RGB") | |
def predict(files): | |
if not files: | |
return "No images uploaded.", "" | |
processed_imgs = [] | |
for file_obj in files: | |
filename = file_obj.name.lower() | |
if filename.endswith((".dcm", ".ima")): | |
file_bytes = file_obj.read() | |
img = dicom_to_image(file_bytes) | |
else: | |
img = Image.open(file_obj).convert("RGB") | |
processed_imgs.append(img) | |
n_frames = 16 | |
if len(processed_imgs) >= n_frames: | |
images_sampled = [ | |
processed_imgs[i] | |
for i in np.linspace(0, len(processed_imgs)-1, n_frames, dtype=int) | |
] | |
else: | |
images_sampled = processed_imgs + [processed_imgs[-1]] * (n_frames - len(processed_imgs)) | |
tensor_imgs = [image_transform(i) for i in images_sampled] | |
input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
class_logits, report, _ = combined_model(input_tensor) | |
class_pred = torch.argmax(class_logits, dim=1).item() | |
class_name = class_names[class_pred] | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return class_name, report[0] if report else "No report generated." | |
# Gradio Blocks (100% reliable approach) | |
# Replace your Blocks interface with this simpler Interface approach | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.File(file_count="multiple", file_types=[".dcm", ".jpg", ".jpeg", ".png"]), | |
outputs=[gr.Textbox(label="Predicted Class"), gr.Textbox(label="Generated Report")], | |
title="🩺 Phronesis Medical Report Generator", | |
description="Upload CT scan images to generate a medical report" | |
) | |
demo.launch() | |