import os import gradio as gr from llama_cpp import Llama from huggingface_hub import snapshot_download, login from sklearn.feature_extraction.text import TfidfVectorizer import faiss import numpy as np # ------------------ Model Setup ------------------ MODEL_REPO = "google/gemma-3-1b-it-qat-q4_0-gguf" MODEL_PATH = "./gemma-3-1b-it-qat-q4_0/gemma-3-1b-it-q4_0.gguf" MODEL_DIR = "./gemma-3-1b-it-qat-q4_0" DEFAULT_SYSTEM_PROMPT = ( "You are a Wise Mentor. Speak in a calm and concise manner. " "If asked for advice, give a maximum of 3 actionable steps. " "Avoid unnecessary elaboration. Decline unethical or harmful requests." ) # Hugging Face Token and model download hf_token = os.environ.get("HF_TOKEN") if not os.path.exists(MODEL_PATH): if not hf_token: raise ValueError("HF_TOKEN is missing. Set it as an environment variable.") login(hf_token) snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, local_dir_use_symlinks=False) # ------------------ RAG Setup ------------------ documents = [] # stores (user, bot) tuples vectorizer = TfidfVectorizer() index = None def update_rag_index(): global index if not documents: return flat_docs = [f"user: {u} bot: {b}" for u, b in documents] vectors = vectorizer.fit_transform(flat_docs).toarray().astype('float32') index = faiss.IndexFlatL2(vectors.shape[1]) index.add(vectors) def retrieve_relevant_docs(query, k=3): if not documents or index is None: return [] flat_docs = [f"user: {u} bot: {b}" for u, b in documents] query_vec = vectorizer.transform([query]).toarray().astype('float32') D, I = index.search(query_vec, k) return [documents[i] for i in I[0] if i < len(documents)] # ------------------ Context Estimation ------------------ def estimate_n_ctx(full_prompt, buffer=500): total_tokens = len(full_prompt.split()) return min(3500, total_tokens + buffer) # ------------------ Chat Function ------------------ def chat(user_input, history, system_prompt): relevant_context = retrieve_relevant_docs(user_input) formatted_turns = "".join([f"{u}{b}" for u, b in relevant_context]) full_prompt = ( f"[INST] <>\n{system_prompt}\n<>\n" f"{formatted_turns}{user_input}[/INST]" ) n_ctx = estimate_n_ctx(full_prompt) llm = Llama( model_path=MODEL_PATH, n_ctx=n_ctx, n_threads=2, n_batch=128 ) output = llm(full_prompt, max_tokens=256, stop=["", ""]) bot_reply = output["choices"][0]["text"].strip() documents.append((user_input, bot_reply)) update_rag_index() history.append((user_input, bot_reply)) return "", history # ------------------ UI ------------------ with gr.Blocks() as demo: gr.Markdown("## 🤖 Prompt-Engineered Persona Agent with Mini-RAG") system_prompt_box = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3) chatbot = gr.Chatbot() msg = gr.Textbox(label="Your Message") clear = gr.Button("🗑️ Clear Chat") msg.submit(chat, [msg, chatbot, system_prompt_box], [msg, chatbot]) clear.click(lambda: [], None, chatbot) demo.launch()