|
import os |
|
import gradio as gr |
|
from llama_cpp import Llama |
|
from huggingface_hub import snapshot_download, login |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
import faiss |
|
import numpy as np |
|
|
|
|
|
|
|
MODEL_REPO = "google/gemma-3-1b-it-qat-q4_0-gguf" |
|
MODEL_PATH = "./gemma-3-1b-it-qat-q4_0/gemma-3-1b-it-q4_0.gguf" |
|
MODEL_DIR = "./gemma-3-1b-it-qat-q4_0" |
|
|
|
DEFAULT_SYSTEM_PROMPT = ( |
|
"You are a Wise Mentor. Speak in a calm and concise manner. " |
|
"If asked for advice, give a maximum of 3 actionable steps. " |
|
"Avoid unnecessary elaboration. Decline unethical or harmful requests." |
|
) |
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
if not os.path.exists(MODEL_PATH): |
|
if not hf_token: |
|
raise ValueError("HF_TOKEN is missing. Set it as an environment variable.") |
|
login(hf_token) |
|
snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, local_dir_use_symlinks=False) |
|
|
|
|
|
|
|
documents = [] |
|
vectorizer = TfidfVectorizer() |
|
index = None |
|
|
|
def update_rag_index(): |
|
global index |
|
if not documents: |
|
return |
|
flat_docs = [f"user: {u} bot: {b}" for u, b in documents] |
|
vectors = vectorizer.fit_transform(flat_docs).toarray().astype('float32') |
|
index = faiss.IndexFlatL2(vectors.shape[1]) |
|
index.add(vectors) |
|
|
|
def retrieve_relevant_docs(query, k=3): |
|
if not documents or index is None: |
|
return [] |
|
flat_docs = [f"user: {u} bot: {b}" for u, b in documents] |
|
query_vec = vectorizer.transform([query]).toarray().astype('float32') |
|
D, I = index.search(query_vec, k) |
|
return [documents[i] for i in I[0] if i < len(documents)] |
|
|
|
|
|
|
|
def estimate_n_ctx(full_prompt, buffer=500): |
|
total_tokens = len(full_prompt.split()) |
|
return min(3500, total_tokens + buffer) |
|
|
|
|
|
|
|
def chat(user_input, history, system_prompt): |
|
relevant_context = retrieve_relevant_docs(user_input) |
|
formatted_turns = "".join([f"<user>{u}</user><bot>{b}</bot>" for u, b in relevant_context]) |
|
|
|
full_prompt = ( |
|
f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n" |
|
f"{formatted_turns}<user>{user_input}[/INST]" |
|
) |
|
|
|
n_ctx = estimate_n_ctx(full_prompt) |
|
|
|
llm = Llama( |
|
model_path=MODEL_PATH, |
|
n_ctx=n_ctx, |
|
n_threads=2, |
|
n_batch=128 |
|
) |
|
|
|
output = llm(full_prompt, max_tokens=256, stop=["</s>", "<user>"]) |
|
bot_reply = output["choices"][0]["text"].strip() |
|
|
|
documents.append((user_input, bot_reply)) |
|
update_rag_index() |
|
|
|
history.append((user_input, bot_reply)) |
|
return "", history |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## π€ Prompt-Engineered Persona Agent with Mini-RAG") |
|
system_prompt_box = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3) |
|
chatbot = gr.Chatbot() |
|
msg = gr.Textbox(label="Your Message") |
|
clear = gr.Button("ποΈ Clear Chat") |
|
|
|
msg.submit(chat, [msg, chatbot, system_prompt_box], [msg, chatbot]) |
|
clear.click(lambda: [], None, chatbot) |
|
|
|
demo.launch() |
|
|