radiolm / app.py
Ruurd's picture
Fix loading of models (only once per model)
d0f4aff
raw
history blame
2.28 kB
import os
import subprocess
def install(package):
subprocess.check_call([os.sys.executable, "-m", "pip", "install", package])
install("transformers")
import os
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
# Global cache for loaded models
model_cache = {}
# Load a model with progress bar
def load_model(model_name, progress=gr.Progress(track_tqdm=False)):
if model_name not in model_cache:
token = os.getenv("HF_TOKEN")
progress(0, desc="Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
progress(0.5, desc="Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
use_auth_token=token
)
model_cache[model_name] = (tokenizer, model)
progress(1, desc="Model ready.")
return f"{model_name} loaded and ready!"
else:
return f"{model_name} already loaded."
# Inference function using GPU
@spaces.GPU
def generate_text(model_name, prompt):
tokenizer, model = model_cache[model_name]
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Available models
model_choices = [
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"meta-llama/Llama-3.2-3B-Instruct",
"google/gemma-7b"
]
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Clinical Text Analysis with LLMs (LLaMA, DeepSeek, Gemma)")
with gr.Row():
model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
model_status = gr.Textbox(label="Model Status", interactive=False)
input_text = gr.Textbox(label="Input Clinical Text")
output_text = gr.Textbox(label="Generated Output")
analyze_button = gr.Button("Analyze")
# Load model when changed
model_selector.change(fn=load_model, inputs=model_selector, outputs=model_status)
# Generate output
analyze_button.click(fn=generate_text, inputs=[model_selector, input_text], outputs=output_text)
demo.launch()