File size: 3,980 Bytes
758a6cc
 
 
9b1fba6
 
758a6cc
 
 
 
 
9b1fba6
 
6edbce0
758a6cc
 
 
 
 
9b1fba6
758a6cc
 
 
9b1fba6
758a6cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6edbce0
758a6cc
 
 
 
 
 
 
 
e05063c
758a6cc
e05063c
30f2776
6edbce0
e05063c
758a6cc
e05063c
 
 
 
758a6cc
6edbce0
758a6cc
e05063c
 
9b1fba6
 
e05063c
758a6cc
9b1fba6
758a6cc
 
9b1fba6
758a6cc
e05063c
 
758a6cc
e05063c
9b1fba6
e05063c
 
9b1fba6
758a6cc
e05063c
758a6cc
e05063c
6edbce0
758a6cc
e05063c
 
 
6edbce0
758a6cc
 
e05063c
758a6cc
 
 
30f2776
6edbce0
e05063c
758a6cc
6edbce0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# app.py
import os
import glob
import faiss
import numpy as np

import gradio as gr
import spaces

from unstructured.partition.pdf import partition_pdf
from sentence_transformers import SentenceTransformer
from transformers import RagTokenizer, RagSequenceForGeneration

# ─── Configuration ─────────────────────────────────────────────
PDF_FOLDER = "meal_plans"
MODEL_NAME = "facebook/rag-sequence-nq"
EMBED_MODEL = "all-MiniLM-L6-v2"
TOP_K = 5

# ─── 1) LOAD + CHUNK ALL PDFs ──────────────────────────────────
rag_tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
texts, sources, pages = [], [], []

for pdf_path in glob.glob(f"{PDF_FOLDER}/*.pdf"):
    book = os.path.basename(pdf_path)
    pages_data = partition_pdf(filename=pdf_path)
    for pg_num, page in enumerate(pages_data, start=1):
        enc = rag_tokenizer(
            page.text,
            max_length=800,
            truncation=True,
            return_overflowing_tokens=True,
            stride=50,
            return_tensors="pt"
        )
        for token_ids in enc["input_ids"]:
            chunk = rag_tokenizer.decode(token_ids, skip_special_tokens=True)
            texts.append(chunk)
            sources.append(book)
            pages.append(pg_num)

# ─── 2) EMBED + BUILD FAISS INDEX ─────────────────────────────
embedder = SentenceTransformer(EMBED_MODEL)
embeddings = embedder.encode(texts, convert_to_numpy=True)
dim = embeddings.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(embeddings)

# ─── 3) LOAD RAG GENERATOR ────────────────────────────────────
tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
generator = RagSequenceForGeneration.from_pretrained(MODEL_NAME)

@spaces.GPU
def respond(
    message: str,
    history: list[tuple[str,str]],
    goal: str,
    diet: list[str],
    meals: int,
    avoid: str,
    weeks: str
):
    # build prefs string
    avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
    prefs = (
        f"Goal={goal}; Diet={','.join(diet)}; "
        f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}"
    )
    # 1) RETRIEVE top-k chunks
    q_emb = embedder.encode([message], convert_to_numpy=True)
    D, I  = index.search(q_emb, TOP_K)
    context = "\n".join(f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0])

    # 2) BUILD PROMPT with guardrail
    prompt = (
        "SYSTEM: Only answer using the provided CONTEXT. "
        "If it’s not there, say \"I'm sorry, I don't know.\" \n"
        f"PREFS: {prefs}\n"
        f"CONTEXT:\n{context}\n"
        f"Q: {message}\n"
    )

    # 3) GENERATE
    inputs  = tokenizer([prompt], return_tensors="pt")
    outputs = generator.generate(**inputs, num_beams=2, max_new_tokens=200)
    answer  = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

    # update chat history
    history = history or []
    history.append((message, answer))
    return history

# ─── 4) BUILD UI ────────────────────────────────────────────────
goal  = gr.Dropdown(["Lose weight","Bulk","Maintain"], label="Goal", value="Lose weight")
diet  = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
meals = gr.Slider(1, 6, step=1, value=3, label="Meals per day")
avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts…", label="Avoidances (comma-separated)")
weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], label="Plan Length", value="1 week")

demo = gr.ChatInterface(
    fn=respond,
    additional_inputs=[goal, diet, meals, avoid, weeks]
)

if __name__ == "__main__":
    demo.launch()