Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,85 +1,105 @@
|
|
1 |
-
|
2 |
-
import
|
|
|
3 |
import faiss
|
4 |
import numpy as np
|
5 |
-
|
|
|
|
|
|
|
|
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
from transformers import RagTokenizer, RagSequenceForGeneration
|
8 |
|
9 |
-
#
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
13 |
|
14 |
-
#
|
15 |
-
|
16 |
-
texts
|
17 |
-
sources = ds["source"]
|
18 |
-
pages = ds["page"]
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
|
27 |
-
|
28 |
|
29 |
@spaces.GPU
|
30 |
def respond(
|
31 |
message: str,
|
32 |
-
history: list[tuple[str,
|
33 |
goal: str,
|
34 |
diet: list[str],
|
35 |
meals: int,
|
36 |
avoid: str,
|
37 |
-
weeks: str
|
38 |
):
|
39 |
-
#
|
40 |
avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
|
41 |
prefs = (
|
42 |
f"Goal={goal}; Diet={','.join(diet)}; "
|
43 |
f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}"
|
44 |
)
|
45 |
-
|
46 |
-
# 1) Query embedding & FAISS search
|
47 |
q_emb = embedder.encode([message], convert_to_numpy=True)
|
48 |
-
D, I = index.search(q_emb,
|
49 |
-
|
50 |
-
f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0]
|
51 |
-
]
|
52 |
-
context = "\n".join(ctx_chunks)
|
53 |
|
54 |
-
# 2)
|
55 |
prompt = (
|
56 |
"SYSTEM: Only answer using the provided CONTEXT. "
|
57 |
-
"If itβs not there, say \"I'm sorry, I don't know.\"\n"
|
58 |
f"PREFS: {prefs}\n"
|
59 |
f"CONTEXT:\n{context}\n"
|
60 |
f"Q: {message}\n"
|
61 |
)
|
62 |
|
63 |
-
# 3)
|
64 |
inputs = tokenizer([prompt], return_tensors="pt")
|
65 |
-
outputs =
|
66 |
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
67 |
|
68 |
-
#
|
69 |
history = history or []
|
70 |
history.append((message, answer))
|
71 |
return history
|
72 |
|
73 |
-
#
|
74 |
-
goal = gr.Dropdown(["Lose weight","Bulk","Maintain"],
|
75 |
diet = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
|
76 |
-
meals = gr.Slider(1,6,value=3,
|
77 |
-
avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts
|
78 |
-
weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"],
|
79 |
|
80 |
demo = gr.ChatInterface(
|
81 |
fn=respond,
|
82 |
-
additional_inputs=[goal, diet, meals, avoid, weeks]
|
83 |
)
|
84 |
|
85 |
if __name__ == "__main__":
|
|
|
1 |
+
# app.py
|
2 |
+
import os
|
3 |
+
import glob
|
4 |
import faiss
|
5 |
import numpy as np
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import spaces
|
9 |
+
|
10 |
+
from unstructured.partition.pdf import partition_pdf
|
11 |
from sentence_transformers import SentenceTransformer
|
12 |
from transformers import RagTokenizer, RagSequenceForGeneration
|
13 |
|
14 |
+
# βββ Configuration βββββββββββββββββββββββββββββββββββββββββββββ
|
15 |
+
PDF_FOLDER = "meal_plans"
|
16 |
+
MODEL_NAME = "facebook/rag-sequence-nq"
|
17 |
+
EMBED_MODEL = "all-MiniLM-L6-v2"
|
18 |
+
TOP_K = 5
|
19 |
|
20 |
+
# βββ 1) LOAD + CHUNK ALL PDFs ββββββββββββββββββββββββββββββββββ
|
21 |
+
rag_tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
|
22 |
+
texts, sources, pages = [], [], []
|
|
|
|
|
23 |
|
24 |
+
for pdf_path in glob.glob(f"{PDF_FOLDER}/*.pdf"):
|
25 |
+
book = os.path.basename(pdf_path)
|
26 |
+
pages_data = partition_pdf(filename=pdf_path)
|
27 |
+
for pg_num, page in enumerate(pages_data, start=1):
|
28 |
+
enc = rag_tokenizer(
|
29 |
+
page.text,
|
30 |
+
max_length=800,
|
31 |
+
truncation=True,
|
32 |
+
return_overflowing_tokens=True,
|
33 |
+
stride=50,
|
34 |
+
return_tensors="pt"
|
35 |
+
)
|
36 |
+
for token_ids in enc["input_ids"]:
|
37 |
+
chunk = rag_tokenizer.decode(token_ids, skip_special_tokens=True)
|
38 |
+
texts.append(chunk)
|
39 |
+
sources.append(book)
|
40 |
+
pages.append(pg_num)
|
41 |
|
42 |
+
# βββ 2) EMBED + BUILD FAISS INDEX βββββββββββββββββββββββββββββ
|
43 |
+
embedder = SentenceTransformer(EMBED_MODEL)
|
44 |
+
embeddings = embedder.encode(texts, convert_to_numpy=True)
|
45 |
+
dim = embeddings.shape[1]
|
46 |
+
index = faiss.IndexFlatL2(dim)
|
47 |
+
index.add(embeddings)
|
48 |
+
|
49 |
+
# βββ 3) LOAD RAG GENERATOR ββββββββββββββββββββββββββββββββββββ
|
50 |
tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
|
51 |
+
generator = RagSequenceForGeneration.from_pretrained(MODEL_NAME)
|
52 |
|
53 |
@spaces.GPU
|
54 |
def respond(
|
55 |
message: str,
|
56 |
+
history: list[tuple[str,str]],
|
57 |
goal: str,
|
58 |
diet: list[str],
|
59 |
meals: int,
|
60 |
avoid: str,
|
61 |
+
weeks: str
|
62 |
):
|
63 |
+
# build prefs string
|
64 |
avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
|
65 |
prefs = (
|
66 |
f"Goal={goal}; Diet={','.join(diet)}; "
|
67 |
f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}"
|
68 |
)
|
69 |
+
# 1) RETRIEVE top-k chunks
|
|
|
70 |
q_emb = embedder.encode([message], convert_to_numpy=True)
|
71 |
+
D, I = index.search(q_emb, TOP_K)
|
72 |
+
context = "\n".join(f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0])
|
|
|
|
|
|
|
73 |
|
74 |
+
# 2) BUILD PROMPT with guardrail
|
75 |
prompt = (
|
76 |
"SYSTEM: Only answer using the provided CONTEXT. "
|
77 |
+
"If itβs not there, say \"I'm sorry, I don't know.\" \n"
|
78 |
f"PREFS: {prefs}\n"
|
79 |
f"CONTEXT:\n{context}\n"
|
80 |
f"Q: {message}\n"
|
81 |
)
|
82 |
|
83 |
+
# 3) GENERATE
|
84 |
inputs = tokenizer([prompt], return_tensors="pt")
|
85 |
+
outputs = generator.generate(**inputs, num_beams=2, max_new_tokens=200)
|
86 |
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
87 |
|
88 |
+
# update chat history
|
89 |
history = history or []
|
90 |
history.append((message, answer))
|
91 |
return history
|
92 |
|
93 |
+
# βββ 4) BUILD UI ββββββββββββββββββββββββββββββββββββββββββββββββ
|
94 |
+
goal = gr.Dropdown(["Lose weight","Bulk","Maintain"], label="Goal", value="Lose weight")
|
95 |
diet = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
|
96 |
+
meals = gr.Slider(1, 6, step=1, value=3, label="Meals per day")
|
97 |
+
avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nutsβ¦", label="Avoidances (comma-separated)")
|
98 |
+
weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], label="Plan Length", value="1 week")
|
99 |
|
100 |
demo = gr.ChatInterface(
|
101 |
fn=respond,
|
102 |
+
additional_inputs=[goal, diet, meals, avoid, weeks]
|
103 |
)
|
104 |
|
105 |
if __name__ == "__main__":
|