hieu-nguyen2208 commited on
Commit
4363820
·
1 Parent(s): 7da6e45
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.chatbot import RestaurantChatbot
3
+
4
+ chatbot = RestaurantChatbot()
5
+ chat_history = []
6
+
7
+ def respond(user_message, history):
8
+ response, retrieved_docs = chatbot.answer(user_message)
9
+
10
+ bot_response = f"{response}\n\n**Nhà hàng gợi ý:**\n"
11
+ if retrieved_docs:
12
+ for doc in retrieved_docs:
13
+ bot_response += (
14
+ f"- **{doc['name']} ({doc['cuisine']})**\n"
15
+ f" - Món ăn: {', '.join(doc['dishes'])}\n"
16
+ f" - Giá: {doc['price_range']}\n"
17
+ f" - Khoảng cách: {doc['distance']} km\n"
18
+ f" - Đánh giá: {doc['rating']}\n"
19
+ f" - Địa chỉ: {doc['address']}\n"
20
+ f" - Mô tả: {doc['description']}\n"
21
+ )
22
+ else:
23
+ bot_response += "- Không tìm thấy nhà hàng phù hợp."
24
+
25
+ return bot_response
26
+
27
+ with gr.Blocks() as demo:
28
+ gr.Markdown("## Chatbot Gợi ý Quán ăn")
29
+ chatbot_ui = gr.ChatInterface(fn=respond, chatbot=gr.Chatbot())
30
+
31
+ demo.launch()
chatbot.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from typing import Tuple, List, Dict, Any
3
+ from src.utils.data_loader import load_restaurant_data
4
+ from src.utils.query_parser import QueryParser
5
+ from src.embeddings.embedder import Embedder
6
+ from src.retrieval.vector_store import VectorStore
7
+ from src.retrieval.keyword_filter import filter_restaurants
8
+ from src.retrieval.hybrid_search import HybridRetriever
9
+ from src.generation.llm import LLM
10
+ from langchain_core.embeddings import Embeddings
11
+
12
+ class LangChainEmbeddingWrapper(Embeddings):
13
+ def __init__(self, embedder):
14
+ self.embedder = embedder
15
+
16
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
17
+ return self.embedder.embed(texts).tolist()
18
+
19
+ def embed_query(self, text: str) -> List[float]:
20
+ return self.embedder.embed([text])[0].tolist()
21
+
22
+ class RestaurantChatbot:
23
+ def __init__(self, data_path: str = "data/restaurants.json"):
24
+ """
25
+ Initialize the restaurant chatbot.
26
+
27
+ Args:
28
+ data_path (str): Path to the restaurant JSON file.
29
+ """
30
+ self.df = load_restaurant_data(data_path)
31
+ self.embedder = Embedder()
32
+ self.embedding_wrapper = LangChainEmbeddingWrapper(self.embedder)
33
+ self.vector_store = VectorStore(embedding_function=self.embedding_wrapper)
34
+ self.llm = LLM()
35
+ self.parser = QueryParser(self.df)
36
+
37
+ embeddings = self.embedder.embed(self.df['text'].tolist())
38
+ self.vector_store.add_documents(
39
+ documents=self.df['text'].tolist(),
40
+ embeddings=embeddings.tolist(),
41
+ ids=[str(i) for i in self.df['id']]
42
+ )
43
+
44
+ self.retriever = HybridRetriever(self.df, self.vector_store, self.embedder)
45
+
46
+ def answer(self, query: str) -> Tuple[str, List[Dict[str, Any]]]:
47
+ """
48
+ Process a user query and return a natural, concise response with recommended restaurants.
49
+
50
+ Args:
51
+ query (str): User query.
52
+
53
+ Returns:
54
+ Tuple[str, List[Dict[str, Any]]]: Natural response text and list of recommended restaurants.
55
+ """
56
+ parsed_query = self.parser.parse_query(query)
57
+ filtered_df = filter_restaurants(self.df, parsed_query)
58
+ description = parsed_query["description"] if parsed_query["description"] else query
59
+
60
+ if filtered_df.empty:
61
+ retrieved_docs = self.retriever.retrieve(description, self.df, top_k=3)
62
+ else:
63
+ retrieved_docs = self.retriever.retrieve(description, filtered_df, top_k=3)
64
+
65
+ if not retrieved_docs:
66
+ return "Mình không tìm được nhà hàng nào phù hợp. Bạn thử đổi tiêu chí xem, như mở rộng khoảng cách hoặc loại món ăn nhé!", []
67
+
68
+ # Create context for LLM
69
+ context = "\n".join([
70
+ f"- {doc['name']} ({doc['cuisine']}): {', '.join(doc['dishes'])}. "
71
+ f"Price: {doc['price_range']}, Distance: {doc['distance']} km, Rating: {doc['rating']}. "
72
+ f"Description: {doc['description']}"
73
+ for doc in retrieved_docs
74
+ ])
75
+
76
+ # Prompt for natural, consultant-like response
77
+ prompt = (
78
+ f"Bạn là một người tư vấn nhà hàng thân thiện. Dựa trên truy vấn và danh sách nhà hàng, hãy gợi ý ngắn gọn, tự nhiên, như trò chuyện với bạn bè, giải thích tại sao chọn các nhà hàng này (tập trung vào món ăn, giá, khoảng cách, hoặc đánh giá phù hợp với truy vấn). Không lặp lại truy vấn hoặc dùng ngôn ngữ kỹ thuật. Chỉ dùng thông tin từ danh sách nhà hàng.\n\n"
79
+ f"Truy vấn: {query}\n\n"
80
+ f"Danh sách nhà hàng:\n{context}\n\n"
81
+ f"Phản hồi:"
82
+ )
83
+
84
+ response = self.llm.generate(prompt, max_length=200)
85
+ return response, retrieved_docs
data/restaurants.json ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": 1,
4
+ "name": "Phở Hà Nội",
5
+ "cuisine": "Vietnamese",
6
+ "distance": 1.2,
7
+ "price_range": "medium",
8
+ "dishes": ["phở", "bún chả", "bánh mì", "gỏi cuốn"],
9
+ "rating": 4.5,
10
+ "address": "123 Nguyễn Huệ, Quận 1, TP.HCM, Việt Nam",
11
+ "description": "A cozy spot offering authentic Vietnamese phở with rich broth and fresh herbs.",
12
+ "dietary_options": ["vegetarian"],
13
+ "service_type": ["dine-in", "takeout"],
14
+ "kid_friendly": true,
15
+ "outdoor_seating": false,
16
+ "alcohol": false
17
+ },
18
+ {
19
+ "id": 2,
20
+ "name": "Pizza 4P's",
21
+ "cuisine": "Italian",
22
+ "distance": 2.5,
23
+ "price_range": "high",
24
+ "dishes": ["pizza", "pasta", "tiramisu", "salad"],
25
+ "rating": 4.8,
26
+ "address": "8/15 Lê Thánh Tôn, Quận 1, TP.HCM, Việt Nam",
27
+ "description": "Upscale Italian dining with artisanal pizzas and a vibrant atmosphere.",
28
+ "dietary_options": ["vegetarian", "gluten-free"],
29
+ "service_type": ["dine-in", "delivery"],
30
+ "kid_friendly": true,
31
+ "outdoor_seating": true,
32
+ "alcohol": true
33
+ },
34
+ {
35
+ "id": 3,
36
+ "name": "Sushi Bar",
37
+ "cuisine": "Japanese",
38
+ "distance": 3.0,
39
+ "price_range": "high",
40
+ "dishes": ["sushi", "sashimi", "ramen", "tempura"],
41
+ "rating": 4.3,
42
+ "address": "45 Lý Tự Trọng, Quận 1, TP.HCM, Việt Nam",
43
+ "description": "Authentic Japanese sushi with fresh fish and a modern ambiance.",
44
+ "dietary_options": [],
45
+ "service_type": ["dine-in"],
46
+ "kid_friendly": false,
47
+ "outdoor_seating": false,
48
+ "alcohol": true
49
+ },
50
+ {
51
+ "id": 4,
52
+ "name": "Bún Bò Huế An Nam",
53
+ "cuisine": "Vietnamese",
54
+ "distance": 0.8,
55
+ "price_range": "low",
56
+ "dishes": ["bún bò huế", "bánh bèo", "bánh nậm"],
57
+ "rating": 4.0,
58
+ "address": "78 Pasteur, Quận 1, TP.HCM, Việt Nam",
59
+ "description": "A local favorite for spicy Huế-style noodle soups with bold flavors.",
60
+ "dietary_options": [],
61
+ "service_type": ["dine-in", "takeout"],
62
+ "kid_friendly": true,
63
+ "outdoor_seating": false,
64
+ "alcohol": false
65
+ },
66
+ {
67
+ "id": 5,
68
+ "name": "Thai Spice",
69
+ "cuisine": "Thai",
70
+ "distance": 4.5,
71
+ "price_range": "medium",
72
+ "dishes": ["tom yum", "pad thai", "green curry", "mango sticky rice"],
73
+ "rating": 4.2,
74
+ "address": "12 Nguyễn Đình Chiểu, Đa Kao, Quận 1, TP.HCM, Việt Nam",
75
+ "description": "Vibrant Thai restaurant with spicy curries and refreshing tropical desserts.",
76
+ "dietary_options": ["vegetarian", "vegan"],
77
+ "service_type": ["dine-in", "delivery", "takeout"],
78
+ "kid_friendly": true,
79
+ "outdoor_seating": true,
80
+ "alcohol": true
81
+ },
82
+ {
83
+ "id": 6,
84
+ "name": "La Fiesta",
85
+ "cuisine": "Mexican",
86
+ "distance": 2.0,
87
+ "price_range": "medium",
88
+ "dishes": ["tacos", "burrito", "quesadilla", "guacamole"],
89
+ "rating": 4.1,
90
+ "address": "33 Tôn Thất Thiệp, Quận 1, TP.HCM, Việt Nam",
91
+ "description": "Lively Mexican eatery with colorful decor and flavorful tacos.",
92
+ "dietary_options": ["vegetarian"],
93
+ "service_type": ["dine-in", "delivery"],
94
+ "kid_friendly": true,
95
+ "outdoor_seating": true,
96
+ "alcohol": true
97
+ },
98
+ {
99
+ "id": 7,
100
+ "name": "The Vegan Garden",
101
+ "cuisine": "Vegan",
102
+ "distance": 5.0,
103
+ "price_range": "medium",
104
+ "dishes": ["vegan phở", "tofu curry", "quinoa salad"],
105
+ "rating": 4.6,
106
+ "address": "20 Võ Văn Tần, Quận 3, TP.HCM, Việt Nam",
107
+ "description": "A serene vegan restaurant offering plant-based Vietnamese and international dishes.",
108
+ "dietary_options": ["vegan", "gluten-free"],
109
+ "service_type": ["dine-in", "takeout"],
110
+ "kid_friendly": true,
111
+ "outdoor_seating": false,
112
+ "alcohol": false
113
+ },
114
+ {
115
+ "id": 8,
116
+ "name": "Le Petit Paris",
117
+ "cuisine": "French",
118
+ "distance": 1.5,
119
+ "price_range": "high",
120
+ "dishes": ["croissant", "coq au vin", "crème brûlée"],
121
+ "rating": 4.7,
122
+ "address": "56 Đồng Khởi, Quận 1, TP.HCM, Việt Nam",
123
+ "description": "Elegant French bistro with classic dishes and a romantic ambiance.",
124
+ "dietary_options": [],
125
+ "service_type": ["dine-in"],
126
+ "kid_friendly": false,
127
+ "outdoor_seating": true,
128
+ "alcohol": true
129
+ },
130
+ {
131
+ "id": 9,
132
+ "name": "Chaat House",
133
+ "cuisine": "Indian",
134
+ "distance": 3.8,
135
+ "price_range": "medium",
136
+ "dishes": ["butter chicken", "naan", "paneer tikka", "biryani"],
137
+ "rating": 4.4,
138
+ "address": "15 Lê Lợi, Quận 1, TP.HCM, Việt Nam",
139
+ "description": "Authentic Indian cuisine with rich spices and a warm atmosphere.",
140
+ "dietary_options": ["vegetarian", "vegan"],
141
+ "service_type": ["dine-in", "delivery", "takeout"],
142
+ "kid_friendly": true,
143
+ "outdoor_seating": false,
144
+ "alcohol": false
145
+ },
146
+ {
147
+ "id": 10,
148
+ "name": "BBQ Haven",
149
+ "cuisine": "American",
150
+ "distance": 6.0,
151
+ "price_range": "medium",
152
+ "dishes": ["ribs", "burger", "fries", "coleslaw"],
153
+ "rating": 4.0,
154
+ "address": "88 Nguyễn Trãi, Quận 5, TP.HCM, Việt Nam",
155
+ "description": "Casual American BBQ spot with smoky ribs and hearty burgers.",
156
+ "dietary_options": [],
157
+ "service_type": ["dine-in", "delivery"],
158
+ "kid_friendly": true,
159
+ "outdoor_seating": true,
160
+ "alcohol": true
161
+ }
162
+ ]
embeddings/embedder.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import torch
3
+ import numpy as np
4
+ from typing import List
5
+
6
+ class Embedder:
7
+ def __init__(self, model_name: str = "BAAI/bge-m3"):
8
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ self.model = AutoModel.from_pretrained(model_name)
10
+
11
+ def embed(self, texts: List[str]) -> np.ndarray:
12
+ inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
13
+ with torch.no_grad():
14
+ outputs = self.model(**inputs)
15
+ embeddings = outputs.last_hidden_state[:, 0] # lấy embedding từ CLS token
16
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
17
+ return embeddings.cpu().numpy()
generation/llm.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from langchain_core.prompts import PromptTemplate
3
+ import os
4
+ from typing import List
5
+
6
+ class LLM:
7
+ def __init__(self, model_repo: str = "Qwen/Qwen2-1.5B-Instruct",
8
+ local_path: str = "models"):
9
+ """
10
+ Initialize the LLM with Qwen2-1.5B-Instruct using Hugging Face Transformers.
11
+
12
+ Args:
13
+ model_repo (str): Hugging Face repository ID for the model.
14
+ local_path (str): Local directory to store the model.
15
+ """
16
+ os.makedirs(local_path, exist_ok=True)
17
+
18
+ try:
19
+ # Load the model
20
+ self.llm = AutoModelForCausalLM.from_pretrained(
21
+ model_repo,
22
+ device_map="auto", # Automatically map to CPU
23
+ cache_dir=local_path,
24
+ trust_remote_code=True
25
+ )
26
+
27
+ # Load the tokenizer
28
+ self.tokenizer = AutoTokenizer.from_pretrained(
29
+ model_repo,
30
+ cache_dir=local_path,
31
+ trust_remote_code=True
32
+ )
33
+ print(f"Model successfully loaded from {model_repo}")
34
+ except Exception as e:
35
+ raise RuntimeError(
36
+ f"Failed to initialize model from {model_repo}. "
37
+ f"Please ensure the model is available at https://huggingface.co/{model_repo}. "
38
+ f"Error: {str(e)}"
39
+ )
40
+
41
+ # Define prompt template for query parsing (used in query_parser.py)
42
+ self.prompt_template = PromptTemplate(
43
+ template="""Bạn là một trợ lý phân tích truy vấn nhà hàng. Phân tích truy vấn sau và trích xuất các đặc trưng: cuisine, menu, price_range, distance, rating, và description. Chỉ trích xuất các giá trị khớp chính xác với danh sách giá trị hợp lệ. Nếu không tìm thấy giá trị khớp, trả về null (hoặc [] cho menu). Loại bỏ các từ khóa đã trích xuất khỏi description. Trả về kết quả dưới dạng JSON.
44
+
45
+ **Danh sách giá trị hợp lệ**:
46
+ - cuisine: {cuisines}
47
+ - menu: {dishes}
48
+ - price_range: {price_ranges}
49
+
50
+ **Hướng dẫn**:
51
+ - cuisine: Chỉ chọn giá trị từ danh sách cuisine. Ví dụ, "Viet" → "Vietnamese".
52
+ - menu: Chỉ chọn các món khớp chính xác với danh sách menu. Ví dụ, "phở bò" → "phở", "sushi" → [].
53
+ - price_range: Chỉ chọn {price_ranges}. Ví dụ, "cheap" → "low".
54
+ - distance: Trích xuất số km (e.g., "2 km" → 2.0) hoặc từ khóa ["nearby", "close" → 2.0, "far" → 10.0]. Nếu không rõ, trả về null.
55
+ - rating: Trích xuất số (e.g., "4 stars" → 4.0). Nếu không rõ, trả về null.
56
+ - description: Phần còn lại sau khi loại bỏ các từ khóa đã trích xuất. Nếu rỗng, trả về truy vấn gốc.
57
+
58
+ **Truy vấn**: {query}
59
+
60
+ **Định dạng đầu ra**:
61
+ {{
62
+ "cuisine": null | "tên loại ẩm thực",
63
+ "menu": [],
64
+ "price_range": null | "low" | "medium" | "high",
65
+ "distance": null | số km | "nearby" | "close" | "far",
66
+ "rating": null | số,
67
+ "description": "phần mô tả còn lại"
68
+ }}
69
+ """,
70
+ input_variables=["cuisines", "dishes", "price_ranges", "query"]
71
+ )
72
+
73
+ def generate(self, prompt: str, max_length: int = 1000) -> str:
74
+ """
75
+ Generate text using the LLM.
76
+
77
+ Args:
78
+ prompt (str): Input prompt.
79
+ max_length (int): Maximum length of the generated text.
80
+
81
+ Returns:
82
+ str: Generated text.
83
+ """
84
+ try:
85
+ # Apply chat template for instruction-tuned Qwen model
86
+ messages = [{"role": "user", "content": prompt}]
87
+ prompt_with_template = self.tokenizer.apply_chat_template(
88
+ messages, tokenize=False, add_generation_prompt=True
89
+ )
90
+ # Tokenize input prompt
91
+ inputs = self.tokenizer(prompt_with_template, return_tensors="pt").to(self.llm.device)
92
+ # Generate text
93
+ outputs = self.llm.generate(
94
+ **inputs,
95
+ max_new_tokens=max_length,
96
+ temperature=0.7,
97
+ do_sample=True,
98
+ pad_token_id=self.tokenizer.eos_token_id
99
+ )
100
+ # Decode the generated tokens
101
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
102
+ print("Response generated successfully!")
103
+ return response.strip()
104
+ except Exception as e:
105
+ raise RuntimeError(f"Failed to generate response: {str(e)}")
106
+
107
+ def format_query_prompt(self, query: str, cuisines: List[str], dishes: List[str], price_ranges: List[str]) -> str:
108
+ """
109
+ Format the prompt for query parsing using the prompt template.
110
+
111
+ Args:
112
+ query (str): User query.
113
+ cuisines (list): List of valid cuisines.
114
+ dishes (list): List of valid dishes.
115
+ price_ranges (list): List of valid price ranges.
116
+
117
+ Returns:
118
+ str: Formatted prompt.
119
+ """
120
+ return self.prompt_template.format(
121
+ cuisines=cuisines,
122
+ dishes=dishes,
123
+ price_ranges=price_ranges,
124
+ query=query
125
+ )
126
+
127
+ if __name__ == "__main__":
128
+ # Khởi tạo đối tượng LLM với model_repo và local_path
129
+ local_path = 'models'
130
+
131
+ try:
132
+ # Khởi tạo đối tượng LLM
133
+ llm = LLM(local_path=local_path)
134
+
135
+ # Định nghĩa một truy vấn và các tham số cần thiết
136
+ query = "Tìm quán ăn Việt Nam gần đây, giá rẻ với món phở và cơm tấm"
137
+ cuisines = ["Vietnamese", "Chinese", "Italian"]
138
+ dishes = ["phở", "sushi", "pasta", "cơm tấm"]
139
+ price_ranges = ["low", "medium", "high"]
140
+
141
+ # Sử dụng hàm generate để tạo câu trả lời từ truy vấn
142
+ generated_text = llm.generate(query, max_length=300)
143
+
144
+ # In kết quả ra màn hình
145
+ print("Generated text:")
146
+ print(generated_text)
147
+
148
+ except Exception as e:
149
+ print(f"Error: {str(e)}")
llm.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from langchain_core.prompts import PromptTemplate
3
+ import os
4
+ from typing import List
5
+
6
+ class LLM:
7
+ def __init__(self, model_repo: str = "Qwen/Qwen2-1.5B-Instruct",
8
+ local_path: str = "models"):
9
+ """
10
+ Initialize the LLM with Qwen2-1.5B-Instruct using Hugging Face Transformers.
11
+
12
+ Args:
13
+ model_repo (str): Hugging Face repository ID for the model.
14
+ local_path (str): Local directory to store the model.
15
+ """
16
+ os.makedirs(local_path, exist_ok=True)
17
+
18
+ try:
19
+ # Load the model
20
+ self.llm = AutoModelForCausalLM.from_pretrained(
21
+ model_repo,
22
+ device_map="auto", # Automatically map to CPU
23
+ cache_dir=local_path,
24
+ trust_remote_code=True
25
+ )
26
+
27
+ # Load the tokenizer
28
+ self.tokenizer = AutoTokenizer.from_pretrained(
29
+ model_repo,
30
+ cache_dir=local_path,
31
+ trust_remote_code=True
32
+ )
33
+ print(f"Model successfully loaded from {model_repo}")
34
+ except Exception as e:
35
+ raise RuntimeError(
36
+ f"Failed to initialize model from {model_repo}. "
37
+ f"Please ensure the model is available at https://huggingface.co/{model_repo}. "
38
+ f"Error: {str(e)}"
39
+ )
40
+
41
+ # Define prompt template for query parsing (used in query_parser.py)
42
+ self.prompt_template = PromptTemplate(
43
+ template="""Bạn là một trợ lý phân tích truy vấn nhà hàng. Phân tích truy vấn sau và trích xuất các đặc trưng: cuisine, menu, price_range, distance, rating, và description. Chỉ trích xuất các giá trị khớp chính xác với danh sách giá trị hợp lệ. Nếu không tìm thấy giá trị khớp, trả về null (hoặc [] cho menu). Loại bỏ các từ khóa đã trích xuất khỏi description. Trả về kết quả dưới dạng JSON.
44
+
45
+ **Danh sách giá trị hợp lệ**:
46
+ - cuisine: {cuisines}
47
+ - menu: {dishes}
48
+ - price_range: {price_ranges}
49
+
50
+ **Hướng dẫn**:
51
+ - cuisine: Chỉ chọn giá trị từ danh sách cuisine. Ví dụ, "Viet" → "Vietnamese".
52
+ - menu: Chỉ chọn các món khớp chính xác với danh sách menu. Ví dụ, "phở bò" → "phở", "sushi" → [].
53
+ - price_range: Chỉ chọn {price_ranges}. Ví dụ, "cheap" → "low".
54
+ - distance: Trích xuất số km (e.g., "2 km" → 2.0) hoặc từ khóa ["nearby", "close" → 2.0, "far" → 10.0]. Nếu không rõ, trả về null.
55
+ - rating: Trích xuất số (e.g., "4 stars" → 4.0). Nếu không rõ, trả về null.
56
+ - description: Phần còn lại sau khi loại bỏ các từ khóa đã trích xuất. Nếu rỗng, trả về truy vấn gốc.
57
+
58
+ **Truy vấn**: {query}
59
+
60
+ **Định dạng đầu ra**:
61
+ {{
62
+ "cuisine": null | "tên loại ẩm thực",
63
+ "menu": [],
64
+ "price_range": null | "low" | "medium" | "high",
65
+ "distance": null | số km | "nearby" | "close" | "far",
66
+ "rating": null | số,
67
+ "description": "phần mô tả còn lại"
68
+ }}
69
+ """,
70
+ input_variables=["cuisines", "dishes", "price_ranges", "query"]
71
+ )
72
+
73
+ def generate(self, prompt: str, max_length: int = 1000) -> str:
74
+ """
75
+ Generate text using the LLM.
76
+
77
+ Args:
78
+ prompt (str): Input prompt.
79
+ max_length (int): Maximum length of the generated text.
80
+
81
+ Returns:
82
+ str: Generated text.
83
+ """
84
+ try:
85
+ # Apply chat template for instruction-tuned Qwen model
86
+ messages = [{"role": "user", "content": prompt}]
87
+ prompt_with_template = self.tokenizer.apply_chat_template(
88
+ messages, tokenize=False, add_generation_prompt=True
89
+ )
90
+ # Tokenize input prompt
91
+ inputs = self.tokenizer(prompt_with_template, return_tensors="pt").to(self.llm.device)
92
+ # Generate text
93
+ outputs = self.llm.generate(
94
+ **inputs,
95
+ max_new_tokens=max_length,
96
+ temperature=0.7,
97
+ do_sample=True,
98
+ pad_token_id=self.tokenizer.eos_token_id
99
+ )
100
+ # Decode the generated tokens
101
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
102
+ print("Response generated successfully!")
103
+ return response.strip()
104
+ except Exception as e:
105
+ raise RuntimeError(f"Failed to generate response: {str(e)}")
106
+
107
+ def format_query_prompt(self, query: str, cuisines: List[str], dishes: List[str], price_ranges: List[str]) -> str:
108
+ """
109
+ Format the prompt for query parsing using the prompt template.
110
+
111
+ Args:
112
+ query (str): User query.
113
+ cuisines (list): List of valid cuisines.
114
+ dishes (list): List of valid dishes.
115
+ price_ranges (list): List of valid price ranges.
116
+
117
+ Returns:
118
+ str: Formatted prompt.
119
+ """
120
+ return self.prompt_template.format(
121
+ cuisines=cuisines,
122
+ dishes=dishes,
123
+ price_ranges=price_ranges,
124
+ query=query
125
+ )
126
+
127
+ if __name__ == "__main__":
128
+ # Khởi tạo đối tượng LLM với model_repo và local_path
129
+ local_path = 'models'
130
+
131
+ try:
132
+ # Khởi tạo đối tượng LLM
133
+ llm = LLM(local_path=local_path)
134
+
135
+ # Định nghĩa một truy vấn và các tham số cần thiết
136
+ query = "Tìm quán ăn Việt Nam gần đây, giá rẻ với món phở và cơm tấm"
137
+ cuisines = ["Vietnamese", "Chinese", "Italian"]
138
+ dishes = ["phở", "sushi", "pasta", "cơm tấm"]
139
+ price_ranges = ["low", "medium", "high"]
140
+
141
+ # Sử dụng hàm generate để tạo câu trả lời từ truy vấn
142
+ generated_text = llm.generate(query, max_length=300)
143
+
144
+ # In kết quả ra màn hình
145
+ print("Generated text:")
146
+ print(generated_text)
147
+
148
+ except Exception as e:
149
+ print(f"Error: {str(e)}")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ pandas
3
+ torch
4
+ transformers
5
+ chromadb
6
+ langchain-core
7
+ rank-bm25
retrieval/hybrid_search.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/retrieval/hybrid_search.py
2
+
3
+ import numpy as np
4
+ from rank_bm25 import BM25Okapi
5
+ from typing import List, Dict, Any
6
+ import pandas as pd
7
+ from src.embeddings.embedder import Embedder
8
+ from src.retrieval.vector_store import VectorStore
9
+
10
+ class HybridRetriever:
11
+ def __init__(self, df: pd.DataFrame, vector_store: VectorStore, embedder: Embedder, alpha: float = 0.5):
12
+ self.df = df
13
+ self.vector_store = vector_store
14
+ self.embedder = embedder
15
+ self.alpha = alpha
16
+ tokenized_corpus = [doc.lower().split() for doc in df['description']]
17
+ self.bm25 = BM25Okapi(tokenized_corpus)
18
+
19
+ def retrieve(self, query: str, filtered_df: pd.DataFrame, top_k: int = 3) -> List[Dict[str, Any]]:
20
+ filtered_indices = filtered_df.index.tolist()
21
+ filtered_texts = filtered_df['description'].tolist()
22
+ filtered_ids = [str(row['id']) for _, row in filtered_df.iterrows()]
23
+
24
+ if not filtered_texts:
25
+ return []
26
+
27
+ query_embedding = self.embedder.embed([query])[0]
28
+ dense_results = self.vector_store.query(query_embedding, top_k=top_k * 2)
29
+ dense_ids = [id for id in dense_results['ids'][0] if id in filtered_ids]
30
+ dense_scores = [1 - dist for dist, id in zip(dense_results['distances'][0], dense_results['ids'][0]) if id in filtered_ids]
31
+
32
+ tokenized_query = query.lower().split()
33
+ bm25_scores = self.bm25.get_scores(tokenized_query)
34
+ bm25_scores_filtered = [bm25_scores[i] for i in filtered_indices]
35
+ bm25_top_k = np.argsort(bm25_scores_filtered)[::-1][:top_k * 2]
36
+ bm25_ids = [filtered_ids[i] for i in bm25_top_k]
37
+ bm25_scores = [bm25_scores_filtered[i] for i in bm25_top_k]
38
+
39
+ dense_scores = np.array(dense_scores) / np.max(dense_scores) if dense_scores else dense_scores
40
+ bm25_scores = np.array(bm25_scores) / np.max(bm25_scores) if bm25_scores else bm25_scores
41
+
42
+ combined_scores = {}
43
+ for idx, dense_id in enumerate(dense_ids):
44
+ combined_scores[int(dense_id)] = combined_scores.get(int(dense_id), 0) + self.alpha * dense_scores[idx]
45
+ for idx, bm25_id in enumerate(bm25_ids):
46
+ combined_scores[int(bm25_id)] = combined_scores.get(int(bm25_id), 0) + (1 - self.alpha) * bm25_scores[idx]
47
+
48
+ sorted_ids = sorted(combined_scores, key=combined_scores.get, reverse=True)[:top_k]
49
+ return [self.df[self.df['id'] == id].iloc[0].to_dict() for id in sorted_ids]
retrieval/keyword_filter.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/retrieval/keyword_filter.py
2
+
3
+ import pandas as pd
4
+ from typing import Dict, Any
5
+
6
+ def filter_restaurants(df: pd.DataFrame, parsed_query: Dict[str, Any]) -> pd.DataFrame:
7
+ """
8
+ Filter restaurants based on extracted features from the query.
9
+
10
+ Args:
11
+ df (pd.DataFrame): DataFrame containing restaurant data.
12
+ parsed_query (Dict[str, Any]): Parsed query with features.
13
+
14
+ Returns:
15
+ pd.DataFrame: Filtered DataFrame.
16
+ """
17
+ filtered_df = df.copy()
18
+
19
+ if parsed_query.get("cuisine"):
20
+ filtered_df = filtered_df[filtered_df["cuisine"].str.lower() == parsed_query["cuisine"].lower()]
21
+
22
+ if parsed_query.get("menu"):
23
+ filtered_df = filtered_df[filtered_df["dishes"].apply(
24
+ lambda dishes: any(item.lower() in [d.lower() for d in dishes] for item in parsed_query["menu"])
25
+ )]
26
+
27
+ if parsed_query.get("price_range"):
28
+ filtered_df = filtered_df[filtered_df["price_range"].str.lower() == parsed_query["price_range"].lower()]
29
+
30
+ distance = parsed_query.get("distance")
31
+ if isinstance(distance, (int, float)):
32
+ filtered_df = filtered_df[filtered_df["distance"] <= distance]
33
+ elif distance in ["nearby", "close"]:
34
+ filtered_df = filtered_df[filtered_df["distance"] <= 2.0]
35
+ elif distance == "far":
36
+ filtered_df = filtered_df[filtered_df["distance"] <= 10.0]
37
+
38
+ if parsed_query.get("rating"):
39
+ filtered_df = filtered_df[filtered_df["rating"] >= parsed_query["rating"]]
40
+
41
+ return filtered_df
retrieval/vector_store.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/retrieval/vector_store.py
2
+
3
+ from langchain_community.vectorstores import Chroma
4
+ from langchain_core.documents import Document
5
+ import numpy as np
6
+ from typing import List, Dict, Any
7
+
8
+ class VectorStore:
9
+ def __init__(self, embedding_function):
10
+ self.embedding_function = embedding_function
11
+ self.collection = None
12
+
13
+ def add_documents(self, documents: List[str], embeddings: List[np.ndarray], ids: List[str]):
14
+ langchain_docs = [Document(page_content=doc, metadata={"id": id}) for doc, id in zip(documents, ids)]
15
+ self.collection = Chroma.from_documents(
16
+ documents=langchain_docs,
17
+ embedding=self.embedding_function,
18
+ ids=ids,
19
+ persist_directory="./chroma_db"
20
+ )
21
+ self.collection.persist()
22
+
23
+ def query(self, query_embedding: np.ndarray, top_k: int = 5) -> Dict[str, Any]:
24
+ results = self.collection.similarity_search_by_vector(
25
+ embedding=query_embedding,
26
+ k=top_k
27
+ )
28
+ ids = [doc.metadata["id"] for doc in results]
29
+ distances = [1 - np.dot(query_embedding, doc.vector) / (np.linalg.norm(query_embedding) * np.linalg.norm(doc.vector))
30
+ if hasattr(doc, "vector") else 1.0 for doc in results]
31
+ return {
32
+ "ids": [ids],
33
+ "distances": [distances]
34
+ }
src/chatbot.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from typing import Tuple, List, Dict, Any
3
+ from src.utils.data_loader import load_restaurant_data
4
+ from src.utils.query_parser import QueryParser
5
+ from src.embeddings.embedder import Embedder
6
+ from src.retrieval.vector_store import VectorStore
7
+ from src.retrieval.keyword_filter import filter_restaurants
8
+ from src.retrieval.hybrid_search import HybridRetriever
9
+ from src.generation.llm import LLM
10
+ from langchain_core.embeddings import Embeddings
11
+
12
+ class LangChainEmbeddingWrapper(Embeddings):
13
+ def __init__(self, embedder):
14
+ self.embedder = embedder
15
+
16
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
17
+ return self.embedder.embed(texts).tolist()
18
+
19
+ def embed_query(self, text: str) -> List[float]:
20
+ return self.embedder.embed([text])[0].tolist()
21
+
22
+ class RestaurantChatbot:
23
+ def __init__(self, data_path: str = "data/restaurants.json"):
24
+ """
25
+ Initialize the restaurant chatbot.
26
+
27
+ Args:
28
+ data_path (str): Path to the restaurant JSON file.
29
+ """
30
+ self.df = load_restaurant_data(data_path)
31
+ self.embedder = Embedder()
32
+ self.embedding_wrapper = LangChainEmbeddingWrapper(self.embedder)
33
+ self.vector_store = VectorStore(embedding_function=self.embedding_wrapper)
34
+ self.llm = LLM()
35
+ self.parser = QueryParser(self.df)
36
+
37
+ embeddings = self.embedder.embed(self.df['text'].tolist())
38
+ self.vector_store.add_documents(
39
+ documents=self.df['text'].tolist(),
40
+ embeddings=embeddings.tolist(),
41
+ ids=[str(i) for i in self.df['id']]
42
+ )
43
+
44
+ self.retriever = HybridRetriever(self.df, self.vector_store, self.embedder)
45
+
46
+ def answer(self, query: str) -> Tuple[str, List[Dict[str, Any]]]:
47
+ """
48
+ Process a user query and return a natural, concise response with recommended restaurants.
49
+
50
+ Args:
51
+ query (str): User query.
52
+
53
+ Returns:
54
+ Tuple[str, List[Dict[str, Any]]]: Natural response text and list of recommended restaurants.
55
+ """
56
+ parsed_query = self.parser.parse_query(query)
57
+ filtered_df = filter_restaurants(self.df, parsed_query)
58
+ description = parsed_query["description"] if parsed_query["description"] else query
59
+
60
+ if filtered_df.empty:
61
+ retrieved_docs = self.retriever.retrieve(description, self.df, top_k=3)
62
+ else:
63
+ retrieved_docs = self.retriever.retrieve(description, filtered_df, top_k=3)
64
+
65
+ if not retrieved_docs:
66
+ return "Mình không tìm được nhà hàng nào phù hợp. Bạn thử đổi tiêu chí xem, như mở rộng khoảng cách hoặc loại món ăn nhé!", []
67
+
68
+ # Create context for LLM
69
+ context = "\n".join([
70
+ f"- {doc['name']} ({doc['cuisine']}): {', '.join(doc['dishes'])}. "
71
+ f"Price: {doc['price_range']}, Distance: {doc['distance']} km, Rating: {doc['rating']}. "
72
+ f"Description: {doc['description']}"
73
+ for doc in retrieved_docs
74
+ ])
75
+
76
+ # Prompt for natural, consultant-like response
77
+ prompt = (
78
+ f"Bạn là một người tư vấn nhà hàng thân thiện. Dựa trên truy vấn và danh sách nhà hàng, hãy gợi ý ngắn gọn, tự nhiên, như trò chuyện với bạn bè, giải thích tại sao chọn các nhà hàng này (tập trung vào món ăn, giá, khoảng cách, hoặc đánh giá phù hợp với truy vấn). Không lặp lại truy vấn hoặc dùng ngôn ngữ kỹ thuật. Chỉ dùng thông tin từ danh sách nhà hàng.\n\n"
79
+ f"Truy vấn: {query}\n\n"
80
+ f"Danh sách nhà hàng:\n{context}\n\n"
81
+ f"Phản hồi:"
82
+ )
83
+
84
+ response = self.llm.generate(prompt, max_length=200)
85
+ return response, retrieved_docs
src/embeddings/embedder.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import torch
3
+ import numpy as np
4
+ from typing import List
5
+
6
+ class Embedder:
7
+ def __init__(self, model_name: str = "BAAI/bge-m3"):
8
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ self.model = AutoModel.from_pretrained(model_name)
10
+
11
+ def embed(self, texts: List[str]) -> np.ndarray:
12
+ inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
13
+ with torch.no_grad():
14
+ outputs = self.model(**inputs)
15
+ embeddings = outputs.last_hidden_state[:, 0] # lấy embedding từ CLS token
16
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
17
+ return embeddings.cpu().numpy()
src/generation/llm.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from langchain_core.prompts import PromptTemplate
3
+ import os
4
+ from typing import List
5
+
6
+ class LLM:
7
+ def __init__(self, model_repo: str = "Qwen/Qwen2-1.5B-Instruct",
8
+ local_path: str = "models"):
9
+ """
10
+ Initialize the LLM with Qwen2-1.5B-Instruct using Hugging Face Transformers.
11
+
12
+ Args:
13
+ model_repo (str): Hugging Face repository ID for the model.
14
+ local_path (str): Local directory to store the model.
15
+ """
16
+ os.makedirs(local_path, exist_ok=True)
17
+
18
+ try:
19
+ # Load the model
20
+ self.llm = AutoModelForCausalLM.from_pretrained(
21
+ model_repo,
22
+ device_map="auto", # Automatically map to CPU
23
+ cache_dir=local_path,
24
+ trust_remote_code=True
25
+ )
26
+
27
+ # Load the tokenizer
28
+ self.tokenizer = AutoTokenizer.from_pretrained(
29
+ model_repo,
30
+ cache_dir=local_path,
31
+ trust_remote_code=True
32
+ )
33
+ print(f"Model successfully loaded from {model_repo}")
34
+ except Exception as e:
35
+ raise RuntimeError(
36
+ f"Failed to initialize model from {model_repo}. "
37
+ f"Please ensure the model is available at https://huggingface.co/{model_repo}. "
38
+ f"Error: {str(e)}"
39
+ )
40
+
41
+ # Define prompt template for query parsing (used in query_parser.py)
42
+ self.prompt_template = PromptTemplate(
43
+ template="""Bạn là một trợ lý phân tích truy vấn nhà hàng. Phân tích truy vấn sau và trích xuất các đặc trưng: cuisine, menu, price_range, distance, rating, và description. Chỉ trích xuất các giá trị khớp chính xác với danh sách giá trị hợp lệ. Nếu không tìm thấy giá trị khớp, trả về null (hoặc [] cho menu). Loại bỏ các từ khóa đã trích xuất khỏi description. Trả về kết quả dưới dạng JSON.
44
+
45
+ **Danh sách giá trị hợp lệ**:
46
+ - cuisine: {cuisines}
47
+ - menu: {dishes}
48
+ - price_range: {price_ranges}
49
+
50
+ **Hướng dẫn**:
51
+ - cuisine: Chỉ chọn giá trị từ danh sách cuisine. Ví dụ, "Viet" → "Vietnamese".
52
+ - menu: Chỉ chọn các món khớp chính xác với danh sách menu. Ví dụ, "phở bò" → "phở", "sushi" → [].
53
+ - price_range: Chỉ chọn {price_ranges}. Ví dụ, "cheap" → "low".
54
+ - distance: Trích xuất số km (e.g., "2 km" → 2.0) hoặc từ khóa ["nearby", "close" → 2.0, "far" → 10.0]. Nếu không rõ, trả về null.
55
+ - rating: Trích xuất số (e.g., "4 stars" → 4.0). Nếu không rõ, trả về null.
56
+ - description: Phần còn lại sau khi loại bỏ các từ khóa đã trích xuất. Nếu rỗng, trả về truy vấn gốc.
57
+
58
+ **Truy vấn**: {query}
59
+
60
+ **Định dạng đầu ra**:
61
+ {{
62
+ "cuisine": null | "tên loại ẩm thực",
63
+ "menu": [],
64
+ "price_range": null | "low" | "medium" | "high",
65
+ "distance": null | số km | "nearby" | "close" | "far",
66
+ "rating": null | số,
67
+ "description": "phần mô tả còn lại"
68
+ }}
69
+ """,
70
+ input_variables=["cuisines", "dishes", "price_ranges", "query"]
71
+ )
72
+
73
+ def generate(self, prompt: str, max_length: int = 1000) -> str:
74
+ """
75
+ Generate text using the LLM.
76
+
77
+ Args:
78
+ prompt (str): Input prompt.
79
+ max_length (int): Maximum length of the generated text.
80
+
81
+ Returns:
82
+ str: Generated text.
83
+ """
84
+ try:
85
+ # Apply chat template for instruction-tuned Qwen model
86
+ messages = [{"role": "user", "content": prompt}]
87
+ prompt_with_template = self.tokenizer.apply_chat_template(
88
+ messages, tokenize=False, add_generation_prompt=True
89
+ )
90
+ # Tokenize input prompt
91
+ inputs = self.tokenizer(prompt_with_template, return_tensors="pt").to(self.llm.device)
92
+ # Generate text
93
+ outputs = self.llm.generate(
94
+ **inputs,
95
+ max_new_tokens=max_length,
96
+ temperature=0.7,
97
+ do_sample=True,
98
+ pad_token_id=self.tokenizer.eos_token_id
99
+ )
100
+ # Decode the generated tokens
101
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
102
+ print("Response generated successfully!")
103
+ return response.strip()
104
+ except Exception as e:
105
+ raise RuntimeError(f"Failed to generate response: {str(e)}")
106
+
107
+ def format_query_prompt(self, query: str, cuisines: List[str], dishes: List[str], price_ranges: List[str]) -> str:
108
+ """
109
+ Format the prompt for query parsing using the prompt template.
110
+
111
+ Args:
112
+ query (str): User query.
113
+ cuisines (list): List of valid cuisines.
114
+ dishes (list): List of valid dishes.
115
+ price_ranges (list): List of valid price ranges.
116
+
117
+ Returns:
118
+ str: Formatted prompt.
119
+ """
120
+ return self.prompt_template.format(
121
+ cuisines=cuisines,
122
+ dishes=dishes,
123
+ price_ranges=price_ranges,
124
+ query=query
125
+ )
126
+
127
+ if __name__ == "__main__":
128
+ # Khởi tạo đối tượng LLM với model_repo và local_path
129
+ local_path = 'models'
130
+
131
+ try:
132
+ # Khởi tạo đối tượng LLM
133
+ llm = LLM(local_path=local_path)
134
+
135
+ # Định nghĩa một truy vấn và các tham số cần thiết
136
+ query = "Tìm quán ăn Việt Nam gần đây, giá rẻ với món phở và cơm tấm"
137
+ cuisines = ["Vietnamese", "Chinese", "Italian"]
138
+ dishes = ["phở", "sushi", "pasta", "cơm tấm"]
139
+ price_ranges = ["low", "medium", "high"]
140
+
141
+ # Sử dụng hàm generate để tạo câu trả lời từ truy vấn
142
+ generated_text = llm.generate(query, max_length=300)
143
+
144
+ # In kết quả ra màn hình
145
+ print("Generated text:")
146
+ print(generated_text)
147
+
148
+ except Exception as e:
149
+ print(f"Error: {str(e)}")
src/retrieval/hybrid_search.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/retrieval/hybrid_search.py
2
+
3
+ import numpy as np
4
+ from rank_bm25 import BM25Okapi
5
+ from typing import List, Dict, Any
6
+ import pandas as pd
7
+ from src.embeddings.embedder import Embedder
8
+ from src.retrieval.vector_store import VectorStore
9
+
10
+ class HybridRetriever:
11
+ def __init__(self, df: pd.DataFrame, vector_store: VectorStore, embedder: Embedder, alpha: float = 0.5):
12
+ self.df = df
13
+ self.vector_store = vector_store
14
+ self.embedder = embedder
15
+ self.alpha = alpha
16
+ tokenized_corpus = [doc.lower().split() for doc in df['description']]
17
+ self.bm25 = BM25Okapi(tokenized_corpus)
18
+
19
+ def retrieve(self, query: str, filtered_df: pd.DataFrame, top_k: int = 3) -> List[Dict[str, Any]]:
20
+ filtered_indices = filtered_df.index.tolist()
21
+ filtered_texts = filtered_df['description'].tolist()
22
+ filtered_ids = [str(row['id']) for _, row in filtered_df.iterrows()]
23
+
24
+ if not filtered_texts:
25
+ return []
26
+
27
+ query_embedding = self.embedder.embed([query])[0]
28
+ dense_results = self.vector_store.query(query_embedding, top_k=top_k * 2)
29
+ dense_ids = [id for id in dense_results['ids'][0] if id in filtered_ids]
30
+ dense_scores = [1 - dist for dist, id in zip(dense_results['distances'][0], dense_results['ids'][0]) if id in filtered_ids]
31
+
32
+ tokenized_query = query.lower().split()
33
+ bm25_scores = self.bm25.get_scores(tokenized_query)
34
+ bm25_scores_filtered = [bm25_scores[i] for i in filtered_indices]
35
+ bm25_top_k = np.argsort(bm25_scores_filtered)[::-1][:top_k * 2]
36
+ bm25_ids = [filtered_ids[i] for i in bm25_top_k]
37
+ bm25_scores = [bm25_scores_filtered[i] for i in bm25_top_k]
38
+
39
+ dense_scores = np.array(dense_scores) / np.max(dense_scores) if dense_scores else dense_scores
40
+ bm25_scores = np.array(bm25_scores) / np.max(bm25_scores) if bm25_scores else bm25_scores
41
+
42
+ combined_scores = {}
43
+ for idx, dense_id in enumerate(dense_ids):
44
+ combined_scores[int(dense_id)] = combined_scores.get(int(dense_id), 0) + self.alpha * dense_scores[idx]
45
+ for idx, bm25_id in enumerate(bm25_ids):
46
+ combined_scores[int(bm25_id)] = combined_scores.get(int(bm25_id), 0) + (1 - self.alpha) * bm25_scores[idx]
47
+
48
+ sorted_ids = sorted(combined_scores, key=combined_scores.get, reverse=True)[:top_k]
49
+ return [self.df[self.df['id'] == id].iloc[0].to_dict() for id in sorted_ids]
src/retrieval/keyword_filter.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/retrieval/keyword_filter.py
2
+
3
+ import pandas as pd
4
+ from typing import Dict, Any
5
+
6
+ def filter_restaurants(df: pd.DataFrame, parsed_query: Dict[str, Any]) -> pd.DataFrame:
7
+ """
8
+ Filter restaurants based on extracted features from the query.
9
+
10
+ Args:
11
+ df (pd.DataFrame): DataFrame containing restaurant data.
12
+ parsed_query (Dict[str, Any]): Parsed query with features.
13
+
14
+ Returns:
15
+ pd.DataFrame: Filtered DataFrame.
16
+ """
17
+ filtered_df = df.copy()
18
+
19
+ if parsed_query.get("cuisine"):
20
+ filtered_df = filtered_df[filtered_df["cuisine"].str.lower() == parsed_query["cuisine"].lower()]
21
+
22
+ if parsed_query.get("menu"):
23
+ filtered_df = filtered_df[filtered_df["dishes"].apply(
24
+ lambda dishes: any(item.lower() in [d.lower() for d in dishes] for item in parsed_query["menu"])
25
+ )]
26
+
27
+ if parsed_query.get("price_range"):
28
+ filtered_df = filtered_df[filtered_df["price_range"].str.lower() == parsed_query["price_range"].lower()]
29
+
30
+ distance = parsed_query.get("distance")
31
+ if isinstance(distance, (int, float)):
32
+ filtered_df = filtered_df[filtered_df["distance"] <= distance]
33
+ elif distance in ["nearby", "close"]:
34
+ filtered_df = filtered_df[filtered_df["distance"] <= 2.0]
35
+ elif distance == "far":
36
+ filtered_df = filtered_df[filtered_df["distance"] <= 10.0]
37
+
38
+ if parsed_query.get("rating"):
39
+ filtered_df = filtered_df[filtered_df["rating"] >= parsed_query["rating"]]
40
+
41
+ return filtered_df
src/retrieval/vector_store.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/retrieval/vector_store.py
2
+
3
+ from langchain_community.vectorstores import Chroma
4
+ from langchain_core.documents import Document
5
+ import numpy as np
6
+ from typing import List, Dict, Any
7
+
8
+ class VectorStore:
9
+ def __init__(self, embedding_function):
10
+ self.embedding_function = embedding_function
11
+ self.collection = None
12
+
13
+ def add_documents(self, documents: List[str], embeddings: List[np.ndarray], ids: List[str]):
14
+ langchain_docs = [Document(page_content=doc, metadata={"id": id}) for doc, id in zip(documents, ids)]
15
+ self.collection = Chroma.from_documents(
16
+ documents=langchain_docs,
17
+ embedding=self.embedding_function,
18
+ ids=ids,
19
+ persist_directory="./chroma_db"
20
+ )
21
+ self.collection.persist()
22
+
23
+ def query(self, query_embedding: np.ndarray, top_k: int = 5) -> Dict[str, Any]:
24
+ results = self.collection.similarity_search_by_vector(
25
+ embedding=query_embedding,
26
+ k=top_k
27
+ )
28
+ ids = [doc.metadata["id"] for doc in results]
29
+ distances = [1 - np.dot(query_embedding, doc.vector) / (np.linalg.norm(query_embedding) * np.linalg.norm(doc.vector))
30
+ if hasattr(doc, "vector") else 1.0 for doc in results]
31
+ return {
32
+ "ids": [ids],
33
+ "distances": [distances]
34
+ }
src/utils/data_loader.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/utils/data_loader.py
2
+
3
+ import pandas as pd
4
+ import json
5
+
6
+ def load_restaurant_data(file_path: str) -> pd.DataFrame:
7
+ """
8
+ Load restaurant data from JSON file into a DataFrame.
9
+
10
+ Args:
11
+ file_path (str): Path to the JSON file.
12
+
13
+ Returns:
14
+ pd.DataFrame: DataFrame containing restaurant data.
15
+ """
16
+ with open(file_path, 'r', encoding='utf-8') as f:
17
+ data = json.load(f)
18
+
19
+ df = pd.DataFrame(data)
20
+ # Create a text field for embedding and BM25
21
+ df['text'] = df.apply(
22
+ lambda row: f"{row['name']} ({row['cuisine']}): {', '.join(row['dishes'])}. "
23
+ f"Price: {row['price_range']}, Distance: {row['distance']} km, "
24
+ f"Rating: {row['rating']}. Description: {row['description']}",
25
+ axis=1
26
+ )
27
+ return df
28
+ if __name__ == "__main__":
29
+ print(load_restaurant_data("./data/restaurants.json"))
src/utils/query_parser.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/utils/query_parser.py
2
+ import pandas as pd
3
+ import json
4
+ from typing import Dict, Any
5
+ from src.generation.llm import LLM
6
+
7
+ class QueryParser:
8
+ def __init__(self, df: pd.DataFrame):
9
+ """
10
+ Initialize the query parser with restaurant data.
11
+
12
+ Args:
13
+ df (pd.DataFrame): DataFrame containing restaurant data.
14
+ """
15
+ self.llm = LLM()
16
+ self.df = df
17
+ self.valid_cuisines = sorted(self.df['cuisine'].unique().tolist())
18
+ self.valid_price_ranges = sorted(self.df['price_range'].unique().tolist())
19
+ self.valid_dishes = sorted(set([dish for dishes in self.df['dishes'] for dish in dishes]))
20
+
21
+ def parse_query(self, query: str) -> Dict[str, Any]:
22
+ """
23
+ Parse the query to extract features.
24
+
25
+ Args:
26
+ query (str): User query.
27
+
28
+ Returns:
29
+ Dict[str, Any]: Parsed features.
30
+ """
31
+ # Format prompt using LLM's prompt template
32
+ prompt = self.llm.format_query_prompt(
33
+ query=query,
34
+ cuisines=self.valid_cuisines,
35
+ dishes=self.valid_dishes,
36
+ price_ranges=self.valid_price_ranges
37
+ )
38
+
39
+ # Generate response
40
+ response = self.llm.generate(prompt)
41
+
42
+ # Parse JSON response
43
+ try:
44
+ json_start = response.find("{")
45
+ json_end = response.rfind("}") + 1
46
+ parsed = json.loads(response[json_start:json_end])
47
+ return parsed
48
+ except json.JSONDecodeError:
49
+ return {
50
+ "cuisine": None,
51
+ "menu": [],
52
+ "price_range": None,
53
+ "distance": None,
54
+ "rating": None,
55
+ "description": query
56
+ }
57
+
58
+
59
+ # Quick test block for QueryParser
60
+ if __name__ == "__main__":
61
+ import pandas as pd
62
+
63
+ sample_data = {
64
+ "cuisine": ["Italian", "Japanese"],
65
+ "price_range": ["$", "$$"],
66
+ "dishes": [["pizza", "pasta"], ["sushi", "ramen"]]
67
+ }
68
+ df = pd.DataFrame(sample_data)
69
+ parser = QueryParser(df)
70
+
71
+ user_query = "I want cheap sushi"
72
+ result = parser.parse_query(user_query)
73
+
74
+ print("Parsed Query Result:")
75
+ print(result)
76
+
utils/data_loader.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/utils/data_loader.py
2
+
3
+ import pandas as pd
4
+ import json
5
+
6
+ def load_restaurant_data(file_path: str) -> pd.DataFrame:
7
+ """
8
+ Load restaurant data from JSON file into a DataFrame.
9
+
10
+ Args:
11
+ file_path (str): Path to the JSON file.
12
+
13
+ Returns:
14
+ pd.DataFrame: DataFrame containing restaurant data.
15
+ """
16
+ with open(file_path, 'r', encoding='utf-8') as f:
17
+ data = json.load(f)
18
+
19
+ df = pd.DataFrame(data)
20
+ # Create a text field for embedding and BM25
21
+ df['text'] = df.apply(
22
+ lambda row: f"{row['name']} ({row['cuisine']}): {', '.join(row['dishes'])}. "
23
+ f"Price: {row['price_range']}, Distance: {row['distance']} km, "
24
+ f"Rating: {row['rating']}. Description: {row['description']}",
25
+ axis=1
26
+ )
27
+ return df
28
+ if __name__ == "__main__":
29
+ print(load_restaurant_data("./data/restaurants.json"))
utils/query_parser.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/utils/query_parser.py
2
+ import pandas as pd
3
+ import json
4
+ from typing import Dict, Any
5
+ from src.generation.llm import LLM
6
+
7
+ class QueryParser:
8
+ def __init__(self, df: pd.DataFrame):
9
+ """
10
+ Initialize the query parser with restaurant data.
11
+
12
+ Args:
13
+ df (pd.DataFrame): DataFrame containing restaurant data.
14
+ """
15
+ self.llm = LLM()
16
+ self.df = df
17
+ self.valid_cuisines = sorted(self.df['cuisine'].unique().tolist())
18
+ self.valid_price_ranges = sorted(self.df['price_range'].unique().tolist())
19
+ self.valid_dishes = sorted(set([dish for dishes in self.df['dishes'] for dish in dishes]))
20
+
21
+ def parse_query(self, query: str) -> Dict[str, Any]:
22
+ """
23
+ Parse the query to extract features.
24
+
25
+ Args:
26
+ query (str): User query.
27
+
28
+ Returns:
29
+ Dict[str, Any]: Parsed features.
30
+ """
31
+ # Format prompt using LLM's prompt template
32
+ prompt = self.llm.format_query_prompt(
33
+ query=query,
34
+ cuisines=self.valid_cuisines,
35
+ dishes=self.valid_dishes,
36
+ price_ranges=self.valid_price_ranges
37
+ )
38
+
39
+ # Generate response
40
+ response = self.llm.generate(prompt)
41
+
42
+ # Parse JSON response
43
+ try:
44
+ json_start = response.find("{")
45
+ json_end = response.rfind("}") + 1
46
+ parsed = json.loads(response[json_start:json_end])
47
+ return parsed
48
+ except json.JSONDecodeError:
49
+ return {
50
+ "cuisine": None,
51
+ "menu": [],
52
+ "price_range": None,
53
+ "distance": None,
54
+ "rating": None,
55
+ "description": query
56
+ }
57
+
58
+
59
+ # Quick test block for QueryParser
60
+ if __name__ == "__main__":
61
+ import pandas as pd
62
+
63
+ sample_data = {
64
+ "cuisine": ["Italian", "Japanese"],
65
+ "price_range": ["$", "$$"],
66
+ "dishes": [["pizza", "pasta"], ["sushi", "ramen"]]
67
+ }
68
+ df = pd.DataFrame(sample_data)
69
+ parser = QueryParser(df)
70
+
71
+ user_query = "I want cheap sushi"
72
+ result = parser.parse_query(user_query)
73
+
74
+ print("Parsed Query Result:")
75
+ print(result)
76
+