File size: 7,528 Bytes
1dbadd4
 
 
 
 
 
 
 
 
 
 
efb082b
 
 
1dbadd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efb082b
1dbadd4
 
 
 
 
 
efb082b
1dbadd4
 
 
 
efb082b
 
 
 
1dbadd4
 
 
 
efb082b
1dbadd4
 
 
efb082b
1dbadd4
 
 
 
 
 
 
efb082b
 
 
 
1dbadd4
 
efb082b
1dbadd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
878e6b8
 
 
 
 
 
1dbadd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import os
import threading
import time
import torch
import gradio as gr
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)

import spaces


MODEL_ID = os.getenv("MODEL_ID", "yasserrmd/SoftwareArchitecture-Instruct-v1")

# -------- Load model & tokenizer --------
print(f"Loading model: {MODEL_ID}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype="auto",
    low_cpu_mem_usage=True,
    trust_remote_code=True,
)
model.eval()

# Ensure a pad token to avoid warnings on some bases
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token

TITLE = "SoftwareArchitecture-Instruct v1 — Chat"
DESCRIPTION = (
    "An instruction-tuned LLM for **software architecture**. "
    "Built on LiquidAI/LFM2-1.2B, fine-tuned with the Software-Architecture dataset. "
    "Designed for technical professionals: accurate, detailed, and on-topic answers."
)

SAMPLES = [
    "Explain the API Gateway pattern and when to use it.",
    "CQRS vs Event Sourcing — how do they relate, and when would you combine them?",
    "Design a resilient payment workflow with retries, idempotency keys, and DLQ.",
    "Rate limiting strategies for a public REST API: token bucket vs sliding window.",
    "Multi-tenant SaaS: compare shared DB, schema, and dedicated DB for isolation.",
    "Blue/green vs canary deployments — trade-offs and where each fits best.",
]

def format_history_as_messages(history):
    """
    Convert Gradio chat history into OpenAI-style messages for apply_chat_template.
    history: list of tuples (user, assistant)
    """
    messages = []
    for (u, a) in history:
        if u:
            messages.append({"role": "user", "content": u})
        if a:
            messages.append({"role": "assistant", "content": a})
    return messages

@spaces.GPU
def stream_generate(messages, max_new_tokens, temperature, top_p, repetition_penalty, seed=None):
    if seed is not None and seed >= 0:
        torch.manual_seed(seed)

    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        tokenize=True,
        return_dict=True,
    )

    # Keep only what the model expects
    allowed = {"input_ids", "attention_mask"}  # no token_type_ids for causal LMs
    inputs = {k: v.to(model.device) for k, v in inputs.items() if k in allowed}

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    gen_kwargs = dict(
        **inputs,
        max_new_tokens=int(max_new_tokens),
        temperature=float(temperature),
        top_p=float(top_p),
        repetition_penalty=float(repetition_penalty),
        do_sample=temperature > 0,
        use_cache=True,
        streamer=streamer,
    )

    thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    partial = ""
    for chunk in streamer:
        partial += chunk
        yield partial

# -------- Gradio callbacks --------
@spaces.GPU
def chat_respond(user_msg, chat_history, max_new_tokens, temperature, top_p, repetition_penalty, seed):
    if not user_msg or not user_msg.strip():
        return gr.update(), chat_history

    # Add user turn
    chat_history = chat_history + [(user_msg, None)]

    # Build messages from full history
    messages = format_history_as_messages(chat_history)

    # Stream assistant output
    stream = stream_generate(
        messages=messages,
        max_new_tokens=int(max_new_tokens),
        temperature=float(temperature),
        top_p=float(top_p),
        repetition_penalty=float(repetition_penalty),
        seed=int(seed) if seed is not None else None,
    )

    # Yield progressive updates for the last assistant turn
    final_assistant_text = ""
    for chunk in stream:
        final_assistant_text = chunk
        yield gr.update(value=chat_history[:-1] + [(user_msg, final_assistant_text)]), ""

    # Ensure final state returned
    chat_history[-1] = (user_msg, final_assistant_text)
    yield gr.update(value=chat_history), ""

def use_sample(sample, chat_history):
    return sample, chat_history

def clear_chat():
    return []

# -------- UI --------

CUSTOM_CSS = """
:root {
  --brand: #0ea5e9; /* cyan-500 */
  --ink: #0b1220;
}
.gradio-container {
  font-family: Inter, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, "Apple Color Emoji","Segoe UI Emoji";
}
#title h1 {
  font-weight: 700;
  letter-spacing: -0.02em;
}
#desc {
  opacity: 0.9;
}
footer {visibility: hidden}
"""

with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft(primary_hue="cyan")) as demo:
    with gr.Row():
        with gr.Column():
            gr.HTML(f"<div id='title'><h1>{TITLE}</h1></div>")
            gr.Markdown(f"<div id='desc'>{DESCRIPTION}</div>", elem_id="desc")

    with gr.Row():
        with gr.Column(scale=4):
            chat = gr.Chatbot(
                    label="SoftwareArchitecture-Instruct v1",
                    avatar_images=(None, None),
                    height=480,
                    bubble_full_width=False,
                    sanitize_html=False,
                )
            with gr.Row():
                user_box = gr.Textbox(
                    placeholder="Ask about software architecture…",
                    show_label=False,
                    lines=3,
                    autofocus=True,
                    scale=4,
                )
                send_btn = gr.Button("Send", variant="primary", scale=1)

            with gr.Accordion("Generation Settings", open=False):
                max_new_tokens = gr.Slider(64, 1024, value=256, step=16, label="Max new tokens")
                temperature = gr.Slider(0.0, 1.5, value=0.3, step=0.05, label="Temperature")
                top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
                repetition_penalty = gr.Slider(1.0, 1.5, value=1.05, step=0.01, label="Repetition penalty")
                seed = gr.Number(value=-1, precision=0, label="Seed (-1 for random)")

            with gr.Row():
                clear_btn = gr.Button("Clear", variant="secondary")
                # sample buttons
                sample_dropdown = gr.Dropdown(choices=SAMPLES, label="Samples", value=None)
                use_sample_btn = gr.Button("Use Sample")

        with gr.Column(scale=2):
            gr.Markdown("### Samples")
            gr.Markdown("\n".join([f"• {s}" for s in SAMPLES]))
            gr.Markdown("—\n**Tip:** Increase *Max new tokens* for longer, more complete answers.")

    # Events
    send_btn.click(
        chat_respond,
        inputs=[user_box, chat, max_new_tokens, temperature, top_p, repetition_penalty, seed],
        outputs=[chat, user_box],
        queue=True,
        show_progress=True,
    )
    user_box.submit(
        chat_respond,
        inputs=[user_box, chat, max_new_tokens, temperature, top_p, repetition_penalty, seed],
        outputs=[chat, user_box],
        queue=True,
        show_progress=True,
    )
    clear_btn.click(fn=clear_chat, outputs=chat)

    use_sample_btn.click(use_sample, inputs=[sample_dropdown, chat], outputs=[user_box, chat])

    gr.Markdown(
        "—\nBuilt for engineers and architects. Base model: **LiquidAI/LFM2-1.2B** · Fine-tuned: **Software-Architecture** dataset."
    )

if __name__ == "__main__":
    demo.queue().launch()