File size: 3,245 Bytes
459ed9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import gradio as gr
from llama_cpp import Llama
from huggingface_hub import snapshot_download, login
from sklearn.feature_extraction.text import TfidfVectorizer
import fiass
import numpy as np

#--------------------MODEL SETUP--------------------
MODEL_REPO = "google/gemma-3-1b-it-qat-q4_0-gguf"
MODEL_PATH = "./gemma-3-1b-it-qat-q4_0/gemma-3-1b-it-q4_0.gguf"
MODEL_DIR = "./gemma-3-1b-it-qat-q4_0"
DEFAULT_SYSTEM_PROMPT = (
    "You are a Wise Mentor. Speak in a calm and concise manner. "
    "If asked for advice, give a maximum of 3 actionable steps. "
    "Avoid unnecessary elaboration. Decline unethical or harmful requests."
)

# Huggingface Token and download
hf_token = os.environ.get("HF_TOKEN")
if not os.path.exists(MODEL_PATH):
    if not hf_token:
        raise ValueError("HF_TOKEN is missing. Set it as an environment variable")
    
    login(hf_token)
    snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, local_dir_use_symlinks=False)

#--------------------RAG SETUP------------------------
documents = [] # stores all chat turns
vectorizer = TfidfVectorizer()
index = None

def update_rag_index():
    global index
    if not documents:
        return
    vectors = vectorizer.fit_transform(documents).toarray().asype('float32')
    index = fiass.IndexFlatL2(vectors.shape[1])
    index.add(vectors)

def retrive_relvant_docs(query, k=2):
    if not documents or index is None:
        return ""
    
    query_vac = vectorizer.transform([query]).toarray().astype('float32')
    D, I = index.search(query_vac, k)
    return "\n".join(documents[i] for i in I[0] if i < len(documents))


#-----------------------CONTEXT LENGTH ESTIMATION---------------------
def estimate_n_ctx(full_prompt, buffer = 500):
    total_tokens = len(full_prompt.split())
    return min(3500, total_tokens+buffer)

#-----------------------CHAT FUNCTION-----------------------
def chat(user_input, history, system_prompt):
    relevent_context = retrive_relvant_docs(user_input)
    formatted_turns = "".join([f"<user>{u}</user><bot>{b}</bot>" for u, b in relevent_context])

    full_prompt = (
        f"<s>[INST] <<SYS>>\n{system_prompt}\nContext:\n{relevent_context}\n<</SYS>>\n"
        f"{formatted_turns}<user>{user_input}[/INST]"
    )

    # Dynamic estimate n_ctx
    n_ctx = estimate_n_ctx(full_prompt=full_prompt)

    llm = Llama(
        model_path= MODEL_PATH,
        n_ctx = n_ctx,
        n_threads=2,
        n_batch=128
    )

    output = llm(full_prompt, max_tokens=256, stop=["</s>", "<user>"])
    bot_reply = output["choices"][0]["text"].strip()

    documents.append(f"user: {user_input} bot: {bot_reply}")
    update_rag_index()

    history.append((user_input, bot_reply))
    return "", history

#-----------------------UI---------------------
with gr.Blocks() as demo:
    gr.Markdown("# πŸ€– Persona Agent with Mini-RAG + Dynamic Context")
    with gr.Row():
        system_prompt_box = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3)
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label="Your Message")
    clear = gr.Button("πŸ—‘οΈ Clear")

    msg.submit(chat, [msg, chatbot, system_prompt_box], [msg, chatbot])
    clear.click(lambda: [], None, chatbot)

demo.launch()