import os import threading import time import torch import gradio as gr from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, ) import spaces MODEL_ID = os.getenv("MODEL_ID", "yasserrmd/SoftwareArchitecture-Instruct-v1") # -------- Load model & tokenizer -------- print(f"Loading model: {MODEL_ID}") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", torch_dtype="auto", low_cpu_mem_usage=True, trust_remote_code=True, ) model.eval() # Ensure a pad token to avoid warnings on some bases if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token TITLE = "SoftwareArchitecture-Instruct v1 — Chat" DESCRIPTION = ( "An instruction-tuned LLM for **software architecture**. " "Built on LiquidAI/LFM2-1.2B, fine-tuned with the Software-Architecture dataset. " "Designed for technical professionals: accurate, detailed, and on-topic answers." ) SAMPLES = [ "Explain the API Gateway pattern and when to use it.", "CQRS vs Event Sourcing — how do they relate, and when would you combine them?", "Design a resilient payment workflow with retries, idempotency keys, and DLQ.", "Rate limiting strategies for a public REST API: token bucket vs sliding window.", "Multi-tenant SaaS: compare shared DB, schema, and dedicated DB for isolation.", "Blue/green vs canary deployments — trade-offs and where each fits best.", ] def format_history_as_messages(history): """ Convert Gradio chat history into OpenAI-style messages for apply_chat_template. history: list of tuples (user, assistant) """ messages = [] for (u, a) in history: if u: messages.append({"role": "user", "content": u}) if a: messages.append({"role": "assistant", "content": a}) return messages @spaces.GPU def stream_generate(messages, max_new_tokens, temperature, top_p, repetition_penalty, seed=None): if seed is not None and seed >= 0: torch.manual_seed(seed) inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt", tokenize=True, return_dict=True, ) # Keep only what the model expects allowed = {"input_ids", "attention_mask"} # no token_type_ids for causal LMs inputs = {k: v.to(model.device) for k, v in inputs.items() if k in allowed} streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) gen_kwargs = dict( **inputs, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), repetition_penalty=float(repetition_penalty), do_sample=temperature > 0, use_cache=True, streamer=streamer, ) thread = threading.Thread(target=model.generate, kwargs=gen_kwargs) thread.start() partial = "" for chunk in streamer: partial += chunk yield partial # -------- Gradio callbacks -------- @spaces.GPU def chat_respond(user_msg, chat_history, max_new_tokens, temperature, top_p, repetition_penalty, seed): if not user_msg or not user_msg.strip(): return gr.update(), chat_history # Add user turn chat_history = chat_history + [(user_msg, None)] # Build messages from full history messages = format_history_as_messages(chat_history) # Stream assistant output stream = stream_generate( messages=messages, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), repetition_penalty=float(repetition_penalty), seed=int(seed) if seed is not None else None, ) # Yield progressive updates for the last assistant turn final_assistant_text = "" for chunk in stream: final_assistant_text = chunk yield gr.update(value=chat_history[:-1] + [(user_msg, final_assistant_text)]), "" # Ensure final state returned chat_history[-1] = (user_msg, final_assistant_text) yield gr.update(value=chat_history), "" def use_sample(sample, chat_history): return sample, chat_history def clear_chat(): return [] # -------- UI -------- CUSTOM_CSS = """ :root { --brand: #0ea5e9; /* cyan-500 */ --ink: #0b1220; } .gradio-container { font-family: Inter, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, "Apple Color Emoji","Segoe UI Emoji"; } #title h1 { font-weight: 700; letter-spacing: -0.02em; } #desc { opacity: 0.9; } footer {visibility: hidden} """ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(primary_hue="cyan")) as demo: with gr.Row(): with gr.Column(): gr.HTML(f"