File size: 3,322 Bytes
d0f4aff
5968a97
d0f4aff
5968a97
5e84c69
4b15ccd
0115682
b6b7c74
 
 
0115682
b6b7c74
 
 
 
 
 
0ebe852
b6b7c74
 
 
 
166106f
b6b7c74
 
65d5ebe
b6b7c74
 
5968a97
0115682
5968a97
0115682
b6b7c74
 
0115682
4b15ccd
 
 
0115682
 
 
 
 
 
 
4b15ccd
0115682
 
 
 
4b15ccd
bc19680
 
 
 
 
 
0115682
bc19680
0115682
 
 
4b15ccd
0115682
 
 
5968a97
0115682
5968a97
 
4b15ccd
5968a97
 
 
b6b7c74
5968a97
0115682
b6b7c74
5e84c69
 
0115682
 
 
d0f4aff
0115682
 
 
4b15ccd
0115682
5e84c69
 
0115682
b6b7c74
 
0115682
 
 
 
 
 
 
4b15ccd
0ebe852
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM

# Global model/tokenizer
current_model = None
current_tokenizer = None

# Load model when selected
def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
    global current_model, current_tokenizer
    token = os.getenv("HF_TOKEN")

    progress(0, desc="Loading tokenizer...")
    current_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)

    progress(0.5, desc="Loading model...")
    current_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="cpu",
        use_auth_token=token
    )

    progress(1, desc="Model ready.")
    return f"{model_name} loaded and ready!"

# Inference - yields response token-by-token
@spaces.GPU
def chat_with_model(history):
    global current_model, current_tokenizer
    if current_model is None or current_tokenizer is None:
        yield history + [("⚠️ No model loaded.", "")]

    current_model.to("cuda")

    # Combine conversation history into prompt
    prompt = ""
    for user_msg, bot_msg in history:
        prompt += f"[INST] {user_msg.strip()} [/INST] {bot_msg.strip()} "
    prompt += f"[INST] {history[-1][0]} [/INST]"

    inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
    output_ids = []

    # Clone history to avoid mutating during yield
    updated_history = history.copy()
    updated_history[-1] = (history[-1][0], "")

    for token_id in current_model.generate(
        **inputs,
        max_new_tokens=256,
        do_sample=False,
        return_dict_in_generate=True,
        output_scores=False
    ).sequences[0]:
        output_ids.append(token_id.item())
        decoded = current_tokenizer.decode(output_ids, skip_special_tokens=True)
        updated_history[-1] = (history[-1][0], decoded)
        yield updated_history

# When user submits a message
def add_user_message(message, history):
    return "", history + [(message, "")]

# Model choices
model_choices = [
    "meta-llama/Llama-3.2-3B-Instruct",
    "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    "google/gemma-7b"
]

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## Clinical Chatbot — LLaMA, DeepSeek, Gemma")

    default_model = gr.State("meta-llama/Llama-3.2-3B-Instruct")

    with gr.Row():
        model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
        model_status = gr.Textbox(label="Model Status", interactive=False)

    chatbot = gr.Chatbot(label="Chat")
    msg = gr.Textbox(label="Your Message", placeholder="Enter your clinical query...", show_label=False)
    clear_btn = gr.Button("Clear Chat")

    # Load model on launch
    demo.load(fn=load_model_on_selection, inputs=default_model, outputs=model_status)

    # Load model on dropdown selection
    model_selector.change(fn=load_model_on_selection, inputs=model_selector, outputs=model_status)

    # On message submit: update history, then stream bot reply
    msg.submit(add_user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
        fn=chat_with_model, inputs=chatbot, outputs=chatbot
    )

    # Clear chat
    clear_btn.click(lambda: [], None, chatbot, queue=False)

demo.launch()