Spaces:
Running
Running
File size: 3,511 Bytes
d4d94d2 79c27a2 d4d94d2 79c27a2 d4d94d2 79c27a2 d4d94d2 79c27a2 d4d94d2 79c27a2 d4d94d2 e65d0e5 79c27a2 8c6ba75 d4d94d2 79c27a2 d4d94d2 79c27a2 d4d94d2 8c6ba75 e65d0e5 8c6ba75 d4d94d2 e65d0e5 d4d94d2 8c6ba75 79c27a2 d4d94d2 79c27a2 a54164e 8c6ba75 d4d94d2 8c6ba75 d4d94d2 8c6ba75 d4d94d2 8c6ba75 d4d94d2 8c6ba75 d4d94d2 79c27a2 8c6ba75 79c27a2 d4d94d2 79c27a2 d4d94d2 79c27a2 d4d94d2 8c6ba75 18a45f9 d4d94d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
# app.py
import torch
import numpy as np
from PIL import Image
import io
import gradio as gr
from torchvision import models, transforms
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from huggingface_hub import hf_hub_download
from model import CombinedModel, ImageToTextProjector
import pydicom
import os
import gc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HF_TOKEN = os.getenv("HF_TOKEN")
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
# Model loading
tokenizer = AutoTokenizer.from_pretrained("baliddeki/phronesis-ml", token=HF_TOKEN)
video_model = models.video.r3d_18(weights="KINETICS400_V1")
video_model.fc = torch.nn.Linear(video_model.fc.in_features, 512)
report_generator = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-v2-base")
projector = ImageToTextProjector(512, report_generator.config.d_model)
num_classes = 4
class_names = ["acute", "normal", "chronic", "lacunar"]
combined_model = CombinedModel(video_model, report_generator, num_classes, projector, tokenizer)
model_file = hf_hub_download("baliddeki/phronesis-ml", "pytorch_model.bin", token=HF_TOKEN)
state_dict = torch.load(model_file, map_location=device)
combined_model.load_state_dict(state_dict)
combined_model.to(device)
combined_model.eval()
image_transform = transforms.Compose([
transforms.Resize((112, 112)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
])
def dicom_to_image(file_bytes):
dicom_file = pydicom.dcmread(io.BytesIO(file_bytes))
pixel_array = dicom_file.pixel_array.astype(np.float32)
pixel_array = ((pixel_array - pixel_array.min()) / pixel_array.ptp()) * 255.0
pixel_array = pixel_array.astype(np.uint8)
return Image.fromarray(pixel_array).convert("RGB")
def predict(files):
if not files:
return "No images uploaded.", ""
processed_imgs = []
for file_obj in files:
filename = file_obj.name.lower()
if filename.endswith((".dcm", ".ima")):
file_bytes = file_obj.read()
img = dicom_to_image(file_bytes)
else:
img = Image.open(file_obj).convert("RGB")
processed_imgs.append(img)
n_frames = 16
if len(processed_imgs) >= n_frames:
images_sampled = [
processed_imgs[i]
for i in np.linspace(0, len(processed_imgs)-1, n_frames, dtype=int)
]
else:
images_sampled = processed_imgs + [processed_imgs[-1]] * (n_frames - len(processed_imgs))
tensor_imgs = [image_transform(i) for i in images_sampled]
input_tensor = torch.stack(tensor_imgs).permute(1, 0, 2, 3).unsqueeze(0).to(device)
with torch.no_grad():
class_logits, report, _ = combined_model(input_tensor)
class_pred = torch.argmax(class_logits, dim=1).item()
class_name = class_names[class_pred]
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return class_name, report[0] if report else "No report generated."
# Gradio Blocks (100% reliable approach)
# Replace your Blocks interface with this simpler Interface approach
demo = gr.Interface(
fn=predict,
inputs=gr.File(file_count="multiple", file_types=[".dcm", ".jpg", ".jpeg", ".png"]),
outputs=[gr.Textbox(label="Predicted Class"), gr.Textbox(label="Generated Report")],
title="🩺 Phronesis Medical Report Generator",
description="Upload CT scan images to generate a medical report"
)
demo.launch()
|