File size: 2,276 Bytes
65d5ebe d0f4aff 5968a97 d0f4aff 5968a97 d0f4aff 0ebe852 d0f4aff 65d5ebe d0f4aff 5968a97 d0f4aff 5968a97 d0f4aff 5968a97 d0f4aff 5968a97 d0f4aff 5968a97 d0f4aff 5968a97 d0f4aff 5968a97 d0f4aff 5968a97 d0f4aff 5968a97 d0f4aff 5968a97 0ebe852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import subprocess
def install(package):
subprocess.check_call([os.sys.executable, "-m", "pip", "install", package])
install("transformers")
import os
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
# Global cache for loaded models
model_cache = {}
# Load a model with progress bar
def load_model(model_name, progress=gr.Progress(track_tqdm=False)):
if model_name not in model_cache:
token = os.getenv("HF_TOKEN")
progress(0, desc="Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
progress(0.5, desc="Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
use_auth_token=token
)
model_cache[model_name] = (tokenizer, model)
progress(1, desc="Model ready.")
return f"{model_name} loaded and ready!"
else:
return f"{model_name} already loaded."
# Inference function using GPU
@spaces.GPU
def generate_text(model_name, prompt):
tokenizer, model = model_cache[model_name]
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Available models
model_choices = [
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"meta-llama/Llama-3.2-3B-Instruct",
"google/gemma-7b"
]
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Clinical Text Analysis with LLMs (LLaMA, DeepSeek, Gemma)")
with gr.Row():
model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
model_status = gr.Textbox(label="Model Status", interactive=False)
input_text = gr.Textbox(label="Input Clinical Text")
output_text = gr.Textbox(label="Generated Output")
analyze_button = gr.Button("Analyze")
# Load model when changed
model_selector.change(fn=load_model, inputs=model_selector, outputs=model_status)
# Generate output
analyze_button.click(fn=generate_text, inputs=[model_selector, input_text], outputs=output_text)
demo.launch()
|