baliddeki's picture
fix 3
18a45f9
raw
history blame
3.51 kB
# app.py
import torch
import numpy as np
from PIL import Image
import io
import gradio as gr
from torchvision import models, transforms
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from huggingface_hub import hf_hub_download
from model import CombinedModel, ImageToTextProjector
import pydicom
import os
import gc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HF_TOKEN = os.getenv("HF_TOKEN")
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
# Model loading
tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN)
video_model = models.video.r3d_18(weights="KINETICS400_V1")
video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base")
projector = ImageToTextProjector(512, report_generator.config.d_model)
num_classes = 4
class_names = ["acute", "normal", "chronic", "lacunar"]
combined_model = CombinedModel(video_model, report_generator, num_classes, projector, tokenizer)
model_file = hf_hub_download("baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN)
state_dict = torch.load(model_file, map_location=device)
combined_model.load_state_dict(state_dict)
combined_model.to(device)
combined_model.eval()
image_transform = transforms.Compose([
transforms.Resize((112, 112)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
])
def dicom_to_image(file_bytes):
dicom_file = pydicom.dcmread(io.BytesIO(file_bytes))
pixel_array = dicom_file.pixel_array.astype(np.float32)
pixel_array = ((pixel_array - pixel_array.min()) / pixel_array.ptp()) * 255.0
pixel_array = pixel_array.astype(np.uint8)
return Image.fromarray(pixel_array).convert("RGB")
def predict(files):
if not files:
return "No images uploaded.", ""
processed_imgs = []
for file_obj in files:
filename = file_obj.name.lower()
if filename.endswith((".dcm", ".ima")):
file_bytes = file_obj.read()
img = dicom_to_image(file_bytes)
else:
img = Image.open(file_obj).convert("RGB")
processed_imgs.append(img)
n_frames = 16
if len(processed_imgs) >= n_frames:
images_sampled = [
processed_imgs[i]
for i in np.linspace(0, len(processed_imgs)-1, n_frames, dtype=int)
]
else:
images_sampled = processed_imgs + [processed_imgs[-1]] * (n_frames - len(processed_imgs))
tensor_imgs = [image_transform(i) for i in images_sampled]
input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device)
with torch.no_grad():
class_logits, report, _ = combined_model(input_tensor)
class_pred = torch.argmax(class_logits, dim=1).item()
class_name = class_names[class_pred]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return class_name, report[0] if report else "No report generated."
# Gradio Blocks (100% reliable approach)
# Replace your Blocks interface with this simpler Interface approach
demo = gr.Interface(
fn=predict,
inputs=gr.File(file_count="multiple", file_types=[".dcm", ".jpg", ".jpeg", ".png"]),
outputs=[gr.Textbox(label="Predicted Class"), gr.Textbox(label="Generated Report")],
title="🩺 Phronesis Medical Report Generator",
description="Upload CT scan images to generate a medical report"
)
demo.launch()