File size: 3,322 Bytes
d0f4aff 5968a97 d0f4aff 5968a97 5e84c69 4b15ccd 0115682 b6b7c74 0115682 b6b7c74 0ebe852 b6b7c74 166106f b6b7c74 65d5ebe b6b7c74 5968a97 0115682 5968a97 0115682 b6b7c74 0115682 4b15ccd 0115682 4b15ccd 0115682 4b15ccd bc19680 0115682 bc19680 0115682 4b15ccd 0115682 5968a97 0115682 5968a97 4b15ccd 5968a97 b6b7c74 5968a97 0115682 b6b7c74 5e84c69 0115682 d0f4aff 0115682 4b15ccd 0115682 5e84c69 0115682 b6b7c74 0115682 4b15ccd 0ebe852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import os
import torch
import gradio as gr
import spaces
from transformers import AutoTokenizer, AutoModelForCausalLM
# Global model/tokenizer
current_model = None
current_tokenizer = None
# Load model when selected
def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
global current_model, current_tokenizer
token = os.getenv("HF_TOKEN")
progress(0, desc="Loading tokenizer...")
current_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
progress(0.5, desc="Loading model...")
current_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="cpu",
use_auth_token=token
)
progress(1, desc="Model ready.")
return f"{model_name} loaded and ready!"
# Inference - yields response token-by-token
@spaces.GPU
def chat_with_model(history):
global current_model, current_tokenizer
if current_model is None or current_tokenizer is None:
yield history + [("⚠️ No model loaded.", "")]
current_model.to("cuda")
# Combine conversation history into prompt
prompt = ""
for user_msg, bot_msg in history:
prompt += f"[INST] {user_msg.strip()} [/INST] {bot_msg.strip()} "
prompt += f"[INST] {history[-1][0]} [/INST]"
inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
output_ids = []
# Clone history to avoid mutating during yield
updated_history = history.copy()
updated_history[-1] = (history[-1][0], "")
for token_id in current_model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
return_dict_in_generate=True,
output_scores=False
).sequences[0]:
output_ids.append(token_id.item())
decoded = current_tokenizer.decode(output_ids, skip_special_tokens=True)
updated_history[-1] = (history[-1][0], decoded)
yield updated_history
# When user submits a message
def add_user_message(message, history):
return "", history + [(message, "")]
# Model choices
model_choices = [
"meta-llama/Llama-3.2-3B-Instruct",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"google/gemma-7b"
]
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## Clinical Chatbot — LLaMA, DeepSeek, Gemma")
default_model = gr.State("meta-llama/Llama-3.2-3B-Instruct")
with gr.Row():
model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
model_status = gr.Textbox(label="Model Status", interactive=False)
chatbot = gr.Chatbot(label="Chat")
msg = gr.Textbox(label="Your Message", placeholder="Enter your clinical query...", show_label=False)
clear_btn = gr.Button("Clear Chat")
# Load model on launch
demo.load(fn=load_model_on_selection, inputs=default_model, outputs=model_status)
# Load model on dropdown selection
model_selector.change(fn=load_model_on_selection, inputs=model_selector, outputs=model_status)
# On message submit: update history, then stream bot reply
msg.submit(add_user_message, [msg, chatbot], [msg, chatbot], queue=False).then(
fn=chat_with_model, inputs=chatbot, outputs=chatbot
)
# Clear chat
clear_btn.click(lambda: [], None, chatbot, queue=False)
demo.launch()
|