Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import glob | |
import faiss | |
import numpy as np | |
import gradio as gr | |
import spaces | |
from unstructured.partition.pdf import partition_pdf | |
from sentence_transformers import SentenceTransformer | |
from transformers import RagTokenizer, RagSequenceForGeneration | |
# βββ Configuration βββββββββββββββββββββββββββββββββββββββββββββ | |
PDF_FOLDER = "meal_plans" | |
MODEL_NAME = "facebook/rag-sequence-nq" | |
EMBED_MODEL = "all-MiniLM-L6-v2" | |
TOP_K = 5 | |
# βββ 1) LOAD + CHUNK ALL PDFs ββββββββββββββββββββββββββββββββββ | |
rag_tokenizer = RagTokenizer.from_pretrained(MODEL_NAME) | |
texts, sources, pages = [], [], [] | |
for pdf_path in glob.glob(f"{PDF_FOLDER}/*.pdf"): | |
book = os.path.basename(pdf_path) | |
pages_data = partition_pdf(filename=pdf_path) | |
for pg_num, page in enumerate(pages_data, start=1): | |
enc = rag_tokenizer( | |
page.text, | |
max_length=800, | |
truncation=True, | |
return_overflowing_tokens=True, | |
stride=50, | |
return_tensors="pt" | |
) | |
for token_ids in enc["input_ids"]: | |
chunk = rag_tokenizer.decode(token_ids, skip_special_tokens=True) | |
texts.append(chunk) | |
sources.append(book) | |
pages.append(pg_num) | |
# βββ 2) EMBED + BUILD FAISS INDEX βββββββββββββββββββββββββββββ | |
embedder = SentenceTransformer(EMBED_MODEL) | |
embeddings = embedder.encode(texts, convert_to_numpy=True) | |
dim = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dim) | |
index.add(embeddings) | |
# βββ 3) LOAD RAG GENERATOR ββββββββββββββββββββββββββββββββββββ | |
tokenizer = RagTokenizer.from_pretrained(MODEL_NAME) | |
generator = RagSequenceForGeneration.from_pretrained(MODEL_NAME) | |
def respond( | |
message: str, | |
history: list[tuple[str,str]], | |
goal: str, | |
diet: list[str], | |
meals: int, | |
avoid: str, | |
weeks: str | |
): | |
# build prefs string | |
avoid_list = [a.strip() for a in avoid.split(",") if a.strip()] | |
prefs = ( | |
f"Goal={goal}; Diet={','.join(diet)}; " | |
f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}" | |
) | |
# 1) RETRIEVE top-k chunks | |
q_emb = embedder.encode([message], convert_to_numpy=True) | |
D, I = index.search(q_emb, TOP_K) | |
context = "\n".join(f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0]) | |
# 2) BUILD PROMPT with guardrail | |
prompt = ( | |
"SYSTEM: Only answer using the provided CONTEXT. " | |
"If itβs not there, say \"I'm sorry, I don't know.\" \n" | |
f"PREFS: {prefs}\n" | |
f"CONTEXT:\n{context}\n" | |
f"Q: {message}\n" | |
) | |
# 3) GENERATE | |
inputs = tokenizer([prompt], return_tensors="pt") | |
outputs = generator.generate(**inputs, num_beams=2, max_new_tokens=200) | |
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
# update chat history | |
history = history or [] | |
history.append((message, answer)) | |
return history | |
# βββ 4) BUILD UI ββββββββββββββββββββββββββββββββββββββββββββββββ | |
goal = gr.Dropdown(["Lose weight","Bulk","Maintain"], label="Goal", value="Lose weight") | |
diet = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style") | |
meals = gr.Slider(1, 6, step=1, value=3, label="Meals per day") | |
avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nutsβ¦", label="Avoidances (comma-separated)") | |
weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], label="Plan Length", value="1 week") | |
demo = gr.ChatInterface( | |
fn=respond, | |
additional_inputs=[goal, diet, meals, avoid, weeks] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |