File size: 2,528 Bytes
d0f4aff 5968a97 d0f4aff 5968a97 4b15ccd b6b7c74 0ebe852 b6b7c74 166106f b6b7c74 65d5ebe b6b7c74 5968a97 b6b7c74 4b15ccd b6b7c74 4b15ccd bc19680 4b15ccd bc19680 4b15ccd bc19680 4b15ccd 5968a97 b6b7c74 5968a97 4b15ccd 5968a97 b6b7c74 5968a97 b6b7c74 d0f4aff 5968a97 b6b7c74 4b15ccd b6b7c74 0a37797 4b15ccd 0ebe852 4b15ccd 0ebe852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import os
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
# Use a global variable to hold the current model and tokenizer
current_model = None
current_tokenizer = None
def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
global current_model, current_tokenizer
token = os.getenv("HF_TOKEN")
progress(0, desc="Loading tokenizer...")
current_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
progress(0.5, desc="Loading model...")
current_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="cpu",
use_auth_token=token
)
progress(1, desc="Model ready.")
return f"{model_name} loaded and ready!"
@spaces.GPU
def generate_text(prompt):
global current_model, current_tokenizer
if current_model is None or current_tokenizer is None:
yield "⚠️ No model loaded yet. Please select a model first."
current_model.to("cuda")
inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
output_ids = []
streamer_output = ""
for token_id in current_model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
return_dict_in_generate=True,
output_scores=False
).sequences[0]:
output_ids.append(token_id.item())
yield current_tokenizer.decode(output_ids, skip_special_tokens=True)
# Model options
model_choices = [
"meta-llama/Llama-3.2-3B-Instruct",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"google/gemma-7b"
]
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Clinical Text Testing with LLaMA, DeepSeek, and Gemma")
model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
model_status = gr.Textbox(label="Model Status", interactive=False)
input_text = gr.Textbox(label="Input Clinical Text")
generate_btn = gr.Button("Generate")
output_text = gr.Textbox(label="Generated Output")
# Load model on dropdown change
model_selector.change(fn=load_model_on_selection, inputs=model_selector, outputs=model_status)
# Generate with current model
generate_btn.click(fn=generate_text, inputs=input_text, outputs=output_text)
input_text.submit(fn=generate_text, inputs=input_text, outputs=output_text)
load_model_on_selection("meta-llama/Llama-3.2-3B-Instruct")
demo.launch()
|