import os import torch import time import gradio as gr import spaces from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer import threading from transformers import TextIteratorStreamer import threading from transformers import TextIteratorStreamer import queue class RichTextStreamer(TextIteratorStreamer): def __init__(self, tokenizer, **kwargs): super().__init__(tokenizer, **kwargs) self.token_queue = queue.Queue() def put(self, value): if isinstance(value, torch.Tensor): token_ids = value.view(-1).tolist() elif isinstance(value, list): token_ids = value else: token_ids = [value] for token_id in token_ids: token_str = self.tokenizer.decode([token_id], **self.decode_kwargs) is_special = token_id in self.tokenizer.all_special_ids self.token_queue.put({ "token_id": token_id, "token": token_str, "is_special": is_special }) def __iter__(self): while True: try: token_info = self.token_queue.get(timeout=self.timeout) yield token_info except queue.Empty: if self.end_of_generation.is_set(): break @spaces.GPU def chat_with_model(messages): global current_model, current_tokenizer if current_model is None or current_tokenizer is None: yield messages + [{"role": "assistant", "content": "⚠️ No model loaded."}] return pad_id = current_tokenizer.pad_token_id eos_id = current_tokenizer.eos_token_id if pad_id is None: pad_id = current_tokenizer.unk_token_id or 0 prompt = format_prompt(messages) device = torch.device("cuda") current_model.to(device).half() inputs = current_tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} streamer = RichTextStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False) max_new_tokens = 256 generated_tokens = 0 output_text = "" in_think = False generation_kwargs = dict( **inputs, max_new_tokens=max_new_tokens, do_sample=True, streamer=streamer, eos_token_id=eos_id, pad_token_id=pad_id ) thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs) thread.start() messages = messages.copy() messages.append({"role": "assistant", "content": ""}) print(f'Step 1: {messages}') prompt_text = current_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=False) for token_info in streamer: token_str = token_info["token"] token_id = token_info["token_id"] is_special = token_info["is_special"] # Stop immediately at EOS if token_id == eos_id: break # Detect reasoning block if "" in token_str: in_think = True token_str = token_str.replace("", "") output_text += "*" if "" in token_str: in_think = False token_str = token_str.replace("", "") output_text += token_str + "*" else: output_text += token_str # Early stopping if user reappears if "\nUser:" in output_text: output_text = output_text.split("\nUser:")[0].rstrip() messages[-1]["content"] = output_text break # Strip prompt from start of generated output if output_text.startswith(prompt_text): output_text = output_text[len(prompt_text):] else: output_text = output_text generated_tokens += 1 if generated_tokens >= max_new_tokens: break messages[-1]["content"] = output_text print(f'Step 2: {messages}') yield messages if in_think: output_text += "*" messages[-1]["content"] = output_text # Wait for thread to finish thread.join(timeout=1.0) current_model.to("cpu") torch.cuda.empty_cache() messages[-1]["content"] = output_text print(f'Step 3: {messages}') yield messages # Globals current_model = None current_tokenizer = None def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)): global current_model, current_tokenizer token = os.getenv("HF_TOKEN") progress(0, desc="Loading tokenizer...") current_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token) progress(0.5, desc="Loading model...") current_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="cpu", # loaded to CPU initially use_auth_token=token ) progress(1, desc="Model ready.") return f"{model_name} loaded and ready!" # Format conversation as plain text def format_prompt(messages): prompt = "" for msg in messages: role = msg["role"] if role == "user": prompt += f"User: {msg['content'].strip()}\n" elif role == "assistant": prompt += f"Assistant: {msg['content'].strip()}\n" prompt += "Assistant:" return prompt def add_user_message(user_input, history): return "", history + [{"role": "user", "content": user_input}] # Curated models model_choices = [ "meta-llama/Llama-3.2-3B-Instruct", "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", "google/gemma-7b", "mistralai/Mistral-Small-3.1-24B-Instruct-2503" ] with gr.Blocks() as demo: gr.Markdown("## Clinical Chatbot (Streaming)") default_model = gr.State(model_choices[0]) with gr.Row(): mode = gr.Radio(["Choose from list", "Enter custom model"], value="Choose from list", label="Model Input Mode") model_selector = gr.Dropdown(choices=model_choices, label="Select Predefined Model") model_textbox = gr.Textbox(label="Or Enter HF Model Name") model_status = gr.Textbox(label="Model Status", interactive=False) chatbot = gr.Chatbot(label="Chat", type="messages") msg = gr.Textbox(label="Your message", placeholder="Enter clinical input...", show_label=False) clear = gr.Button("Clear") def resolve_model_choice(mode, dropdown_value, textbox_value): return textbox_value.strip() if mode == "Enter custom model" else dropdown_value # Load on launch demo.load(fn=load_model_on_selection, inputs=default_model, outputs=model_status) # Load on user selection mode.select(fn=resolve_model_choice, inputs=[mode, model_selector, model_textbox], outputs=default_model).then( load_model_on_selection, inputs=default_model, outputs=model_status ) model_selector.change(fn=resolve_model_choice, inputs=[mode, model_selector, model_textbox], outputs=default_model).then( load_model_on_selection, inputs=default_model, outputs=model_status ) model_textbox.submit(fn=resolve_model_choice, inputs=[mode, model_selector, model_textbox], outputs=default_model).then( load_model_on_selection, inputs=default_model, outputs=model_status ) msg.submit(add_user_message, [msg, chatbot], [msg, chatbot], queue=False).then( chat_with_model, chatbot, chatbot ) clear.click(lambda: [], None, chatbot, queue=False) demo.launch()