File size: 3,238 Bytes
459ed9b
 
 
 
 
487acc5
459ed9b
 
b53f983
 
459ed9b
 
 
b53f983
459ed9b
 
 
 
 
 
b53f983
459ed9b
 
 
b53f983
459ed9b
 
 
b53f983
 
 
459ed9b
 
 
 
 
 
 
b53f983
 
487acc5
459ed9b
 
b53f983
459ed9b
b53f983
 
 
 
 
459ed9b
b53f983
459ed9b
b53f983
459ed9b
b53f983
 
 
459ed9b
 
b53f983
 
459ed9b
 
b53f983
459ed9b
 
 
b53f983
459ed9b
 
b53f983
 
459ed9b
 
 
 
 
 
 
b53f983
459ed9b
 
 
 
 
b53f983
 
459ed9b
b53f983
 
459ed9b
 
b53f983
459ed9b
 
 
 
b53f983
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import gradio as gr
from llama_cpp import Llama
from huggingface_hub import snapshot_download, login
from sklearn.feature_extraction.text import TfidfVectorizer
import faiss
import numpy as np

# ------------------ Model Setup ------------------

MODEL_REPO = "google/gemma-3-1b-it-qat-q4_0-gguf"
MODEL_PATH = "./gemma-3-1b-it-qat-q4_0/gemma-3-1b-it-q4_0.gguf"
MODEL_DIR = "./gemma-3-1b-it-qat-q4_0"

DEFAULT_SYSTEM_PROMPT = (
    "You are a Wise Mentor. Speak in a calm and concise manner. "
    "If asked for advice, give a maximum of 3 actionable steps. "
    "Avoid unnecessary elaboration. Decline unethical or harmful requests."
)

# Hugging Face Token and model download
hf_token = os.environ.get("HF_TOKEN")
if not os.path.exists(MODEL_PATH):
    if not hf_token:
        raise ValueError("HF_TOKEN is missing. Set it as an environment variable.")
    login(hf_token)
    snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, local_dir_use_symlinks=False)

# ------------------ RAG Setup ------------------

documents = []  # stores (user, bot) tuples
vectorizer = TfidfVectorizer()
index = None

def update_rag_index():
    global index
    if not documents:
        return
    flat_docs = [f"user: {u} bot: {b}" for u, b in documents]
    vectors = vectorizer.fit_transform(flat_docs).toarray().astype('float32')
    index = faiss.IndexFlatL2(vectors.shape[1])
    index.add(vectors)

def retrieve_relevant_docs(query, k=3):
    if not documents or index is None:
        return []
    flat_docs = [f"user: {u} bot: {b}" for u, b in documents]
    query_vec = vectorizer.transform([query]).toarray().astype('float32')
    D, I = index.search(query_vec, k)
    return [documents[i] for i in I[0] if i < len(documents)]

# ------------------ Context Estimation ------------------

def estimate_n_ctx(full_prompt, buffer=500):
    total_tokens = len(full_prompt.split())
    return min(3500, total_tokens + buffer)

# ------------------ Chat Function ------------------

def chat(user_input, history, system_prompt):
    relevant_context = retrieve_relevant_docs(user_input)
    formatted_turns = "".join([f"<user>{u}</user><bot>{b}</bot>" for u, b in relevant_context])

    full_prompt = (
        f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n"
        f"{formatted_turns}<user>{user_input}[/INST]"
    )

    n_ctx = estimate_n_ctx(full_prompt)

    llm = Llama(
        model_path=MODEL_PATH,
        n_ctx=n_ctx,
        n_threads=2,
        n_batch=128
    )

    output = llm(full_prompt, max_tokens=256, stop=["</s>", "<user>"])
    bot_reply = output["choices"][0]["text"].strip()

    documents.append((user_input, bot_reply))
    update_rag_index()

    history.append((user_input, bot_reply))
    return "", history

# ------------------ UI ------------------

with gr.Blocks() as demo:
    gr.Markdown("## πŸ€– Prompt-Engineered Persona Agent with Mini-RAG")
    system_prompt_box = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3)
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="Your Message")
    clear = gr.Button("πŸ—‘οΈ Clear Chat")

    msg.submit(chat, [msg, chatbot, system_prompt_box], [msg, chatbot])
    clear.click(lambda: [], None, chatbot)

demo.launch()