File size: 3,238 Bytes
459ed9b 487acc5 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 487acc5 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 459ed9b b53f983 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import os
import gradio as gr
from llama_cpp import Llama
from huggingface_hub import snapshot_download, login
from sklearn.feature_extraction.text import TfidfVectorizer
import faiss
import numpy as np
# ------------------ Model Setup ------------------
MODEL_REPO = "google/gemma-3-1b-it-qat-q4_0-gguf"
MODEL_PATH = "./gemma-3-1b-it-qat-q4_0/gemma-3-1b-it-q4_0.gguf"
MODEL_DIR = "./gemma-3-1b-it-qat-q4_0"
DEFAULT_SYSTEM_PROMPT = (
"You are a Wise Mentor. Speak in a calm and concise manner. "
"If asked for advice, give a maximum of 3 actionable steps. "
"Avoid unnecessary elaboration. Decline unethical or harmful requests."
)
# Hugging Face Token and model download
hf_token = os.environ.get("HF_TOKEN")
if not os.path.exists(MODEL_PATH):
if not hf_token:
raise ValueError("HF_TOKEN is missing. Set it as an environment variable.")
login(hf_token)
snapshot_download(repo_id=MODEL_REPO, local_dir=MODEL_DIR, local_dir_use_symlinks=False)
# ------------------ RAG Setup ------------------
documents = [] # stores (user, bot) tuples
vectorizer = TfidfVectorizer()
index = None
def update_rag_index():
global index
if not documents:
return
flat_docs = [f"user: {u} bot: {b}" for u, b in documents]
vectors = vectorizer.fit_transform(flat_docs).toarray().astype('float32')
index = faiss.IndexFlatL2(vectors.shape[1])
index.add(vectors)
def retrieve_relevant_docs(query, k=3):
if not documents or index is None:
return []
flat_docs = [f"user: {u} bot: {b}" for u, b in documents]
query_vec = vectorizer.transform([query]).toarray().astype('float32')
D, I = index.search(query_vec, k)
return [documents[i] for i in I[0] if i < len(documents)]
# ------------------ Context Estimation ------------------
def estimate_n_ctx(full_prompt, buffer=500):
total_tokens = len(full_prompt.split())
return min(3500, total_tokens + buffer)
# ------------------ Chat Function ------------------
def chat(user_input, history, system_prompt):
relevant_context = retrieve_relevant_docs(user_input)
formatted_turns = "".join([f"<user>{u}</user><bot>{b}</bot>" for u, b in relevant_context])
full_prompt = (
f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n"
f"{formatted_turns}<user>{user_input}[/INST]"
)
n_ctx = estimate_n_ctx(full_prompt)
llm = Llama(
model_path=MODEL_PATH,
n_ctx=n_ctx,
n_threads=2,
n_batch=128
)
output = llm(full_prompt, max_tokens=256, stop=["</s>", "<user>"])
bot_reply = output["choices"][0]["text"].strip()
documents.append((user_input, bot_reply))
update_rag_index()
history.append((user_input, bot_reply))
return "", history
# ------------------ UI ------------------
with gr.Blocks() as demo:
gr.Markdown("## π€ Prompt-Engineered Persona Agent with Mini-RAG")
system_prompt_box = gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3)
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Your Message")
clear = gr.Button("ποΈ Clear Chat")
msg.submit(chat, [msg, chatbot, system_prompt_box], [msg, chatbot])
clear.click(lambda: [], None, chatbot)
demo.launch()
|