eat2fit / app.py
DurgaDeepak's picture
Update app.py
758a6cc verified
raw
history blame
3.98 kB
# app.py
import os
import glob
import faiss
import numpy as np
import gradio as gr
import spaces
from unstructured.partition.pdf import partition_pdf
from sentence_transformers import SentenceTransformer
from transformers import RagTokenizer, RagSequenceForGeneration
# ─── Configuration ─────────────────────────────────────────────
PDF_FOLDER = "meal_plans"
MODEL_NAME = "facebook/rag-sequence-nq"
EMBED_MODEL = "all-MiniLM-L6-v2"
TOP_K = 5
# ─── 1) LOAD + CHUNK ALL PDFs ──────────────────────────────────
rag_tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
texts, sources, pages = [], [], []
for pdf_path in glob.glob(f"{PDF_FOLDER}/*.pdf"):
book = os.path.basename(pdf_path)
pages_data = partition_pdf(filename=pdf_path)
for pg_num, page in enumerate(pages_data, start=1):
enc = rag_tokenizer(
page.text,
max_length=800,
truncation=True,
return_overflowing_tokens=True,
stride=50,
return_tensors="pt"
)
for token_ids in enc["input_ids"]:
chunk = rag_tokenizer.decode(token_ids, skip_special_tokens=True)
texts.append(chunk)
sources.append(book)
pages.append(pg_num)
# ─── 2) EMBED + BUILD FAISS INDEX ─────────────────────────────
embedder = SentenceTransformer(EMBED_MODEL)
embeddings = embedder.encode(texts, convert_to_numpy=True)
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embeddings)
# ─── 3) LOAD RAG GENERATOR ────────────────────────────────────
tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
generator = RagSequenceForGeneration.from_pretrained(MODEL_NAME)
@spaces.GPU
def respond(
message: str,
history: list[tuple[str,str]],
goal: str,
diet: list[str],
meals: int,
avoid: str,
weeks: str
):
# build prefs string
avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
prefs = (
f"Goal={goal}; Diet={','.join(diet)}; "
f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}"
)
# 1) RETRIEVE top-k chunks
q_emb = embedder.encode([message], convert_to_numpy=True)
D, I = index.search(q_emb, TOP_K)
context = "\n".join(f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0])
# 2) BUILD PROMPT with guardrail
prompt = (
"SYSTEM: Only answer using the provided CONTEXT. "
"If it’s not there, say \"I'm sorry, I don't know.\" \n"
f"PREFS: {prefs}\n"
f"CONTEXT:\n{context}\n"
f"Q: {message}\n"
)
# 3) GENERATE
inputs = tokenizer([prompt], return_tensors="pt")
outputs = generator.generate(**inputs, num_beams=2, max_new_tokens=200)
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
# update chat history
history = history or []
history.append((message, answer))
return history
# ─── 4) BUILD UI ────────────────────────────────────────────────
goal = gr.Dropdown(["Lose weight","Bulk","Maintain"], label="Goal", value="Lose weight")
diet = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
meals = gr.Slider(1, 6, step=1, value=3, label="Meals per day")
avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts…", label="Avoidances (comma-separated)")
weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], label="Plan Length", value="1 week")
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[goal, diet, meals, avoid, weeks]
)
if __name__ == "__main__":
demo.launch()