DurgaDeepak commited on
Commit
758a6cc
Β·
verified Β·
1 Parent(s): 9adbe41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -39
app.py CHANGED
@@ -1,85 +1,105 @@
1
- import gradio as gr
2
- import spaces
 
3
  import faiss
4
  import numpy as np
5
- from datasets import load_dataset
 
 
 
 
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import RagTokenizer, RagSequenceForGeneration
8
 
9
- # β€” Config β€”
10
- DATASET_NAME = "DurgaDeepak/meal_plans"
11
- INDEX_PATH = "mealplan.index"
12
- MODEL_NAME = "facebook/rag-sequence-nq"
 
13
 
14
- # β€” Load chunks & FAISS index β€”
15
- ds = load_dataset(DATASET_NAME, split="train")
16
- texts = ds["text"]
17
- sources = ds["source"]
18
- pages = ds["page"]
19
 
20
- # β€” Embeddings embedder & FAISS β€”
21
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
22
- chunk_embeddings = embedder.encode(texts, convert_to_numpy=True)
23
- index = faiss.read_index(INDEX_PATH)
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- # β€” RAG generator β€”
 
 
 
 
 
 
 
26
  tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
27
- model = RagSequenceForGeneration.from_pretrained(MODEL_NAME)
28
 
29
  @spaces.GPU
30
  def respond(
31
  message: str,
32
- history: list[tuple[str, str]],
33
  goal: str,
34
  diet: list[str],
35
  meals: int,
36
  avoid: str,
37
- weeks: str,
38
  ):
39
- # Parse preferences
40
  avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
41
  prefs = (
42
  f"Goal={goal}; Diet={','.join(diet)}; "
43
  f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}"
44
  )
45
-
46
- # 1) Query embedding & FAISS search
47
  q_emb = embedder.encode([message], convert_to_numpy=True)
48
- D, I = index.search(q_emb, 5) # top-5
49
- ctx_chunks = [
50
- f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0]
51
- ]
52
- context = "\n".join(ctx_chunks)
53
 
54
- # 2) Build prompt
55
  prompt = (
56
  "SYSTEM: Only answer using the provided CONTEXT. "
57
- "If it’s not there, say \"I'm sorry, I don't know.\"\n"
58
  f"PREFS: {prefs}\n"
59
  f"CONTEXT:\n{context}\n"
60
  f"Q: {message}\n"
61
  )
62
 
63
- # 3) Generate
64
  inputs = tokenizer([prompt], return_tensors="pt")
65
- outputs = model.generate(**inputs, num_beams=2, max_new_tokens=200)
66
  answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
67
 
68
- # 4) Update history
69
  history = history or []
70
  history.append((message, answer))
71
  return history
72
 
73
- # β€” Build Gradio chat interface β€”
74
- goal = gr.Dropdown(["Lose weight","Bulk","Maintain"], value="Lose weight", label="Goal")
75
  diet = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
76
- meals = gr.Slider(1,6,value=3,step=1,label="Meals per day")
77
- avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts...", label="Avoidances (comma-separated)")
78
- weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], value="1 week", label="Plan Length")
79
 
80
  demo = gr.ChatInterface(
81
  fn=respond,
82
- additional_inputs=[goal, diet, meals, avoid, weeks],
83
  )
84
 
85
  if __name__ == "__main__":
 
1
+ # app.py
2
+ import os
3
+ import glob
4
  import faiss
5
  import numpy as np
6
+
7
+ import gradio as gr
8
+ import spaces
9
+
10
+ from unstructured.partition.pdf import partition_pdf
11
  from sentence_transformers import SentenceTransformer
12
  from transformers import RagTokenizer, RagSequenceForGeneration
13
 
14
+ # ─── Configuration ─────────────────────────────────────────────
15
+ PDF_FOLDER = "meal_plans"
16
+ MODEL_NAME = "facebook/rag-sequence-nq"
17
+ EMBED_MODEL = "all-MiniLM-L6-v2"
18
+ TOP_K = 5
19
 
20
+ # ─── 1) LOAD + CHUNK ALL PDFs ──────────────────────────────────
21
+ rag_tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
22
+ texts, sources, pages = [], [], []
 
 
23
 
24
+ for pdf_path in glob.glob(f"{PDF_FOLDER}/*.pdf"):
25
+ book = os.path.basename(pdf_path)
26
+ pages_data = partition_pdf(filename=pdf_path)
27
+ for pg_num, page in enumerate(pages_data, start=1):
28
+ enc = rag_tokenizer(
29
+ page.text,
30
+ max_length=800,
31
+ truncation=True,
32
+ return_overflowing_tokens=True,
33
+ stride=50,
34
+ return_tensors="pt"
35
+ )
36
+ for token_ids in enc["input_ids"]:
37
+ chunk = rag_tokenizer.decode(token_ids, skip_special_tokens=True)
38
+ texts.append(chunk)
39
+ sources.append(book)
40
+ pages.append(pg_num)
41
 
42
+ # ─── 2) EMBED + BUILD FAISS INDEX ─────────────────────────────
43
+ embedder = SentenceTransformer(EMBED_MODEL)
44
+ embeddings = embedder.encode(texts, convert_to_numpy=True)
45
+ dim = embeddings.shape[1]
46
+ index = faiss.IndexFlatL2(dim)
47
+ index.add(embeddings)
48
+
49
+ # ─── 3) LOAD RAG GENERATOR ────────────────────────────────────
50
  tokenizer = RagTokenizer.from_pretrained(MODEL_NAME)
51
+ generator = RagSequenceForGeneration.from_pretrained(MODEL_NAME)
52
 
53
  @spaces.GPU
54
  def respond(
55
  message: str,
56
+ history: list[tuple[str,str]],
57
  goal: str,
58
  diet: list[str],
59
  meals: int,
60
  avoid: str,
61
+ weeks: str
62
  ):
63
+ # build prefs string
64
  avoid_list = [a.strip() for a in avoid.split(",") if a.strip()]
65
  prefs = (
66
  f"Goal={goal}; Diet={','.join(diet)}; "
67
  f"Meals={meals}/day; Avoid={','.join(avoid_list)}; Duration={weeks}"
68
  )
69
+ # 1) RETRIEVE top-k chunks
 
70
  q_emb = embedder.encode([message], convert_to_numpy=True)
71
+ D, I = index.search(q_emb, TOP_K)
72
+ context = "\n".join(f"[{sources[i]} p{pages[i]}] {texts[i]}" for i in I[0])
 
 
 
73
 
74
+ # 2) BUILD PROMPT with guardrail
75
  prompt = (
76
  "SYSTEM: Only answer using the provided CONTEXT. "
77
+ "If it’s not there, say \"I'm sorry, I don't know.\" \n"
78
  f"PREFS: {prefs}\n"
79
  f"CONTEXT:\n{context}\n"
80
  f"Q: {message}\n"
81
  )
82
 
83
+ # 3) GENERATE
84
  inputs = tokenizer([prompt], return_tensors="pt")
85
+ outputs = generator.generate(**inputs, num_beams=2, max_new_tokens=200)
86
  answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
87
 
88
+ # update chat history
89
  history = history or []
90
  history.append((message, answer))
91
  return history
92
 
93
+ # ─── 4) BUILD UI ────────────────────────────────────────────────
94
+ goal = gr.Dropdown(["Lose weight","Bulk","Maintain"], label="Goal", value="Lose weight")
95
  diet = gr.CheckboxGroup(["Omnivore","Vegetarian","Vegan","Keto","Paleo","Low-Carb"], label="Diet Style")
96
+ meals = gr.Slider(1, 6, step=1, value=3, label="Meals per day")
97
+ avoid = gr.Textbox(placeholder="e.g. Gluten, Dairy, Nuts…", label="Avoidances (comma-separated)")
98
+ weeks = gr.Dropdown(["1 week","2 weeks","3 weeks","4 weeks"], label="Plan Length", value="1 week")
99
 
100
  demo = gr.ChatInterface(
101
  fn=respond,
102
+ additional_inputs=[goal, diet, meals, avoid, weeks]
103
  )
104
 
105
  if __name__ == "__main__":