File size: 2,107 Bytes
d0f4aff
5968a97
d0f4aff
5968a97
b6b7c74
 
 
 
 
 
 
 
 
 
 
 
0ebe852
b6b7c74
 
 
 
f6e6d69
b6b7c74
 
65d5ebe
b6b7c74
 
5968a97
 
b6b7c74
 
 
 
 
 
 
 
5968a97
b6b7c74
5968a97
 
 
 
 
 
b6b7c74
5968a97
b6b7c74
 
 
 
d0f4aff
5968a97
 
 
b6b7c74
 
 
 
 
 
 
0ebe852
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM

# Use a global variable to hold the current model and tokenizer
current_model = None
current_tokenizer = None

def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
    global current_model, current_tokenizer
    token = os.getenv("HF_TOKEN")

    progress(0, desc="Loading tokenizer...")
    current_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)

    progress(0.5, desc="Loading model...")
    current_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="cuda",
        use_auth_token=token
    )

    progress(1, desc="Model ready.")
    return f"{model_name} loaded and ready!"

@spaces.GPU
def generate_text(prompt):
    global current_model, current_tokenizer
    if current_model is None or current_tokenizer is None:
        return "⚠️ No model loaded yet. Please select a model first."
    
    inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
    outputs = current_model.generate(**inputs, max_new_tokens=256)
    return current_tokenizer.decode(outputs[0], skip_special_tokens=True)

# Model options
model_choices = [
    "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    "meta-llama/Llama-3.2-3B-Instruct",
    "google/gemma-7b"
]

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## Clinical Text Testing with LLaMA, DeepSeek, and Gemma")

    model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
    model_status = gr.Textbox(label="Model Status", interactive=False)

    input_text = gr.Textbox(label="Input Clinical Text")
    output_text = gr.Textbox(label="Generated Output")

    generate_btn = gr.Button("Generate")

    # Load model on dropdown change
    model_selector.change(fn=load_model_on_selection, inputs=model_selector, outputs=model_status)

    # Generate with current model
    generate_btn.click(fn=generate_text, inputs=input_text, outputs=output_text)

demo.launch()