File size: 2,626 Bytes
52381a8
7aa27d3
52381a8
 
 
7aa27d3
52381a8
fee4ce3
52381a8
 
 
7aa27d3
d86b11c
52381a8
 
d86b11c
52381a8
7aa27d3
52381a8
 
d86b11c
 
52381a8
d86b11c
52381a8
7aa27d3
d86b11c
52381a8
 
7aa27d3
 
52381a8
d86b11c
52381a8
 
 
 
 
 
 
7aa27d3
52381a8
 
7aa27d3
52381a8
 
 
d86b11c
52381a8
 
 
 
 
 
 
 
 
 
 
 
 
 
7aa27d3
52381a8
 
7aa27d3
52381a8
 
 
 
7aa27d3
52381a8
 
7aa27d3
 
 
 
 
 
52381a8
7aa27d3
 
 
 
 
52381a8
d86b11c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import threading

import app_math as app_math  # keeping your existing import

# ---- Model setup ----
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
MODEL_ID = "HuggingFaceH4/zephyr-7b-beta"

# Automatically map model across available devices (GPU/CPU)
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    token=HF_TOKEN,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",          # << key change
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    low_cpu_mem_usage=True,
    token=HF_TOKEN,
)

# Ensure pad token is set
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token_id = tokenizer.eos_token_id


def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p):
    # Build chat messages
    messages = [{"role": "system", "content": system_message}]
    for u, a in history:
        if u:
            messages.append({"role": "user", "content": u})
        if a:
            messages.append({"role": "assistant", "content": a})
    messages.append({"role": "user", "content": message})

    # Tokenize with Zephyr's chat template
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
    ).to(model.device)

    # Stream generation
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    gen_kwargs = {
        "inputs": inputs,
        "max_new_tokens": int(max_tokens),
        "do_sample": True,
        "temperature": float(temperature),
        "top_p": float(top_p),
        "eos_token_id": tokenizer.eos_token_id,
        "pad_token_id": tokenizer.pad_token_id,
        "streamer": streamer,
    }

    thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    partial = ""
    for new_text in streamer:
        partial += new_text
        yield partial


# ---- Gradio UI ----
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch()