File size: 4,184 Bytes
abbab7a
404886a
 
b605fd6
404886a
1c250b5
 
 
 
b605fd6
 
abbab7a
b605fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c250b5
 
 
 
 
 
 
 
 
 
404886a
 
 
 
 
 
 
 
b605fd6
1c250b5
b605fd6
 
6161aaf
 
 
 
 
 
 
 
b605fd6
1c250b5
b605fd6
1c250b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b605fd6
 
 
 
 
 
 
404886a
 
 
b605fd6
404886a
b605fd6
 
404886a
1c250b5
404886a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import gradio as gr
from huggingface_hub import InferenceClient
from huggingface_hub.utils import HfHubHTTPError

# Modelo preferido
PREFERRED_MODEL = os.environ.get("MODEL_ID", "mistralai/Mistral-7B-Instruct-v0.2")
# Modelo de fallback atualizado
FALLBACK_MODEL = os.environ.get("FALLBACK_MODEL", "unsloth/Llama-3.2-3B-Instruct")

# token vindo do secret HF_TOKEN do Space (ou env local)
token = os.environ.get("HF_TOKEN")

def _extract_text_from_response(resp):
    if isinstance(resp, str):
        return resp
    try:
        if hasattr(resp, "generated_text"):
            return getattr(resp, "generated_text") or ""
        if hasattr(resp, "text"):
            return getattr(resp, "text") or ""
    except Exception:
        pass
    if isinstance(resp, dict):
        for key in ("generated_text", "generated_texts", "text", "output_text", "result"):
            if key in resp:
                v = resp[key]
                if isinstance(v, list) and v:
                    return v[0] if isinstance(v[0], str) else str(v[0])
                if isinstance(v, str):
                    return v
        if "choices" in resp and isinstance(resp["choices"], list) and resp["choices"]:
            first = resp["choices"][0]
            if isinstance(first, dict):
                if "message" in first and isinstance(first["message"], dict) and "content" in first["message"]:
                    maybe = first["message"]["content"]
                    if isinstance(maybe, str):
                        return maybe
                for k in ("text", "content", "generated_text"):
                    if k in first and isinstance(first[k], str):
                        return first[k]
    try:
        return str(resp)
    except Exception:
        return "<unable to decode response>"

def _call_model(model_id, prompt, max_new_tokens, temperature, top_p):
    client = InferenceClient(model=model_id, token=token)
    return client.text_generation(
        prompt,
        max_new_tokens=int(max_new_tokens),
        temperature=float(temperature),
        top_p=float(top_p),
        do_sample=True,
    )

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    if not token:
        yield "ERRO: variável HF_TOKEN não encontrada. Adicione o secret HF_TOKEN no Settings do Space."
        return

    prompt = f"{system_message}\n\n"
    for user_msg, bot_msg in history:
        if user_msg:
            prompt += f"User: {user_msg}\n"
        if bot_msg:
            prompt += f"Assistant: {bot_msg}\n"
    prompt += f"User: {message}\nAssistant:"

    try:
        out = _call_model(PREFERRED_MODEL, prompt, max_tokens, temperature, top_p)
    except HfHubHTTPError as e:
        try:
            code = e.response.status_code if e.response is not None else None
        except Exception:
            code = None

        if code == 404:
            yield f"Aviso: modelo `{PREFERRED_MODEL}` não disponível via Inference API (404). Tentando fallback para `{FALLBACK_MODEL}`..."
            try:
                out = _call_model(FALLBACK_MODEL, prompt, max_tokens, temperature, top_p)
            except Exception as e2:
                yield f"Falha no fallback para {FALLBACK_MODEL}: {e2}"
                return
        else:
            yield f"ERRO na chamada de inferência: {e}\n(verifique HF_TOKEN, permissões e se o modelo está disponível via Inference API)"
            return
    except Exception as e:
        yield f"Erro inesperado ao chamar a API: {e}"
        return

    text = _extract_text_from_response(out)
    yield text

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a helpful assistant.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.0, maximum=2.0, value=0.7, step=0.05, label="Temperature"),
        gr.Slider(minimum=0.0, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
    ],
    title="Chat (Mistral fallback com Llama 3.2 3B)",
)

if __name__ == "__main__":
    demo.launch()