ujwal55's picture
Updated app.py
b53f983
import os
import gradio as gr
from llama_cpp import Llama
from huggingface_hub import snapshot_download, login
from sklearn.feature_extraction.text import TfidfVectorizer
import faiss
import numpy as np
# ------------------ Model Setup ------------------
MODEL_REPO = "google/gemma-3-1b-it-qat-q4_0-gguf"
MODEL_PATH = "./gemma-3-1b-it-qat-q4_0/gemma-3-1b-it-q4_0.gguf"
MODEL_DIR = "./gemma-3-1b-it-qat-q4_0"
DEFAULT_SYSTEM_PROMPT = (
"You are a Wise Mentor. Speak in a calm and concise manner. "
"If asked for advice, give a maximum of 3 actionable steps. "
"Avoid unnecessary elaboration. Decline unethical or harmful requests."
)
# Hugging Face Token and model download
hf_token = os.environ.get("HF_TOKEN")
if not os.path.exists(MODEL_PATH):
if not hf_token:
raise ValueError("HF_TOKEN is missing. Set it as an environment variable.")
login(hf_token)
snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, local_dir_use_symlinks=False)
# ------------------ RAG Setup ------------------
documents = [] # stores (user, bot) tuples
vectorizer = TfidfVectorizer()
index = None
def update_rag_index():
global index
if not documents:
return
flat_docs = [f"user: {u} bot: {b}" for u, b in documents]
vectors = vectorizer.fit_transform(flat_docs).toarray().astype('float32')
index = faiss.IndexFlatL2(vectors.shape[1])
index.add(vectors)
def retrieve_relevant_docs(query, k=3):
if not documents or index is None:
return []
flat_docs = [f"user: {u} bot: {b}" for u, b in documents]
query_vec = vectorizer.transform([query]).toarray().astype('float32')
D, I = index.search(query_vec, k)
return [documents[i] for i in I[0] if i < len(documents)]
# ------------------ Context Estimation ------------------
def estimate_n_ctx(full_prompt, buffer=500):
total_tokens = len(full_prompt.split())
return min(3500, total_tokens + buffer)
# ------------------ Chat Function ------------------
def chat(user_input, history, system_prompt):
relevant_context = retrieve_relevant_docs(user_input)
formatted_turns = "".join([f"<user>{u}</user><bot>{b}</bot>" for u, b in relevant_context])
full_prompt = (
f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n"
f"{formatted_turns}<user>{user_input}[/INST]"
)
n_ctx = estimate_n_ctx(full_prompt)
llm = Llama(
model_path=MODEL_PATH,
n_ctx=n_ctx,
n_threads=2,
n_batch=128
)
output = llm(full_prompt, max_tokens=256, stop=["</s>", "<user>"])
bot_reply = output["choices"][0]["text"].strip()
documents.append((user_input, bot_reply))
update_rag_index()
history.append((user_input, bot_reply))
return "", history
# ------------------ UI ------------------
with gr.Blocks() as demo:
gr.Markdown("## πŸ€– Prompt-Engineered Persona Agent with Mini-RAG")
system_prompt_box = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3)
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Your Message")
clear = gr.Button("πŸ—‘οΈ Clear Chat")
msg.submit(chat, [msg, chatbot, system_prompt_box], [msg, chatbot])
clear.click(lambda: [], None, chatbot)
demo.launch()