# app.py import torch import numpy as np from PIL import Image import io import gradio as gr from torchvision import models, transforms from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from huggingface_hub import hf_hub_download from model import CombinedModel, ImageToTextProjector import pydicom import os import gc device = torch.device("cuda" if torch.cuda.is_available() else "cpu") HF_TOKEN = os.getenv("HF_TOKEN") os.environ["HF_HOME"] = "/tmp/huggingface_cache" # Load tokenizer and models tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN) video_model = models.video.r3d_18(weights="KINETICS400_V1") video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512) report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base") projector = ImageToTextProjector(512, report_generator.config.d_model) num_classes = 4 class_names = ["acute", "normal", "chronic", "lacunar"] combined_model = CombinedModel( video_model, report_generator, num_classes, projector, tokenizer ) model_file = hf_hub_download( "baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN ) state_dict = torch.load(model_file, map_location=device) combined_model.load_state_dict(state_dict) combined_model.to(device) combined_model.eval() # Image transforms image_transform = transforms.Compose( [ transforms.Resize((112, 112)), transforms.ToTensor(), transforms.Normalize( mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989] ), ] ) def dicom_to_image(file_bytes): dicom_file = pydicom.dcmread(io.BytesIO(file_bytes)) pixel_array = dicom_file.pixel_array.astype(np.float32) pixel_array = ((pixel_array - pixel_array.min()) / pixel_array.ptp()) * 255.0 pixel_array = pixel_array.astype(np.uint8) return Image.fromarray(pixel_array).convert("RGB") def predict(files): if not files: return "No image uploaded.", "" processed_imgs = [] for file in files: filename = file.name.lower() if filename.endswith((".dcm", ".ima")): file_bytes = file.read() dicom_img = dicom_to_image(file_bytes) processed_imgs.append(dicom_img) else: pil_img = Image.open(file).convert("RGB") processed_imgs.append(pil_img) n_frames = 16 if len(processed_imgs) >= n_frames: images_sampled = [ processed_imgs[i] for i in np.linspace(0, len(processed_imgs) - 1, n_frames, dtype=int) ] else: images_sampled = processed_imgs + [processed_imgs[-1]] * ( n_frames - len(processed_imgs) ) tensor_imgs = [image_transform(i) for i in images_sampled] input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device) with torch.no_grad(): class_logits, report, _ = combined_model(input_tensor) class_pred = torch.argmax(class_logits, dim=1).item() class_name = class_names[class_pred] gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return class_name, report[0] if report else "No report generated." # Gradio Blocks setup (explicitly) with gr.Blocks() as demo: gr.Markdown("## 🩺 Phronesis Medical Report Generator") file_input = gr.File( file_count="multiple", file_types=[".dcm", ".jpg", ".jpeg", ".png"], label="Upload CT Scan Images", ) btn = gr.Button("Generate Report") class_output = gr.Textbox(label="Predicted Class") report_output = gr.Textbox(label="Generated Report") btn.click(fn=predict, inputs=file_input, outputs=[class_output, report_output]) demo.launch()