Spaces:
Running
Running
Commit
·
4363820
1
Parent(s):
7da6e45
"LOL"
Browse files- app.py +31 -0
- chatbot.py +85 -0
- data/restaurants.json +162 -0
- embeddings/embedder.py +17 -0
- generation/llm.py +149 -0
- llm.py +149 -0
- requirements.txt +7 -0
- retrieval/hybrid_search.py +49 -0
- retrieval/keyword_filter.py +41 -0
- retrieval/vector_store.py +34 -0
- src/chatbot.py +85 -0
- src/embeddings/embedder.py +17 -0
- src/generation/llm.py +149 -0
- src/retrieval/hybrid_search.py +49 -0
- src/retrieval/keyword_filter.py +41 -0
- src/retrieval/vector_store.py +34 -0
- src/utils/data_loader.py +29 -0
- src/utils/query_parser.py +76 -0
- utils/data_loader.py +29 -0
- utils/query_parser.py +76 -0
app.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from src.chatbot import RestaurantChatbot
|
3 |
+
|
4 |
+
chatbot = RestaurantChatbot()
|
5 |
+
chat_history = []
|
6 |
+
|
7 |
+
def respond(user_message, history):
|
8 |
+
response, retrieved_docs = chatbot.answer(user_message)
|
9 |
+
|
10 |
+
bot_response = f"{response}\n\n**Nhà hàng gợi ý:**\n"
|
11 |
+
if retrieved_docs:
|
12 |
+
for doc in retrieved_docs:
|
13 |
+
bot_response += (
|
14 |
+
f"- **{doc['name']} ({doc['cuisine']})**\n"
|
15 |
+
f" - Món ăn: {', '.join(doc['dishes'])}\n"
|
16 |
+
f" - Giá: {doc['price_range']}\n"
|
17 |
+
f" - Khoảng cách: {doc['distance']} km\n"
|
18 |
+
f" - Đánh giá: {doc['rating']}\n"
|
19 |
+
f" - Địa chỉ: {doc['address']}\n"
|
20 |
+
f" - Mô tả: {doc['description']}\n"
|
21 |
+
)
|
22 |
+
else:
|
23 |
+
bot_response += "- Không tìm thấy nhà hàng phù hợp."
|
24 |
+
|
25 |
+
return bot_response
|
26 |
+
|
27 |
+
with gr.Blocks() as demo:
|
28 |
+
gr.Markdown("## Chatbot Gợi ý Quán ăn")
|
29 |
+
chatbot_ui = gr.ChatInterface(fn=respond, chatbot=gr.Chatbot())
|
30 |
+
|
31 |
+
demo.launch()
|
chatbot.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from typing import Tuple, List, Dict, Any
|
3 |
+
from src.utils.data_loader import load_restaurant_data
|
4 |
+
from src.utils.query_parser import QueryParser
|
5 |
+
from src.embeddings.embedder import Embedder
|
6 |
+
from src.retrieval.vector_store import VectorStore
|
7 |
+
from src.retrieval.keyword_filter import filter_restaurants
|
8 |
+
from src.retrieval.hybrid_search import HybridRetriever
|
9 |
+
from src.generation.llm import LLM
|
10 |
+
from langchain_core.embeddings import Embeddings
|
11 |
+
|
12 |
+
class LangChainEmbeddingWrapper(Embeddings):
|
13 |
+
def __init__(self, embedder):
|
14 |
+
self.embedder = embedder
|
15 |
+
|
16 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
17 |
+
return self.embedder.embed(texts).tolist()
|
18 |
+
|
19 |
+
def embed_query(self, text: str) -> List[float]:
|
20 |
+
return self.embedder.embed([text])[0].tolist()
|
21 |
+
|
22 |
+
class RestaurantChatbot:
|
23 |
+
def __init__(self, data_path: str = "data/restaurants.json"):
|
24 |
+
"""
|
25 |
+
Initialize the restaurant chatbot.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
data_path (str): Path to the restaurant JSON file.
|
29 |
+
"""
|
30 |
+
self.df = load_restaurant_data(data_path)
|
31 |
+
self.embedder = Embedder()
|
32 |
+
self.embedding_wrapper = LangChainEmbeddingWrapper(self.embedder)
|
33 |
+
self.vector_store = VectorStore(embedding_function=self.embedding_wrapper)
|
34 |
+
self.llm = LLM()
|
35 |
+
self.parser = QueryParser(self.df)
|
36 |
+
|
37 |
+
embeddings = self.embedder.embed(self.df['text'].tolist())
|
38 |
+
self.vector_store.add_documents(
|
39 |
+
documents=self.df['text'].tolist(),
|
40 |
+
embeddings=embeddings.tolist(),
|
41 |
+
ids=[str(i) for i in self.df['id']]
|
42 |
+
)
|
43 |
+
|
44 |
+
self.retriever = HybridRetriever(self.df, self.vector_store, self.embedder)
|
45 |
+
|
46 |
+
def answer(self, query: str) -> Tuple[str, List[Dict[str, Any]]]:
|
47 |
+
"""
|
48 |
+
Process a user query and return a natural, concise response with recommended restaurants.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
query (str): User query.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
Tuple[str, List[Dict[str, Any]]]: Natural response text and list of recommended restaurants.
|
55 |
+
"""
|
56 |
+
parsed_query = self.parser.parse_query(query)
|
57 |
+
filtered_df = filter_restaurants(self.df, parsed_query)
|
58 |
+
description = parsed_query["description"] if parsed_query["description"] else query
|
59 |
+
|
60 |
+
if filtered_df.empty:
|
61 |
+
retrieved_docs = self.retriever.retrieve(description, self.df, top_k=3)
|
62 |
+
else:
|
63 |
+
retrieved_docs = self.retriever.retrieve(description, filtered_df, top_k=3)
|
64 |
+
|
65 |
+
if not retrieved_docs:
|
66 |
+
return "Mình không tìm được nhà hàng nào phù hợp. Bạn thử đổi tiêu chí xem, như mở rộng khoảng cách hoặc loại món ăn nhé!", []
|
67 |
+
|
68 |
+
# Create context for LLM
|
69 |
+
context = "\n".join([
|
70 |
+
f"- {doc['name']} ({doc['cuisine']}): {', '.join(doc['dishes'])}. "
|
71 |
+
f"Price: {doc['price_range']}, Distance: {doc['distance']} km, Rating: {doc['rating']}. "
|
72 |
+
f"Description: {doc['description']}"
|
73 |
+
for doc in retrieved_docs
|
74 |
+
])
|
75 |
+
|
76 |
+
# Prompt for natural, consultant-like response
|
77 |
+
prompt = (
|
78 |
+
f"Bạn là một người tư vấn nhà hàng thân thiện. Dựa trên truy vấn và danh sách nhà hàng, hãy gợi ý ngắn gọn, tự nhiên, như trò chuyện với bạn bè, giải thích tại sao chọn các nhà hàng này (tập trung vào món ăn, giá, khoảng cách, hoặc đánh giá phù hợp với truy vấn). Không lặp lại truy vấn hoặc dùng ngôn ngữ kỹ thuật. Chỉ dùng thông tin từ danh sách nhà hàng.\n\n"
|
79 |
+
f"Truy vấn: {query}\n\n"
|
80 |
+
f"Danh sách nhà hàng:\n{context}\n\n"
|
81 |
+
f"Phản hồi:"
|
82 |
+
)
|
83 |
+
|
84 |
+
response = self.llm.generate(prompt, max_length=200)
|
85 |
+
return response, retrieved_docs
|
data/restaurants.json
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"id": 1,
|
4 |
+
"name": "Phở Hà Nội",
|
5 |
+
"cuisine": "Vietnamese",
|
6 |
+
"distance": 1.2,
|
7 |
+
"price_range": "medium",
|
8 |
+
"dishes": ["phở", "bún chả", "bánh mì", "gỏi cuốn"],
|
9 |
+
"rating": 4.5,
|
10 |
+
"address": "123 Nguyễn Huệ, Quận 1, TP.HCM, Việt Nam",
|
11 |
+
"description": "A cozy spot offering authentic Vietnamese phở with rich broth and fresh herbs.",
|
12 |
+
"dietary_options": ["vegetarian"],
|
13 |
+
"service_type": ["dine-in", "takeout"],
|
14 |
+
"kid_friendly": true,
|
15 |
+
"outdoor_seating": false,
|
16 |
+
"alcohol": false
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"id": 2,
|
20 |
+
"name": "Pizza 4P's",
|
21 |
+
"cuisine": "Italian",
|
22 |
+
"distance": 2.5,
|
23 |
+
"price_range": "high",
|
24 |
+
"dishes": ["pizza", "pasta", "tiramisu", "salad"],
|
25 |
+
"rating": 4.8,
|
26 |
+
"address": "8/15 Lê Thánh Tôn, Quận 1, TP.HCM, Việt Nam",
|
27 |
+
"description": "Upscale Italian dining with artisanal pizzas and a vibrant atmosphere.",
|
28 |
+
"dietary_options": ["vegetarian", "gluten-free"],
|
29 |
+
"service_type": ["dine-in", "delivery"],
|
30 |
+
"kid_friendly": true,
|
31 |
+
"outdoor_seating": true,
|
32 |
+
"alcohol": true
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"id": 3,
|
36 |
+
"name": "Sushi Bar",
|
37 |
+
"cuisine": "Japanese",
|
38 |
+
"distance": 3.0,
|
39 |
+
"price_range": "high",
|
40 |
+
"dishes": ["sushi", "sashimi", "ramen", "tempura"],
|
41 |
+
"rating": 4.3,
|
42 |
+
"address": "45 Lý Tự Trọng, Quận 1, TP.HCM, Việt Nam",
|
43 |
+
"description": "Authentic Japanese sushi with fresh fish and a modern ambiance.",
|
44 |
+
"dietary_options": [],
|
45 |
+
"service_type": ["dine-in"],
|
46 |
+
"kid_friendly": false,
|
47 |
+
"outdoor_seating": false,
|
48 |
+
"alcohol": true
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"id": 4,
|
52 |
+
"name": "Bún Bò Huế An Nam",
|
53 |
+
"cuisine": "Vietnamese",
|
54 |
+
"distance": 0.8,
|
55 |
+
"price_range": "low",
|
56 |
+
"dishes": ["bún bò huế", "bánh bèo", "bánh nậm"],
|
57 |
+
"rating": 4.0,
|
58 |
+
"address": "78 Pasteur, Quận 1, TP.HCM, Việt Nam",
|
59 |
+
"description": "A local favorite for spicy Huế-style noodle soups with bold flavors.",
|
60 |
+
"dietary_options": [],
|
61 |
+
"service_type": ["dine-in", "takeout"],
|
62 |
+
"kid_friendly": true,
|
63 |
+
"outdoor_seating": false,
|
64 |
+
"alcohol": false
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"id": 5,
|
68 |
+
"name": "Thai Spice",
|
69 |
+
"cuisine": "Thai",
|
70 |
+
"distance": 4.5,
|
71 |
+
"price_range": "medium",
|
72 |
+
"dishes": ["tom yum", "pad thai", "green curry", "mango sticky rice"],
|
73 |
+
"rating": 4.2,
|
74 |
+
"address": "12 Nguyễn Đình Chiểu, Đa Kao, Quận 1, TP.HCM, Việt Nam",
|
75 |
+
"description": "Vibrant Thai restaurant with spicy curries and refreshing tropical desserts.",
|
76 |
+
"dietary_options": ["vegetarian", "vegan"],
|
77 |
+
"service_type": ["dine-in", "delivery", "takeout"],
|
78 |
+
"kid_friendly": true,
|
79 |
+
"outdoor_seating": true,
|
80 |
+
"alcohol": true
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"id": 6,
|
84 |
+
"name": "La Fiesta",
|
85 |
+
"cuisine": "Mexican",
|
86 |
+
"distance": 2.0,
|
87 |
+
"price_range": "medium",
|
88 |
+
"dishes": ["tacos", "burrito", "quesadilla", "guacamole"],
|
89 |
+
"rating": 4.1,
|
90 |
+
"address": "33 Tôn Thất Thiệp, Quận 1, TP.HCM, Việt Nam",
|
91 |
+
"description": "Lively Mexican eatery with colorful decor and flavorful tacos.",
|
92 |
+
"dietary_options": ["vegetarian"],
|
93 |
+
"service_type": ["dine-in", "delivery"],
|
94 |
+
"kid_friendly": true,
|
95 |
+
"outdoor_seating": true,
|
96 |
+
"alcohol": true
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"id": 7,
|
100 |
+
"name": "The Vegan Garden",
|
101 |
+
"cuisine": "Vegan",
|
102 |
+
"distance": 5.0,
|
103 |
+
"price_range": "medium",
|
104 |
+
"dishes": ["vegan phở", "tofu curry", "quinoa salad"],
|
105 |
+
"rating": 4.6,
|
106 |
+
"address": "20 Võ Văn Tần, Quận 3, TP.HCM, Việt Nam",
|
107 |
+
"description": "A serene vegan restaurant offering plant-based Vietnamese and international dishes.",
|
108 |
+
"dietary_options": ["vegan", "gluten-free"],
|
109 |
+
"service_type": ["dine-in", "takeout"],
|
110 |
+
"kid_friendly": true,
|
111 |
+
"outdoor_seating": false,
|
112 |
+
"alcohol": false
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"id": 8,
|
116 |
+
"name": "Le Petit Paris",
|
117 |
+
"cuisine": "French",
|
118 |
+
"distance": 1.5,
|
119 |
+
"price_range": "high",
|
120 |
+
"dishes": ["croissant", "coq au vin", "crème brûlée"],
|
121 |
+
"rating": 4.7,
|
122 |
+
"address": "56 Đồng Khởi, Quận 1, TP.HCM, Việt Nam",
|
123 |
+
"description": "Elegant French bistro with classic dishes and a romantic ambiance.",
|
124 |
+
"dietary_options": [],
|
125 |
+
"service_type": ["dine-in"],
|
126 |
+
"kid_friendly": false,
|
127 |
+
"outdoor_seating": true,
|
128 |
+
"alcohol": true
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"id": 9,
|
132 |
+
"name": "Chaat House",
|
133 |
+
"cuisine": "Indian",
|
134 |
+
"distance": 3.8,
|
135 |
+
"price_range": "medium",
|
136 |
+
"dishes": ["butter chicken", "naan", "paneer tikka", "biryani"],
|
137 |
+
"rating": 4.4,
|
138 |
+
"address": "15 Lê Lợi, Quận 1, TP.HCM, Việt Nam",
|
139 |
+
"description": "Authentic Indian cuisine with rich spices and a warm atmosphere.",
|
140 |
+
"dietary_options": ["vegetarian", "vegan"],
|
141 |
+
"service_type": ["dine-in", "delivery", "takeout"],
|
142 |
+
"kid_friendly": true,
|
143 |
+
"outdoor_seating": false,
|
144 |
+
"alcohol": false
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"id": 10,
|
148 |
+
"name": "BBQ Haven",
|
149 |
+
"cuisine": "American",
|
150 |
+
"distance": 6.0,
|
151 |
+
"price_range": "medium",
|
152 |
+
"dishes": ["ribs", "burger", "fries", "coleslaw"],
|
153 |
+
"rating": 4.0,
|
154 |
+
"address": "88 Nguyễn Trãi, Quận 5, TP.HCM, Việt Nam",
|
155 |
+
"description": "Casual American BBQ spot with smoky ribs and hearty burgers.",
|
156 |
+
"dietary_options": [],
|
157 |
+
"service_type": ["dine-in", "delivery"],
|
158 |
+
"kid_friendly": true,
|
159 |
+
"outdoor_seating": true,
|
160 |
+
"alcohol": true
|
161 |
+
}
|
162 |
+
]
|
embeddings/embedder.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModel
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
class Embedder:
|
7 |
+
def __init__(self, model_name: str = "BAAI/bge-m3"):
|
8 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
+
self.model = AutoModel.from_pretrained(model_name)
|
10 |
+
|
11 |
+
def embed(self, texts: List[str]) -> np.ndarray:
|
12 |
+
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
13 |
+
with torch.no_grad():
|
14 |
+
outputs = self.model(**inputs)
|
15 |
+
embeddings = outputs.last_hidden_state[:, 0] # lấy embedding từ CLS token
|
16 |
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
17 |
+
return embeddings.cpu().numpy()
|
generation/llm.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
+
from langchain_core.prompts import PromptTemplate
|
3 |
+
import os
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
class LLM:
|
7 |
+
def __init__(self, model_repo: str = "Qwen/Qwen2-1.5B-Instruct",
|
8 |
+
local_path: str = "models"):
|
9 |
+
"""
|
10 |
+
Initialize the LLM with Qwen2-1.5B-Instruct using Hugging Face Transformers.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
model_repo (str): Hugging Face repository ID for the model.
|
14 |
+
local_path (str): Local directory to store the model.
|
15 |
+
"""
|
16 |
+
os.makedirs(local_path, exist_ok=True)
|
17 |
+
|
18 |
+
try:
|
19 |
+
# Load the model
|
20 |
+
self.llm = AutoModelForCausalLM.from_pretrained(
|
21 |
+
model_repo,
|
22 |
+
device_map="auto", # Automatically map to CPU
|
23 |
+
cache_dir=local_path,
|
24 |
+
trust_remote_code=True
|
25 |
+
)
|
26 |
+
|
27 |
+
# Load the tokenizer
|
28 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
29 |
+
model_repo,
|
30 |
+
cache_dir=local_path,
|
31 |
+
trust_remote_code=True
|
32 |
+
)
|
33 |
+
print(f"Model successfully loaded from {model_repo}")
|
34 |
+
except Exception as e:
|
35 |
+
raise RuntimeError(
|
36 |
+
f"Failed to initialize model from {model_repo}. "
|
37 |
+
f"Please ensure the model is available at https://huggingface.co/{model_repo}. "
|
38 |
+
f"Error: {str(e)}"
|
39 |
+
)
|
40 |
+
|
41 |
+
# Define prompt template for query parsing (used in query_parser.py)
|
42 |
+
self.prompt_template = PromptTemplate(
|
43 |
+
template="""Bạn là một trợ lý phân tích truy vấn nhà hàng. Phân tích truy vấn sau và trích xuất các đặc trưng: cuisine, menu, price_range, distance, rating, và description. Chỉ trích xuất các giá trị khớp chính xác với danh sách giá trị hợp lệ. Nếu không tìm thấy giá trị khớp, trả về null (hoặc [] cho menu). Loại bỏ các từ khóa đã trích xuất khỏi description. Trả về kết quả dưới dạng JSON.
|
44 |
+
|
45 |
+
**Danh sách giá trị hợp lệ**:
|
46 |
+
- cuisine: {cuisines}
|
47 |
+
- menu: {dishes}
|
48 |
+
- price_range: {price_ranges}
|
49 |
+
|
50 |
+
**Hướng dẫn**:
|
51 |
+
- cuisine: Chỉ chọn giá trị từ danh sách cuisine. Ví dụ, "Viet" → "Vietnamese".
|
52 |
+
- menu: Chỉ chọn các món khớp chính xác với danh sách menu. Ví dụ, "phở bò" → "phở", "sushi" → [].
|
53 |
+
- price_range: Chỉ chọn {price_ranges}. Ví dụ, "cheap" → "low".
|
54 |
+
- distance: Trích xuất số km (e.g., "2 km" → 2.0) hoặc từ khóa ["nearby", "close" → 2.0, "far" → 10.0]. Nếu không rõ, trả về null.
|
55 |
+
- rating: Trích xuất số (e.g., "4 stars" → 4.0). Nếu không rõ, trả về null.
|
56 |
+
- description: Phần còn lại sau khi loại bỏ các từ khóa đã trích xuất. Nếu rỗng, trả về truy vấn gốc.
|
57 |
+
|
58 |
+
**Truy vấn**: {query}
|
59 |
+
|
60 |
+
**Định dạng đầu ra**:
|
61 |
+
{{
|
62 |
+
"cuisine": null | "tên loại ẩm thực",
|
63 |
+
"menu": [],
|
64 |
+
"price_range": null | "low" | "medium" | "high",
|
65 |
+
"distance": null | số km | "nearby" | "close" | "far",
|
66 |
+
"rating": null | số,
|
67 |
+
"description": "phần mô tả còn lại"
|
68 |
+
}}
|
69 |
+
""",
|
70 |
+
input_variables=["cuisines", "dishes", "price_ranges", "query"]
|
71 |
+
)
|
72 |
+
|
73 |
+
def generate(self, prompt: str, max_length: int = 1000) -> str:
|
74 |
+
"""
|
75 |
+
Generate text using the LLM.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
prompt (str): Input prompt.
|
79 |
+
max_length (int): Maximum length of the generated text.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
str: Generated text.
|
83 |
+
"""
|
84 |
+
try:
|
85 |
+
# Apply chat template for instruction-tuned Qwen model
|
86 |
+
messages = [{"role": "user", "content": prompt}]
|
87 |
+
prompt_with_template = self.tokenizer.apply_chat_template(
|
88 |
+
messages, tokenize=False, add_generation_prompt=True
|
89 |
+
)
|
90 |
+
# Tokenize input prompt
|
91 |
+
inputs = self.tokenizer(prompt_with_template, return_tensors="pt").to(self.llm.device)
|
92 |
+
# Generate text
|
93 |
+
outputs = self.llm.generate(
|
94 |
+
**inputs,
|
95 |
+
max_new_tokens=max_length,
|
96 |
+
temperature=0.7,
|
97 |
+
do_sample=True,
|
98 |
+
pad_token_id=self.tokenizer.eos_token_id
|
99 |
+
)
|
100 |
+
# Decode the generated tokens
|
101 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
102 |
+
print("Response generated successfully!")
|
103 |
+
return response.strip()
|
104 |
+
except Exception as e:
|
105 |
+
raise RuntimeError(f"Failed to generate response: {str(e)}")
|
106 |
+
|
107 |
+
def format_query_prompt(self, query: str, cuisines: List[str], dishes: List[str], price_ranges: List[str]) -> str:
|
108 |
+
"""
|
109 |
+
Format the prompt for query parsing using the prompt template.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
query (str): User query.
|
113 |
+
cuisines (list): List of valid cuisines.
|
114 |
+
dishes (list): List of valid dishes.
|
115 |
+
price_ranges (list): List of valid price ranges.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
str: Formatted prompt.
|
119 |
+
"""
|
120 |
+
return self.prompt_template.format(
|
121 |
+
cuisines=cuisines,
|
122 |
+
dishes=dishes,
|
123 |
+
price_ranges=price_ranges,
|
124 |
+
query=query
|
125 |
+
)
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
# Khởi tạo đối tượng LLM với model_repo và local_path
|
129 |
+
local_path = 'models'
|
130 |
+
|
131 |
+
try:
|
132 |
+
# Khởi tạo đối tượng LLM
|
133 |
+
llm = LLM(local_path=local_path)
|
134 |
+
|
135 |
+
# Định nghĩa một truy vấn và các tham số cần thiết
|
136 |
+
query = "Tìm quán ăn Việt Nam gần đây, giá rẻ với món phở và cơm tấm"
|
137 |
+
cuisines = ["Vietnamese", "Chinese", "Italian"]
|
138 |
+
dishes = ["phở", "sushi", "pasta", "cơm tấm"]
|
139 |
+
price_ranges = ["low", "medium", "high"]
|
140 |
+
|
141 |
+
# Sử dụng hàm generate để tạo câu trả lời từ truy vấn
|
142 |
+
generated_text = llm.generate(query, max_length=300)
|
143 |
+
|
144 |
+
# In kết quả ra màn hình
|
145 |
+
print("Generated text:")
|
146 |
+
print(generated_text)
|
147 |
+
|
148 |
+
except Exception as e:
|
149 |
+
print(f"Error: {str(e)}")
|
llm.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
+
from langchain_core.prompts import PromptTemplate
|
3 |
+
import os
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
class LLM:
|
7 |
+
def __init__(self, model_repo: str = "Qwen/Qwen2-1.5B-Instruct",
|
8 |
+
local_path: str = "models"):
|
9 |
+
"""
|
10 |
+
Initialize the LLM with Qwen2-1.5B-Instruct using Hugging Face Transformers.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
model_repo (str): Hugging Face repository ID for the model.
|
14 |
+
local_path (str): Local directory to store the model.
|
15 |
+
"""
|
16 |
+
os.makedirs(local_path, exist_ok=True)
|
17 |
+
|
18 |
+
try:
|
19 |
+
# Load the model
|
20 |
+
self.llm = AutoModelForCausalLM.from_pretrained(
|
21 |
+
model_repo,
|
22 |
+
device_map="auto", # Automatically map to CPU
|
23 |
+
cache_dir=local_path,
|
24 |
+
trust_remote_code=True
|
25 |
+
)
|
26 |
+
|
27 |
+
# Load the tokenizer
|
28 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
29 |
+
model_repo,
|
30 |
+
cache_dir=local_path,
|
31 |
+
trust_remote_code=True
|
32 |
+
)
|
33 |
+
print(f"Model successfully loaded from {model_repo}")
|
34 |
+
except Exception as e:
|
35 |
+
raise RuntimeError(
|
36 |
+
f"Failed to initialize model from {model_repo}. "
|
37 |
+
f"Please ensure the model is available at https://huggingface.co/{model_repo}. "
|
38 |
+
f"Error: {str(e)}"
|
39 |
+
)
|
40 |
+
|
41 |
+
# Define prompt template for query parsing (used in query_parser.py)
|
42 |
+
self.prompt_template = PromptTemplate(
|
43 |
+
template="""Bạn là một trợ lý phân tích truy vấn nhà hàng. Phân tích truy vấn sau và trích xuất các đặc trưng: cuisine, menu, price_range, distance, rating, và description. Chỉ trích xuất các giá trị khớp chính xác với danh sách giá trị hợp lệ. Nếu không tìm thấy giá trị khớp, trả về null (hoặc [] cho menu). Loại bỏ các từ khóa đã trích xuất khỏi description. Trả về kết quả dưới dạng JSON.
|
44 |
+
|
45 |
+
**Danh sách giá trị hợp lệ**:
|
46 |
+
- cuisine: {cuisines}
|
47 |
+
- menu: {dishes}
|
48 |
+
- price_range: {price_ranges}
|
49 |
+
|
50 |
+
**Hướng dẫn**:
|
51 |
+
- cuisine: Chỉ chọn giá trị từ danh sách cuisine. Ví dụ, "Viet" → "Vietnamese".
|
52 |
+
- menu: Chỉ chọn các món khớp chính xác với danh sách menu. Ví dụ, "phở bò" → "phở", "sushi" → [].
|
53 |
+
- price_range: Chỉ chọn {price_ranges}. Ví dụ, "cheap" → "low".
|
54 |
+
- distance: Trích xuất số km (e.g., "2 km" → 2.0) hoặc từ khóa ["nearby", "close" → 2.0, "far" → 10.0]. Nếu không rõ, trả về null.
|
55 |
+
- rating: Trích xuất số (e.g., "4 stars" → 4.0). Nếu không rõ, trả về null.
|
56 |
+
- description: Phần còn lại sau khi loại bỏ các từ khóa đã trích xuất. Nếu rỗng, trả về truy vấn gốc.
|
57 |
+
|
58 |
+
**Truy vấn**: {query}
|
59 |
+
|
60 |
+
**Định dạng đầu ra**:
|
61 |
+
{{
|
62 |
+
"cuisine": null | "tên loại ẩm thực",
|
63 |
+
"menu": [],
|
64 |
+
"price_range": null | "low" | "medium" | "high",
|
65 |
+
"distance": null | số km | "nearby" | "close" | "far",
|
66 |
+
"rating": null | số,
|
67 |
+
"description": "phần mô tả còn lại"
|
68 |
+
}}
|
69 |
+
""",
|
70 |
+
input_variables=["cuisines", "dishes", "price_ranges", "query"]
|
71 |
+
)
|
72 |
+
|
73 |
+
def generate(self, prompt: str, max_length: int = 1000) -> str:
|
74 |
+
"""
|
75 |
+
Generate text using the LLM.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
prompt (str): Input prompt.
|
79 |
+
max_length (int): Maximum length of the generated text.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
str: Generated text.
|
83 |
+
"""
|
84 |
+
try:
|
85 |
+
# Apply chat template for instruction-tuned Qwen model
|
86 |
+
messages = [{"role": "user", "content": prompt}]
|
87 |
+
prompt_with_template = self.tokenizer.apply_chat_template(
|
88 |
+
messages, tokenize=False, add_generation_prompt=True
|
89 |
+
)
|
90 |
+
# Tokenize input prompt
|
91 |
+
inputs = self.tokenizer(prompt_with_template, return_tensors="pt").to(self.llm.device)
|
92 |
+
# Generate text
|
93 |
+
outputs = self.llm.generate(
|
94 |
+
**inputs,
|
95 |
+
max_new_tokens=max_length,
|
96 |
+
temperature=0.7,
|
97 |
+
do_sample=True,
|
98 |
+
pad_token_id=self.tokenizer.eos_token_id
|
99 |
+
)
|
100 |
+
# Decode the generated tokens
|
101 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
102 |
+
print("Response generated successfully!")
|
103 |
+
return response.strip()
|
104 |
+
except Exception as e:
|
105 |
+
raise RuntimeError(f"Failed to generate response: {str(e)}")
|
106 |
+
|
107 |
+
def format_query_prompt(self, query: str, cuisines: List[str], dishes: List[str], price_ranges: List[str]) -> str:
|
108 |
+
"""
|
109 |
+
Format the prompt for query parsing using the prompt template.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
query (str): User query.
|
113 |
+
cuisines (list): List of valid cuisines.
|
114 |
+
dishes (list): List of valid dishes.
|
115 |
+
price_ranges (list): List of valid price ranges.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
str: Formatted prompt.
|
119 |
+
"""
|
120 |
+
return self.prompt_template.format(
|
121 |
+
cuisines=cuisines,
|
122 |
+
dishes=dishes,
|
123 |
+
price_ranges=price_ranges,
|
124 |
+
query=query
|
125 |
+
)
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
# Khởi tạo đối tượng LLM với model_repo và local_path
|
129 |
+
local_path = 'models'
|
130 |
+
|
131 |
+
try:
|
132 |
+
# Khởi tạo đối tượng LLM
|
133 |
+
llm = LLM(local_path=local_path)
|
134 |
+
|
135 |
+
# Định nghĩa một truy vấn và các tham số cần thiết
|
136 |
+
query = "Tìm quán ăn Việt Nam gần đây, giá rẻ với món phở và cơm tấm"
|
137 |
+
cuisines = ["Vietnamese", "Chinese", "Italian"]
|
138 |
+
dishes = ["phở", "sushi", "pasta", "cơm tấm"]
|
139 |
+
price_ranges = ["low", "medium", "high"]
|
140 |
+
|
141 |
+
# Sử dụng hàm generate để tạo câu trả lời từ truy vấn
|
142 |
+
generated_text = llm.generate(query, max_length=300)
|
143 |
+
|
144 |
+
# In kết quả ra màn hình
|
145 |
+
print("Generated text:")
|
146 |
+
print(generated_text)
|
147 |
+
|
148 |
+
except Exception as e:
|
149 |
+
print(f"Error: {str(e)}")
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
pandas
|
3 |
+
torch
|
4 |
+
transformers
|
5 |
+
chromadb
|
6 |
+
langchain-core
|
7 |
+
rank-bm25
|
retrieval/hybrid_search.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/retrieval/hybrid_search.py
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from rank_bm25 import BM25Okapi
|
5 |
+
from typing import List, Dict, Any
|
6 |
+
import pandas as pd
|
7 |
+
from src.embeddings.embedder import Embedder
|
8 |
+
from src.retrieval.vector_store import VectorStore
|
9 |
+
|
10 |
+
class HybridRetriever:
|
11 |
+
def __init__(self, df: pd.DataFrame, vector_store: VectorStore, embedder: Embedder, alpha: float = 0.5):
|
12 |
+
self.df = df
|
13 |
+
self.vector_store = vector_store
|
14 |
+
self.embedder = embedder
|
15 |
+
self.alpha = alpha
|
16 |
+
tokenized_corpus = [doc.lower().split() for doc in df['description']]
|
17 |
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
18 |
+
|
19 |
+
def retrieve(self, query: str, filtered_df: pd.DataFrame, top_k: int = 3) -> List[Dict[str, Any]]:
|
20 |
+
filtered_indices = filtered_df.index.tolist()
|
21 |
+
filtered_texts = filtered_df['description'].tolist()
|
22 |
+
filtered_ids = [str(row['id']) for _, row in filtered_df.iterrows()]
|
23 |
+
|
24 |
+
if not filtered_texts:
|
25 |
+
return []
|
26 |
+
|
27 |
+
query_embedding = self.embedder.embed([query])[0]
|
28 |
+
dense_results = self.vector_store.query(query_embedding, top_k=top_k * 2)
|
29 |
+
dense_ids = [id for id in dense_results['ids'][0] if id in filtered_ids]
|
30 |
+
dense_scores = [1 - dist for dist, id in zip(dense_results['distances'][0], dense_results['ids'][0]) if id in filtered_ids]
|
31 |
+
|
32 |
+
tokenized_query = query.lower().split()
|
33 |
+
bm25_scores = self.bm25.get_scores(tokenized_query)
|
34 |
+
bm25_scores_filtered = [bm25_scores[i] for i in filtered_indices]
|
35 |
+
bm25_top_k = np.argsort(bm25_scores_filtered)[::-1][:top_k * 2]
|
36 |
+
bm25_ids = [filtered_ids[i] for i in bm25_top_k]
|
37 |
+
bm25_scores = [bm25_scores_filtered[i] for i in bm25_top_k]
|
38 |
+
|
39 |
+
dense_scores = np.array(dense_scores) / np.max(dense_scores) if dense_scores else dense_scores
|
40 |
+
bm25_scores = np.array(bm25_scores) / np.max(bm25_scores) if bm25_scores else bm25_scores
|
41 |
+
|
42 |
+
combined_scores = {}
|
43 |
+
for idx, dense_id in enumerate(dense_ids):
|
44 |
+
combined_scores[int(dense_id)] = combined_scores.get(int(dense_id), 0) + self.alpha * dense_scores[idx]
|
45 |
+
for idx, bm25_id in enumerate(bm25_ids):
|
46 |
+
combined_scores[int(bm25_id)] = combined_scores.get(int(bm25_id), 0) + (1 - self.alpha) * bm25_scores[idx]
|
47 |
+
|
48 |
+
sorted_ids = sorted(combined_scores, key=combined_scores.get, reverse=True)[:top_k]
|
49 |
+
return [self.df[self.df['id'] == id].iloc[0].to_dict() for id in sorted_ids]
|
retrieval/keyword_filter.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/retrieval/keyword_filter.py
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
from typing import Dict, Any
|
5 |
+
|
6 |
+
def filter_restaurants(df: pd.DataFrame, parsed_query: Dict[str, Any]) -> pd.DataFrame:
|
7 |
+
"""
|
8 |
+
Filter restaurants based on extracted features from the query.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
df (pd.DataFrame): DataFrame containing restaurant data.
|
12 |
+
parsed_query (Dict[str, Any]): Parsed query with features.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
pd.DataFrame: Filtered DataFrame.
|
16 |
+
"""
|
17 |
+
filtered_df = df.copy()
|
18 |
+
|
19 |
+
if parsed_query.get("cuisine"):
|
20 |
+
filtered_df = filtered_df[filtered_df["cuisine"].str.lower() == parsed_query["cuisine"].lower()]
|
21 |
+
|
22 |
+
if parsed_query.get("menu"):
|
23 |
+
filtered_df = filtered_df[filtered_df["dishes"].apply(
|
24 |
+
lambda dishes: any(item.lower() in [d.lower() for d in dishes] for item in parsed_query["menu"])
|
25 |
+
)]
|
26 |
+
|
27 |
+
if parsed_query.get("price_range"):
|
28 |
+
filtered_df = filtered_df[filtered_df["price_range"].str.lower() == parsed_query["price_range"].lower()]
|
29 |
+
|
30 |
+
distance = parsed_query.get("distance")
|
31 |
+
if isinstance(distance, (int, float)):
|
32 |
+
filtered_df = filtered_df[filtered_df["distance"] <= distance]
|
33 |
+
elif distance in ["nearby", "close"]:
|
34 |
+
filtered_df = filtered_df[filtered_df["distance"] <= 2.0]
|
35 |
+
elif distance == "far":
|
36 |
+
filtered_df = filtered_df[filtered_df["distance"] <= 10.0]
|
37 |
+
|
38 |
+
if parsed_query.get("rating"):
|
39 |
+
filtered_df = filtered_df[filtered_df["rating"] >= parsed_query["rating"]]
|
40 |
+
|
41 |
+
return filtered_df
|
retrieval/vector_store.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/retrieval/vector_store.py
|
2 |
+
|
3 |
+
from langchain_community.vectorstores import Chroma
|
4 |
+
from langchain_core.documents import Document
|
5 |
+
import numpy as np
|
6 |
+
from typing import List, Dict, Any
|
7 |
+
|
8 |
+
class VectorStore:
|
9 |
+
def __init__(self, embedding_function):
|
10 |
+
self.embedding_function = embedding_function
|
11 |
+
self.collection = None
|
12 |
+
|
13 |
+
def add_documents(self, documents: List[str], embeddings: List[np.ndarray], ids: List[str]):
|
14 |
+
langchain_docs = [Document(page_content=doc, metadata={"id": id}) for doc, id in zip(documents, ids)]
|
15 |
+
self.collection = Chroma.from_documents(
|
16 |
+
documents=langchain_docs,
|
17 |
+
embedding=self.embedding_function,
|
18 |
+
ids=ids,
|
19 |
+
persist_directory="./chroma_db"
|
20 |
+
)
|
21 |
+
self.collection.persist()
|
22 |
+
|
23 |
+
def query(self, query_embedding: np.ndarray, top_k: int = 5) -> Dict[str, Any]:
|
24 |
+
results = self.collection.similarity_search_by_vector(
|
25 |
+
embedding=query_embedding,
|
26 |
+
k=top_k
|
27 |
+
)
|
28 |
+
ids = [doc.metadata["id"] for doc in results]
|
29 |
+
distances = [1 - np.dot(query_embedding, doc.vector) / (np.linalg.norm(query_embedding) * np.linalg.norm(doc.vector))
|
30 |
+
if hasattr(doc, "vector") else 1.0 for doc in results]
|
31 |
+
return {
|
32 |
+
"ids": [ids],
|
33 |
+
"distances": [distances]
|
34 |
+
}
|
src/chatbot.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from typing import Tuple, List, Dict, Any
|
3 |
+
from src.utils.data_loader import load_restaurant_data
|
4 |
+
from src.utils.query_parser import QueryParser
|
5 |
+
from src.embeddings.embedder import Embedder
|
6 |
+
from src.retrieval.vector_store import VectorStore
|
7 |
+
from src.retrieval.keyword_filter import filter_restaurants
|
8 |
+
from src.retrieval.hybrid_search import HybridRetriever
|
9 |
+
from src.generation.llm import LLM
|
10 |
+
from langchain_core.embeddings import Embeddings
|
11 |
+
|
12 |
+
class LangChainEmbeddingWrapper(Embeddings):
|
13 |
+
def __init__(self, embedder):
|
14 |
+
self.embedder = embedder
|
15 |
+
|
16 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
17 |
+
return self.embedder.embed(texts).tolist()
|
18 |
+
|
19 |
+
def embed_query(self, text: str) -> List[float]:
|
20 |
+
return self.embedder.embed([text])[0].tolist()
|
21 |
+
|
22 |
+
class RestaurantChatbot:
|
23 |
+
def __init__(self, data_path: str = "data/restaurants.json"):
|
24 |
+
"""
|
25 |
+
Initialize the restaurant chatbot.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
data_path (str): Path to the restaurant JSON file.
|
29 |
+
"""
|
30 |
+
self.df = load_restaurant_data(data_path)
|
31 |
+
self.embedder = Embedder()
|
32 |
+
self.embedding_wrapper = LangChainEmbeddingWrapper(self.embedder)
|
33 |
+
self.vector_store = VectorStore(embedding_function=self.embedding_wrapper)
|
34 |
+
self.llm = LLM()
|
35 |
+
self.parser = QueryParser(self.df)
|
36 |
+
|
37 |
+
embeddings = self.embedder.embed(self.df['text'].tolist())
|
38 |
+
self.vector_store.add_documents(
|
39 |
+
documents=self.df['text'].tolist(),
|
40 |
+
embeddings=embeddings.tolist(),
|
41 |
+
ids=[str(i) for i in self.df['id']]
|
42 |
+
)
|
43 |
+
|
44 |
+
self.retriever = HybridRetriever(self.df, self.vector_store, self.embedder)
|
45 |
+
|
46 |
+
def answer(self, query: str) -> Tuple[str, List[Dict[str, Any]]]:
|
47 |
+
"""
|
48 |
+
Process a user query and return a natural, concise response with recommended restaurants.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
query (str): User query.
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
Tuple[str, List[Dict[str, Any]]]: Natural response text and list of recommended restaurants.
|
55 |
+
"""
|
56 |
+
parsed_query = self.parser.parse_query(query)
|
57 |
+
filtered_df = filter_restaurants(self.df, parsed_query)
|
58 |
+
description = parsed_query["description"] if parsed_query["description"] else query
|
59 |
+
|
60 |
+
if filtered_df.empty:
|
61 |
+
retrieved_docs = self.retriever.retrieve(description, self.df, top_k=3)
|
62 |
+
else:
|
63 |
+
retrieved_docs = self.retriever.retrieve(description, filtered_df, top_k=3)
|
64 |
+
|
65 |
+
if not retrieved_docs:
|
66 |
+
return "Mình không tìm được nhà hàng nào phù hợp. Bạn thử đổi tiêu chí xem, như mở rộng khoảng cách hoặc loại món ăn nhé!", []
|
67 |
+
|
68 |
+
# Create context for LLM
|
69 |
+
context = "\n".join([
|
70 |
+
f"- {doc['name']} ({doc['cuisine']}): {', '.join(doc['dishes'])}. "
|
71 |
+
f"Price: {doc['price_range']}, Distance: {doc['distance']} km, Rating: {doc['rating']}. "
|
72 |
+
f"Description: {doc['description']}"
|
73 |
+
for doc in retrieved_docs
|
74 |
+
])
|
75 |
+
|
76 |
+
# Prompt for natural, consultant-like response
|
77 |
+
prompt = (
|
78 |
+
f"Bạn là một người tư vấn nhà hàng thân thiện. Dựa trên truy vấn và danh sách nhà hàng, hãy gợi ý ngắn gọn, tự nhiên, như trò chuyện với bạn bè, giải thích tại sao chọn các nhà hàng này (tập trung vào món ăn, giá, khoảng cách, hoặc đánh giá phù hợp với truy vấn). Không lặp lại truy vấn hoặc dùng ngôn ngữ kỹ thuật. Chỉ dùng thông tin từ danh sách nhà hàng.\n\n"
|
79 |
+
f"Truy vấn: {query}\n\n"
|
80 |
+
f"Danh sách nhà hàng:\n{context}\n\n"
|
81 |
+
f"Phản hồi:"
|
82 |
+
)
|
83 |
+
|
84 |
+
response = self.llm.generate(prompt, max_length=200)
|
85 |
+
return response, retrieved_docs
|
src/embeddings/embedder.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModel
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
class Embedder:
|
7 |
+
def __init__(self, model_name: str = "BAAI/bge-m3"):
|
8 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
+
self.model = AutoModel.from_pretrained(model_name)
|
10 |
+
|
11 |
+
def embed(self, texts: List[str]) -> np.ndarray:
|
12 |
+
inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
13 |
+
with torch.no_grad():
|
14 |
+
outputs = self.model(**inputs)
|
15 |
+
embeddings = outputs.last_hidden_state[:, 0] # lấy embedding từ CLS token
|
16 |
+
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
17 |
+
return embeddings.cpu().numpy()
|
src/generation/llm.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
+
from langchain_core.prompts import PromptTemplate
|
3 |
+
import os
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
class LLM:
|
7 |
+
def __init__(self, model_repo: str = "Qwen/Qwen2-1.5B-Instruct",
|
8 |
+
local_path: str = "models"):
|
9 |
+
"""
|
10 |
+
Initialize the LLM with Qwen2-1.5B-Instruct using Hugging Face Transformers.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
model_repo (str): Hugging Face repository ID for the model.
|
14 |
+
local_path (str): Local directory to store the model.
|
15 |
+
"""
|
16 |
+
os.makedirs(local_path, exist_ok=True)
|
17 |
+
|
18 |
+
try:
|
19 |
+
# Load the model
|
20 |
+
self.llm = AutoModelForCausalLM.from_pretrained(
|
21 |
+
model_repo,
|
22 |
+
device_map="auto", # Automatically map to CPU
|
23 |
+
cache_dir=local_path,
|
24 |
+
trust_remote_code=True
|
25 |
+
)
|
26 |
+
|
27 |
+
# Load the tokenizer
|
28 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
29 |
+
model_repo,
|
30 |
+
cache_dir=local_path,
|
31 |
+
trust_remote_code=True
|
32 |
+
)
|
33 |
+
print(f"Model successfully loaded from {model_repo}")
|
34 |
+
except Exception as e:
|
35 |
+
raise RuntimeError(
|
36 |
+
f"Failed to initialize model from {model_repo}. "
|
37 |
+
f"Please ensure the model is available at https://huggingface.co/{model_repo}. "
|
38 |
+
f"Error: {str(e)}"
|
39 |
+
)
|
40 |
+
|
41 |
+
# Define prompt template for query parsing (used in query_parser.py)
|
42 |
+
self.prompt_template = PromptTemplate(
|
43 |
+
template="""Bạn là một trợ lý phân tích truy vấn nhà hàng. Phân tích truy vấn sau và trích xuất các đặc trưng: cuisine, menu, price_range, distance, rating, và description. Chỉ trích xuất các giá trị khớp chính xác với danh sách giá trị hợp lệ. Nếu không tìm thấy giá trị khớp, trả về null (hoặc [] cho menu). Loại bỏ các từ khóa đã trích xuất khỏi description. Trả về kết quả dưới dạng JSON.
|
44 |
+
|
45 |
+
**Danh sách giá trị hợp lệ**:
|
46 |
+
- cuisine: {cuisines}
|
47 |
+
- menu: {dishes}
|
48 |
+
- price_range: {price_ranges}
|
49 |
+
|
50 |
+
**Hướng dẫn**:
|
51 |
+
- cuisine: Chỉ chọn giá trị từ danh sách cuisine. Ví dụ, "Viet" → "Vietnamese".
|
52 |
+
- menu: Chỉ chọn các món khớp chính xác với danh sách menu. Ví dụ, "phở bò" → "phở", "sushi" → [].
|
53 |
+
- price_range: Chỉ chọn {price_ranges}. Ví dụ, "cheap" → "low".
|
54 |
+
- distance: Trích xuất số km (e.g., "2 km" → 2.0) hoặc từ khóa ["nearby", "close" → 2.0, "far" → 10.0]. Nếu không rõ, trả về null.
|
55 |
+
- rating: Trích xuất số (e.g., "4 stars" → 4.0). Nếu không rõ, trả về null.
|
56 |
+
- description: Phần còn lại sau khi loại bỏ các từ khóa đã trích xuất. Nếu rỗng, trả về truy vấn gốc.
|
57 |
+
|
58 |
+
**Truy vấn**: {query}
|
59 |
+
|
60 |
+
**Định dạng đầu ra**:
|
61 |
+
{{
|
62 |
+
"cuisine": null | "tên loại ẩm thực",
|
63 |
+
"menu": [],
|
64 |
+
"price_range": null | "low" | "medium" | "high",
|
65 |
+
"distance": null | số km | "nearby" | "close" | "far",
|
66 |
+
"rating": null | số,
|
67 |
+
"description": "phần mô tả còn lại"
|
68 |
+
}}
|
69 |
+
""",
|
70 |
+
input_variables=["cuisines", "dishes", "price_ranges", "query"]
|
71 |
+
)
|
72 |
+
|
73 |
+
def generate(self, prompt: str, max_length: int = 1000) -> str:
|
74 |
+
"""
|
75 |
+
Generate text using the LLM.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
prompt (str): Input prompt.
|
79 |
+
max_length (int): Maximum length of the generated text.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
str: Generated text.
|
83 |
+
"""
|
84 |
+
try:
|
85 |
+
# Apply chat template for instruction-tuned Qwen model
|
86 |
+
messages = [{"role": "user", "content": prompt}]
|
87 |
+
prompt_with_template = self.tokenizer.apply_chat_template(
|
88 |
+
messages, tokenize=False, add_generation_prompt=True
|
89 |
+
)
|
90 |
+
# Tokenize input prompt
|
91 |
+
inputs = self.tokenizer(prompt_with_template, return_tensors="pt").to(self.llm.device)
|
92 |
+
# Generate text
|
93 |
+
outputs = self.llm.generate(
|
94 |
+
**inputs,
|
95 |
+
max_new_tokens=max_length,
|
96 |
+
temperature=0.7,
|
97 |
+
do_sample=True,
|
98 |
+
pad_token_id=self.tokenizer.eos_token_id
|
99 |
+
)
|
100 |
+
# Decode the generated tokens
|
101 |
+
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
102 |
+
print("Response generated successfully!")
|
103 |
+
return response.strip()
|
104 |
+
except Exception as e:
|
105 |
+
raise RuntimeError(f"Failed to generate response: {str(e)}")
|
106 |
+
|
107 |
+
def format_query_prompt(self, query: str, cuisines: List[str], dishes: List[str], price_ranges: List[str]) -> str:
|
108 |
+
"""
|
109 |
+
Format the prompt for query parsing using the prompt template.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
query (str): User query.
|
113 |
+
cuisines (list): List of valid cuisines.
|
114 |
+
dishes (list): List of valid dishes.
|
115 |
+
price_ranges (list): List of valid price ranges.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
str: Formatted prompt.
|
119 |
+
"""
|
120 |
+
return self.prompt_template.format(
|
121 |
+
cuisines=cuisines,
|
122 |
+
dishes=dishes,
|
123 |
+
price_ranges=price_ranges,
|
124 |
+
query=query
|
125 |
+
)
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
# Khởi tạo đối tượng LLM với model_repo và local_path
|
129 |
+
local_path = 'models'
|
130 |
+
|
131 |
+
try:
|
132 |
+
# Khởi tạo đối tượng LLM
|
133 |
+
llm = LLM(local_path=local_path)
|
134 |
+
|
135 |
+
# Định nghĩa một truy vấn và các tham số cần thiết
|
136 |
+
query = "Tìm quán ăn Việt Nam gần đây, giá rẻ với món phở và cơm tấm"
|
137 |
+
cuisines = ["Vietnamese", "Chinese", "Italian"]
|
138 |
+
dishes = ["phở", "sushi", "pasta", "cơm tấm"]
|
139 |
+
price_ranges = ["low", "medium", "high"]
|
140 |
+
|
141 |
+
# Sử dụng hàm generate để tạo câu trả lời từ truy vấn
|
142 |
+
generated_text = llm.generate(query, max_length=300)
|
143 |
+
|
144 |
+
# In kết quả ra màn hình
|
145 |
+
print("Generated text:")
|
146 |
+
print(generated_text)
|
147 |
+
|
148 |
+
except Exception as e:
|
149 |
+
print(f"Error: {str(e)}")
|
src/retrieval/hybrid_search.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/retrieval/hybrid_search.py
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from rank_bm25 import BM25Okapi
|
5 |
+
from typing import List, Dict, Any
|
6 |
+
import pandas as pd
|
7 |
+
from src.embeddings.embedder import Embedder
|
8 |
+
from src.retrieval.vector_store import VectorStore
|
9 |
+
|
10 |
+
class HybridRetriever:
|
11 |
+
def __init__(self, df: pd.DataFrame, vector_store: VectorStore, embedder: Embedder, alpha: float = 0.5):
|
12 |
+
self.df = df
|
13 |
+
self.vector_store = vector_store
|
14 |
+
self.embedder = embedder
|
15 |
+
self.alpha = alpha
|
16 |
+
tokenized_corpus = [doc.lower().split() for doc in df['description']]
|
17 |
+
self.bm25 = BM25Okapi(tokenized_corpus)
|
18 |
+
|
19 |
+
def retrieve(self, query: str, filtered_df: pd.DataFrame, top_k: int = 3) -> List[Dict[str, Any]]:
|
20 |
+
filtered_indices = filtered_df.index.tolist()
|
21 |
+
filtered_texts = filtered_df['description'].tolist()
|
22 |
+
filtered_ids = [str(row['id']) for _, row in filtered_df.iterrows()]
|
23 |
+
|
24 |
+
if not filtered_texts:
|
25 |
+
return []
|
26 |
+
|
27 |
+
query_embedding = self.embedder.embed([query])[0]
|
28 |
+
dense_results = self.vector_store.query(query_embedding, top_k=top_k * 2)
|
29 |
+
dense_ids = [id for id in dense_results['ids'][0] if id in filtered_ids]
|
30 |
+
dense_scores = [1 - dist for dist, id in zip(dense_results['distances'][0], dense_results['ids'][0]) if id in filtered_ids]
|
31 |
+
|
32 |
+
tokenized_query = query.lower().split()
|
33 |
+
bm25_scores = self.bm25.get_scores(tokenized_query)
|
34 |
+
bm25_scores_filtered = [bm25_scores[i] for i in filtered_indices]
|
35 |
+
bm25_top_k = np.argsort(bm25_scores_filtered)[::-1][:top_k * 2]
|
36 |
+
bm25_ids = [filtered_ids[i] for i in bm25_top_k]
|
37 |
+
bm25_scores = [bm25_scores_filtered[i] for i in bm25_top_k]
|
38 |
+
|
39 |
+
dense_scores = np.array(dense_scores) / np.max(dense_scores) if dense_scores else dense_scores
|
40 |
+
bm25_scores = np.array(bm25_scores) / np.max(bm25_scores) if bm25_scores else bm25_scores
|
41 |
+
|
42 |
+
combined_scores = {}
|
43 |
+
for idx, dense_id in enumerate(dense_ids):
|
44 |
+
combined_scores[int(dense_id)] = combined_scores.get(int(dense_id), 0) + self.alpha * dense_scores[idx]
|
45 |
+
for idx, bm25_id in enumerate(bm25_ids):
|
46 |
+
combined_scores[int(bm25_id)] = combined_scores.get(int(bm25_id), 0) + (1 - self.alpha) * bm25_scores[idx]
|
47 |
+
|
48 |
+
sorted_ids = sorted(combined_scores, key=combined_scores.get, reverse=True)[:top_k]
|
49 |
+
return [self.df[self.df['id'] == id].iloc[0].to_dict() for id in sorted_ids]
|
src/retrieval/keyword_filter.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/retrieval/keyword_filter.py
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
from typing import Dict, Any
|
5 |
+
|
6 |
+
def filter_restaurants(df: pd.DataFrame, parsed_query: Dict[str, Any]) -> pd.DataFrame:
|
7 |
+
"""
|
8 |
+
Filter restaurants based on extracted features from the query.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
df (pd.DataFrame): DataFrame containing restaurant data.
|
12 |
+
parsed_query (Dict[str, Any]): Parsed query with features.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
pd.DataFrame: Filtered DataFrame.
|
16 |
+
"""
|
17 |
+
filtered_df = df.copy()
|
18 |
+
|
19 |
+
if parsed_query.get("cuisine"):
|
20 |
+
filtered_df = filtered_df[filtered_df["cuisine"].str.lower() == parsed_query["cuisine"].lower()]
|
21 |
+
|
22 |
+
if parsed_query.get("menu"):
|
23 |
+
filtered_df = filtered_df[filtered_df["dishes"].apply(
|
24 |
+
lambda dishes: any(item.lower() in [d.lower() for d in dishes] for item in parsed_query["menu"])
|
25 |
+
)]
|
26 |
+
|
27 |
+
if parsed_query.get("price_range"):
|
28 |
+
filtered_df = filtered_df[filtered_df["price_range"].str.lower() == parsed_query["price_range"].lower()]
|
29 |
+
|
30 |
+
distance = parsed_query.get("distance")
|
31 |
+
if isinstance(distance, (int, float)):
|
32 |
+
filtered_df = filtered_df[filtered_df["distance"] <= distance]
|
33 |
+
elif distance in ["nearby", "close"]:
|
34 |
+
filtered_df = filtered_df[filtered_df["distance"] <= 2.0]
|
35 |
+
elif distance == "far":
|
36 |
+
filtered_df = filtered_df[filtered_df["distance"] <= 10.0]
|
37 |
+
|
38 |
+
if parsed_query.get("rating"):
|
39 |
+
filtered_df = filtered_df[filtered_df["rating"] >= parsed_query["rating"]]
|
40 |
+
|
41 |
+
return filtered_df
|
src/retrieval/vector_store.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/retrieval/vector_store.py
|
2 |
+
|
3 |
+
from langchain_community.vectorstores import Chroma
|
4 |
+
from langchain_core.documents import Document
|
5 |
+
import numpy as np
|
6 |
+
from typing import List, Dict, Any
|
7 |
+
|
8 |
+
class VectorStore:
|
9 |
+
def __init__(self, embedding_function):
|
10 |
+
self.embedding_function = embedding_function
|
11 |
+
self.collection = None
|
12 |
+
|
13 |
+
def add_documents(self, documents: List[str], embeddings: List[np.ndarray], ids: List[str]):
|
14 |
+
langchain_docs = [Document(page_content=doc, metadata={"id": id}) for doc, id in zip(documents, ids)]
|
15 |
+
self.collection = Chroma.from_documents(
|
16 |
+
documents=langchain_docs,
|
17 |
+
embedding=self.embedding_function,
|
18 |
+
ids=ids,
|
19 |
+
persist_directory="./chroma_db"
|
20 |
+
)
|
21 |
+
self.collection.persist()
|
22 |
+
|
23 |
+
def query(self, query_embedding: np.ndarray, top_k: int = 5) -> Dict[str, Any]:
|
24 |
+
results = self.collection.similarity_search_by_vector(
|
25 |
+
embedding=query_embedding,
|
26 |
+
k=top_k
|
27 |
+
)
|
28 |
+
ids = [doc.metadata["id"] for doc in results]
|
29 |
+
distances = [1 - np.dot(query_embedding, doc.vector) / (np.linalg.norm(query_embedding) * np.linalg.norm(doc.vector))
|
30 |
+
if hasattr(doc, "vector") else 1.0 for doc in results]
|
31 |
+
return {
|
32 |
+
"ids": [ids],
|
33 |
+
"distances": [distances]
|
34 |
+
}
|
src/utils/data_loader.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/utils/data_loader.py
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import json
|
5 |
+
|
6 |
+
def load_restaurant_data(file_path: str) -> pd.DataFrame:
|
7 |
+
"""
|
8 |
+
Load restaurant data from JSON file into a DataFrame.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
file_path (str): Path to the JSON file.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
pd.DataFrame: DataFrame containing restaurant data.
|
15 |
+
"""
|
16 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
17 |
+
data = json.load(f)
|
18 |
+
|
19 |
+
df = pd.DataFrame(data)
|
20 |
+
# Create a text field for embedding and BM25
|
21 |
+
df['text'] = df.apply(
|
22 |
+
lambda row: f"{row['name']} ({row['cuisine']}): {', '.join(row['dishes'])}. "
|
23 |
+
f"Price: {row['price_range']}, Distance: {row['distance']} km, "
|
24 |
+
f"Rating: {row['rating']}. Description: {row['description']}",
|
25 |
+
axis=1
|
26 |
+
)
|
27 |
+
return df
|
28 |
+
if __name__ == "__main__":
|
29 |
+
print(load_restaurant_data("./data/restaurants.json"))
|
src/utils/query_parser.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/utils/query_parser.py
|
2 |
+
import pandas as pd
|
3 |
+
import json
|
4 |
+
from typing import Dict, Any
|
5 |
+
from src.generation.llm import LLM
|
6 |
+
|
7 |
+
class QueryParser:
|
8 |
+
def __init__(self, df: pd.DataFrame):
|
9 |
+
"""
|
10 |
+
Initialize the query parser with restaurant data.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
df (pd.DataFrame): DataFrame containing restaurant data.
|
14 |
+
"""
|
15 |
+
self.llm = LLM()
|
16 |
+
self.df = df
|
17 |
+
self.valid_cuisines = sorted(self.df['cuisine'].unique().tolist())
|
18 |
+
self.valid_price_ranges = sorted(self.df['price_range'].unique().tolist())
|
19 |
+
self.valid_dishes = sorted(set([dish for dishes in self.df['dishes'] for dish in dishes]))
|
20 |
+
|
21 |
+
def parse_query(self, query: str) -> Dict[str, Any]:
|
22 |
+
"""
|
23 |
+
Parse the query to extract features.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
query (str): User query.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
Dict[str, Any]: Parsed features.
|
30 |
+
"""
|
31 |
+
# Format prompt using LLM's prompt template
|
32 |
+
prompt = self.llm.format_query_prompt(
|
33 |
+
query=query,
|
34 |
+
cuisines=self.valid_cuisines,
|
35 |
+
dishes=self.valid_dishes,
|
36 |
+
price_ranges=self.valid_price_ranges
|
37 |
+
)
|
38 |
+
|
39 |
+
# Generate response
|
40 |
+
response = self.llm.generate(prompt)
|
41 |
+
|
42 |
+
# Parse JSON response
|
43 |
+
try:
|
44 |
+
json_start = response.find("{")
|
45 |
+
json_end = response.rfind("}") + 1
|
46 |
+
parsed = json.loads(response[json_start:json_end])
|
47 |
+
return parsed
|
48 |
+
except json.JSONDecodeError:
|
49 |
+
return {
|
50 |
+
"cuisine": None,
|
51 |
+
"menu": [],
|
52 |
+
"price_range": None,
|
53 |
+
"distance": None,
|
54 |
+
"rating": None,
|
55 |
+
"description": query
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
# Quick test block for QueryParser
|
60 |
+
if __name__ == "__main__":
|
61 |
+
import pandas as pd
|
62 |
+
|
63 |
+
sample_data = {
|
64 |
+
"cuisine": ["Italian", "Japanese"],
|
65 |
+
"price_range": ["$", "$$"],
|
66 |
+
"dishes": [["pizza", "pasta"], ["sushi", "ramen"]]
|
67 |
+
}
|
68 |
+
df = pd.DataFrame(sample_data)
|
69 |
+
parser = QueryParser(df)
|
70 |
+
|
71 |
+
user_query = "I want cheap sushi"
|
72 |
+
result = parser.parse_query(user_query)
|
73 |
+
|
74 |
+
print("Parsed Query Result:")
|
75 |
+
print(result)
|
76 |
+
|
utils/data_loader.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/utils/data_loader.py
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import json
|
5 |
+
|
6 |
+
def load_restaurant_data(file_path: str) -> pd.DataFrame:
|
7 |
+
"""
|
8 |
+
Load restaurant data from JSON file into a DataFrame.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
file_path (str): Path to the JSON file.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
pd.DataFrame: DataFrame containing restaurant data.
|
15 |
+
"""
|
16 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
17 |
+
data = json.load(f)
|
18 |
+
|
19 |
+
df = pd.DataFrame(data)
|
20 |
+
# Create a text field for embedding and BM25
|
21 |
+
df['text'] = df.apply(
|
22 |
+
lambda row: f"{row['name']} ({row['cuisine']}): {', '.join(row['dishes'])}. "
|
23 |
+
f"Price: {row['price_range']}, Distance: {row['distance']} km, "
|
24 |
+
f"Rating: {row['rating']}. Description: {row['description']}",
|
25 |
+
axis=1
|
26 |
+
)
|
27 |
+
return df
|
28 |
+
if __name__ == "__main__":
|
29 |
+
print(load_restaurant_data("./data/restaurants.json"))
|
utils/query_parser.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# src/utils/query_parser.py
|
2 |
+
import pandas as pd
|
3 |
+
import json
|
4 |
+
from typing import Dict, Any
|
5 |
+
from src.generation.llm import LLM
|
6 |
+
|
7 |
+
class QueryParser:
|
8 |
+
def __init__(self, df: pd.DataFrame):
|
9 |
+
"""
|
10 |
+
Initialize the query parser with restaurant data.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
df (pd.DataFrame): DataFrame containing restaurant data.
|
14 |
+
"""
|
15 |
+
self.llm = LLM()
|
16 |
+
self.df = df
|
17 |
+
self.valid_cuisines = sorted(self.df['cuisine'].unique().tolist())
|
18 |
+
self.valid_price_ranges = sorted(self.df['price_range'].unique().tolist())
|
19 |
+
self.valid_dishes = sorted(set([dish for dishes in self.df['dishes'] for dish in dishes]))
|
20 |
+
|
21 |
+
def parse_query(self, query: str) -> Dict[str, Any]:
|
22 |
+
"""
|
23 |
+
Parse the query to extract features.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
query (str): User query.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
Dict[str, Any]: Parsed features.
|
30 |
+
"""
|
31 |
+
# Format prompt using LLM's prompt template
|
32 |
+
prompt = self.llm.format_query_prompt(
|
33 |
+
query=query,
|
34 |
+
cuisines=self.valid_cuisines,
|
35 |
+
dishes=self.valid_dishes,
|
36 |
+
price_ranges=self.valid_price_ranges
|
37 |
+
)
|
38 |
+
|
39 |
+
# Generate response
|
40 |
+
response = self.llm.generate(prompt)
|
41 |
+
|
42 |
+
# Parse JSON response
|
43 |
+
try:
|
44 |
+
json_start = response.find("{")
|
45 |
+
json_end = response.rfind("}") + 1
|
46 |
+
parsed = json.loads(response[json_start:json_end])
|
47 |
+
return parsed
|
48 |
+
except json.JSONDecodeError:
|
49 |
+
return {
|
50 |
+
"cuisine": None,
|
51 |
+
"menu": [],
|
52 |
+
"price_range": None,
|
53 |
+
"distance": None,
|
54 |
+
"rating": None,
|
55 |
+
"description": query
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
# Quick test block for QueryParser
|
60 |
+
if __name__ == "__main__":
|
61 |
+
import pandas as pd
|
62 |
+
|
63 |
+
sample_data = {
|
64 |
+
"cuisine": ["Italian", "Japanese"],
|
65 |
+
"price_range": ["$", "$$"],
|
66 |
+
"dishes": [["pizza", "pasta"], ["sushi", "ramen"]]
|
67 |
+
}
|
68 |
+
df = pd.DataFrame(sample_data)
|
69 |
+
parser = QueryParser(df)
|
70 |
+
|
71 |
+
user_query = "I want cheap sushi"
|
72 |
+
result = parser.parse_query(user_query)
|
73 |
+
|
74 |
+
print("Parsed Query Result:")
|
75 |
+
print(result)
|
76 |
+
|