File size: 2,676 Bytes
d0f4aff
5968a97
d0f4aff
5968a97
4b15ccd
 
b6b7c74
 
 
 
 
 
 
 
 
 
 
0ebe852
b6b7c74
 
 
 
166106f
b6b7c74
 
65d5ebe
b6b7c74
 
5968a97
 
b6b7c74
 
 
4b15ccd
 
 
b6b7c74
4b15ccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5968a97
b6b7c74
5968a97
 
4b15ccd
5968a97
 
 
b6b7c74
5968a97
b6b7c74
 
 
 
d0f4aff
5968a97
b6b7c74
 
4b15ccd
 
b6b7c74
 
 
 
4b15ccd
 
 
0ebe852
4b15ccd
0ebe852
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer


# Use a global variable to hold the current model and tokenizer
current_model = None
current_tokenizer = None

def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
    global current_model, current_tokenizer
    token = os.getenv("HF_TOKEN")

    progress(0, desc="Loading tokenizer...")
    current_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)

    progress(0.5, desc="Loading model...")
    current_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="cpu",
        use_auth_token=token
    )

    progress(1, desc="Model ready.")
    return f"{model_name} loaded and ready!"

@spaces.GPU
def generate_text(prompt):
    global current_model, current_tokenizer
    if current_model is None or current_tokenizer is None:
        yield "⚠️ No model loaded yet. Please select a model first."

    current_model.to("cuda")
    inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)

    output_ids = []
    streamer_output = ""

    def token_streamer():
        for token_id in current_model.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=False,
            return_dict_in_generate=True,
            output_scores=False
        ).sequences[0]:
            output_ids.append(token_id.item())
            yield current_tokenizer.decode(output_ids, skip_special_tokens=True)

    for partial_output in token_streamer():
        yield partial_output


# Model options
model_choices = [
    "meta-llama/Llama-3.2-3B-Instruct",
    "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    "google/gemma-7b"
]

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## Clinical Text Testing with LLaMA, DeepSeek, and Gemma")

    model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
    model_status = gr.Textbox(label="Model Status", interactive=False)

    input_text = gr.Textbox(label="Input Clinical Text")
    generate_btn = gr.Button("Generate")

    output_text = gr.Textbox(label="Generated Output")

    # Load model on dropdown change
    model_selector.change(fn=load_model_on_selection, inputs=model_selector, outputs=model_status)

    # Generate with current model
    generate_btn.click(fn=generate_text, inputs=input_text, outputs=output_text, stream=True)
    input_text.submit(fn=generate_text, inputs=input_text, outputs=output_text, stream=True)


load_model_on_selection("meta-llama/Llama-3.2-3B-Instruct")
demo.launch()