File size: 2,676 Bytes
d0f4aff 5968a97 d0f4aff 5968a97 4b15ccd b6b7c74 0ebe852 b6b7c74 166106f b6b7c74 65d5ebe b6b7c74 5968a97 b6b7c74 4b15ccd b6b7c74 4b15ccd 5968a97 b6b7c74 5968a97 4b15ccd 5968a97 b6b7c74 5968a97 b6b7c74 d0f4aff 5968a97 b6b7c74 4b15ccd b6b7c74 4b15ccd 0ebe852 4b15ccd 0ebe852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import os
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
# Use a global variable to hold the current model and tokenizer
current_model = None
current_tokenizer = None
def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
global current_model, current_tokenizer
token = os.getenv("HF_TOKEN")
progress(0, desc="Loading tokenizer...")
current_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
progress(0.5, desc="Loading model...")
current_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="cpu",
use_auth_token=token
)
progress(1, desc="Model ready.")
return f"{model_name} loaded and ready!"
@spaces.GPU
def generate_text(prompt):
global current_model, current_tokenizer
if current_model is None or current_tokenizer is None:
yield "⚠️ No model loaded yet. Please select a model first."
current_model.to("cuda")
inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
output_ids = []
streamer_output = ""
def token_streamer():
for token_id in current_model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
return_dict_in_generate=True,
output_scores=False
).sequences[0]:
output_ids.append(token_id.item())
yield current_tokenizer.decode(output_ids, skip_special_tokens=True)
for partial_output in token_streamer():
yield partial_output
# Model options
model_choices = [
"meta-llama/Llama-3.2-3B-Instruct",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"google/gemma-7b"
]
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Clinical Text Testing with LLaMA, DeepSeek, and Gemma")
model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
model_status = gr.Textbox(label="Model Status", interactive=False)
input_text = gr.Textbox(label="Input Clinical Text")
generate_btn = gr.Button("Generate")
output_text = gr.Textbox(label="Generated Output")
# Load model on dropdown change
model_selector.change(fn=load_model_on_selection, inputs=model_selector, outputs=model_status)
# Generate with current model
generate_btn.click(fn=generate_text, inputs=input_text, outputs=output_text, stream=True)
input_text.submit(fn=generate_text, inputs=input_text, outputs=output_text, stream=True)
load_model_on_selection("meta-llama/Llama-3.2-3B-Instruct")
demo.launch()
|