Futuresony commited on
Commit
242f03e
·
verified ·
1 Parent(s): 7afa3d4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +630 -0
app.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is my app.py
2
+
3
+ import os
4
+ import torch
5
+ import re
6
+ import warnings
7
+ import time
8
+ import json
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
10
+ from sentence_transformers import SentenceTransformer, util
11
+ import gspread
12
+ from google.auth import default
13
+ from tqdm import tqdm
14
+ from duckduckgo_search import DDGS
15
+ import spacy
16
+ from pathlib import Path
17
+ import base64
18
+
19
+ # Suppress warnings
20
+ warnings.filterwarnings("ignore", category=UserWarning)
21
+
22
+ # --- Configuration ---
23
+ SHEET_ID = "19ipxC2vHYhpXCefpxpIkpeYdI43a1Ku2kYwecgUULIw" # Your Google Sheet ID
24
+ HF_TOKEN = os.getenv("HF_TOKEN") # Get Hugging Face token from Space Secrets
25
+ GOOGLE_SERVICE_ACCOUNT_KEY_BASE64 = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY_BASE64")
26
+
27
+ # Changed model_id to Gemma 2B for CPU
28
+ model_id = "google/gemma-2b" # Using Gemma 2B
29
+
30
+ # --- Constants for Prompting and Validation ---
31
+ SEARCH_MARKER = "ACTION: SEARCH:"
32
+ BUSINESS_LOOKUP_MARKER = "ACTION: LOOKUP_BUSINESS_INFO:"
33
+ ANSWER_DIRECTLY_MARKER = "ACTION: ANSWER_DIRECTLY:"
34
+ BUSINESS_LOOKUP_VALIDATION_THRESHOLD = 0.6
35
+ SEARCH_VALIDATION_THRESHOLD = 0.6
36
+ PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD = 0.5
37
+
38
+ # --- Global variables to load once ---
39
+ tokenizer = None
40
+ model = None
41
+ nlp = None # SpaCy model
42
+ embedder = None # Sentence Transformer
43
+ data = [] # Google Sheet data
44
+ descriptions = []
45
+ embeddings = torch.tensor([]) # Google Sheet embeddings
46
+
47
+ # --- Loading Functions (Run once on startup) ---
48
+
49
+ def load_spacy_model():
50
+ """Loads or downloads the spaCy model."""
51
+ model_name = "en_core_web_sm"
52
+ try:
53
+ print(f"Loading spaCy model '{model_name}'...")
54
+ nlp_model = spacy.load(model_name)
55
+ print(f"SpaCy model '{model_name}' loaded.")
56
+ return nlp_model
57
+ except OSError:
58
+ print(f"SpaCy model '{model_name}' not found locally. Attempting download...")
59
+ # For HF Spaces, ensuring it's in requirements.txt is key.
60
+ # We'll assume requirements.txt handles installation, and try loading again.
61
+ print("Assuming 'en_core_web_sm' is installed via requirements.txt. Attempting to load...")
62
+ try:
63
+ nlp_model = spacy.load(model_name)
64
+ print(f"SpaCy model '{model_name}' loaded after assumed installation.")
65
+ return nlp_model
66
+ except Exception as e:
67
+ print(f"Failed to load spaCy model '{model_name}' after assumed installation: {e}")
68
+ print("SpaCy will not be available.")
69
+ return None # Return None if loading fails
70
+
71
+ def load_sentence_transformer():
72
+ """Loads the Sentence Transformer model."""
73
+ print("Loading Sentence Transformer...")
74
+ try:
75
+ embedder_model = SentenceTransformer("all-MiniLM-L6-v2")
76
+ print("Sentence Transformer loaded.")
77
+ return embedder_model
78
+ except Exception as e:
79
+ print(f"Error loading Sentence Transformer: {e}")
80
+ return None
81
+
82
+ def load_google_sheet_data(sheet_id, service_account_key_base64):
83
+ """Authenticates and loads data from Google Sheet."""
84
+ print(f"Attempting to load Google Sheet data from ID: {sheet_id}")
85
+ if not service_account_key_base64:
86
+ print("Warning: GOOGLE_SERVICE_ACCOUNT_KEY_BASE64 secret is not set. Cannot access Google Sheets.")
87
+ return [], [], torch.tensor([])
88
+
89
+ try:
90
+ # Decode the base64 key
91
+ key_bytes = base64.b64decode(service_account_key_base64)
92
+ key_dict = json.loads(key_bytes)
93
+
94
+ # Authenticate using the service account key
95
+ # Use service_account.Credentials.from_service_account_info directly
96
+ from google.oauth2 import service_account
97
+ creds = service_account.Credentials.from_service_account_info(key_dict)
98
+ client = gspread.authorize(creds)
99
+
100
+ sheet = client.open_by_key(sheet_id).sheet1
101
+ print(f"Successfully opened Google Sheet with ID: {sheet_id}")
102
+ sheet_data = sheet.get_all_records()
103
+
104
+ if not sheet_data:
105
+ print(f"Warning: No data records found in Google Sheet with ID: {sheet_id}")
106
+ return [], [], torch.tensor([])
107
+
108
+ filtered_data = [row for row in sheet_data if row.get('Service') and row.get('Description')]
109
+ if not filtered_data:
110
+ print("Warning: Filtered data is empty after checking for 'Service' and 'Description'.")
111
+ return [], [], torch.tensor([])
112
+
113
+ if not filtered_data or 'Service' not in filtered_data[0] or 'Description' not in filtered_data[0]:
114
+ print("Error: Filtered Google Sheet data must contain 'Service' and 'Description' columns.")
115
+ return [], [], torch.tensor([])
116
+
117
+ services = [row["Service"] for row in filtered_data]
118
+ descriptions = [row["Description"] for row in filtered_data]
119
+ print(f"Loaded {len(descriptions)} entries from Google Sheet for embedding.")
120
+
121
+ # embeddings will be encoded after embedder is loaded
122
+ return filtered_data, descriptions, None # Return descriptions, embeddings encoded later
123
+
124
+ except gspread.exceptions.SpreadsheetNotFound:
125
+ print(f"Error: Google Sheet with ID '{sheet_id}' not found.")
126
+ print("Please check the SHEET_ID and ensure the service account has access.")
127
+ return [], [], torch.tensor([])
128
+ except Exception as e:
129
+ print(f"An error occurred while accessing the Google Sheet: {e}")
130
+ return [], [], torch.tensor([])
131
+
132
+
133
+ def load_llm_model(model_id, hf_token):
134
+ """Loads the LLM in full precision (for CPU).""" # Modified description
135
+ print(f"Loading model {model_id} in full precision...")
136
+ if not hf_token:
137
+ print("Error: HF_TOKEN secret is not set. Cannot load Hugging Face model.")
138
+ return None, None
139
+
140
+ try:
141
+ llm_tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
142
+ if llm_tokenizer.pad_token is None:
143
+ llm_tokenizer.pad_token = llm_tokenizer.eos_token
144
+
145
+ # Load the model without quantization config
146
+ llm_model = AutoModelForCausalLM.from_pretrained(
147
+ model_id,
148
+ token=hf_token,
149
+ device_map="auto", # This will likely map to 'cpu'
150
+ # Removed quantization_config=bnb_config
151
+ )
152
+
153
+ print(f"Model {model_id} loaded in full precision.")
154
+ return llm_model, llm_tokenizer
155
+
156
+ except Exception as e:
157
+ print(f"Error loading model {model_id}: {e}")
158
+ # Removed specific bitsandbytes message
159
+ print("Please ensure transformers, trl, peft, and accelerate are installed.")
160
+ print("Check your Hugging Face token.")
161
+ # Do not raise, return None to allow app to start without LLM
162
+ return None, None
163
+
164
+ # --- Load all assets on startup ---
165
+ print("Loading assets...")
166
+ nlp = load_spacy_model()
167
+ embedder = load_sentence_transformer()
168
+ data, descriptions, _ = load_google_sheet_data(SHEET_ID, GOOGLE_SERVICE_ACCOUNT_KEY_BASE64) # Load data and descriptions first
169
+
170
+ if embedder and descriptions:
171
+ print("Encoding Google Sheet descriptions...")
172
+ try:
173
+ embeddings = embedder.encode(descriptions, convert_to_tensor=True)
174
+ print("Encoding complete.")
175
+ except Exception as e:
176
+ print(f"Error during embedding: {e}")
177
+ embeddings = torch.tensor([]) # Ensure embeddings is an empty tensor on error
178
+ else:
179
+ print("Skipping embedding due to missing embedder or descriptions.")
180
+ embeddings = torch.tensor([]) # Ensure embeddings is an empty tensor if no descriptions
181
+
182
+ model, tokenizer = load_llm_model(model_id, HF_TOKEN)
183
+
184
+ # Check if essential components loaded
185
+ if not model or not tokenizer or not embedder or not nlp:
186
+ print("\nERROR: Essential components failed to load. The application may not function correctly.")
187
+ if not model: print("- LLM Model failed to load.")
188
+ if not tokenizer: print("- LLM Tokenizer failed to load.")
189
+ if not embedder: print("- Sentence Embedder failed to load.")
190
+ if not nlp: print("- spaCy Model failed to load.")
191
+ # Continue, but the main inference function will need checks
192
+
193
+ # --- Helper Functions (from your script) ---
194
+
195
+ # Function to perform DuckDuckGo Search and return results with URLs
196
+ def perform_duckduckgo_search(query, max_results=3):
197
+ """
198
+ Performs a search using DuckDuckGo and returns a list of dictionaries.
199
+ Includes a delay to avoid rate limits.
200
+ """
201
+ search_results_list = []
202
+ try:
203
+ time.sleep(1) # Add a delay before each search
204
+ with DDGS() as ddgs:
205
+ for r in ddgs.text(query, max_results=max_results):
206
+ search_results_list.append(r) # Append the dictionary directly
207
+ except Exception as e:
208
+ print(f"Error during DuckDuckgo search for '{query}': {e}")
209
+ return []
210
+ return search_results_list
211
+
212
+ # Function to retrieve relevant business info
213
+ def retrieve_business_info(query, data, embeddings, embedder, threshold=0.50):
214
+ """
215
+ Retrieves relevant business information based on query similarity.
216
+ Returns a dictionary if a match above threshold is found, otherwise None.
217
+ Also returns the similarity score.
218
+ Uses the global embedder, data, and embeddings.
219
+ """
220
+ if not data or (embeddings is None or embeddings.numel() == 0) or embedder is None:
221
+ print("Skipping business info retrieval: Data, embeddings or embedder not available.")
222
+ return None, 0.0
223
+
224
+ try:
225
+ user_embedding = embedder.encode(query, convert_to_tensor=True)
226
+ cos_scores = util.cos_sim(user_embedding, embeddings)[0]
227
+ best_score = cos_scores.max().item()
228
+
229
+ if best_score > threshold:
230
+ best_match_idx = cos_scores.argmax().item()
231
+ best_match = data[best_match_idx]
232
+ return best_match, best_score
233
+ else:
234
+ return None, best_score
235
+ except Exception as e:
236
+ print(f"Error during business information retrieval: {e}")
237
+ return None, 0.0
238
+
239
+ # Function to split user query into potential sub-queries using spaCy
240
+ def split_query(query):
241
+ """Splits a user query into potential sub-queries using spaCy."""
242
+ if nlp is None:
243
+ print("SpaCy model not loaded. Cannot split query.")
244
+ return [query] # Return original query if nlp is not available
245
+
246
+ try:
247
+ doc = nlp(query)
248
+ sentences = [sent.text.strip() for sent in doc.sents]
249
+ if len(sentences) == 1:
250
+ parts = re.split(r',| and (who|what|where|when|why|how|is|are|can|tell me about)|;', query, flags=re.IGNORECASE)
251
+ parts = [part.strip() for part in parts if part is not None and part.strip()]
252
+ if len(parts) <= 1:
253
+ return [query]
254
+ return parts
255
+ return sentences
256
+ except Exception as e:
257
+ print(f"Error during query splitting: {e}")
258
+ return [query] # Return original query on error
259
+
260
+ # --- Pass 1 System Prompt ---
261
+ pass1_instructions_action = """You are a helpful assistant for a business. Your primary goal in this first step is to analyze the user's query and decide which actions are needed to answer it.
262
+
263
+ You have analyzed the user's query and potentially broken it down into parts. For each part, a preliminary check was done to see if it matches known business information. The results of this check are provided below.
264
+
265
+ {business_check_summary}
266
+
267
+ Based on the user's query and the results of the business info check for each part, identify if you need to perform actions.
268
+
269
+ Output one or more actions, each on a new line, in the format:
270
+ ACTION: [ACTION_TYPE]: [Argument/Query for the action]
271
+
272
+ Possible actions:
273
+ 1. **LOOKUP_BUSINESS_INFO**: If a part of the query asks about the business's services, prices, availability, or individuals mentioned in the business context, *and* the business info check for that part indicates a high relevance ({PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f} or higher). The argument should be the specific phrase or name to look up.
274
+ 2. **SEARCH**: If a part of the query asks for current external information (e.g., current events, real-time data, general facts not in business info), *or* if a part that seems like it could be business info did *not* have a high relevance score in the preliminary check (below {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f}). The argument should be the precise search query.
275
+ 3. **ANSWER_DIRECTLY**: If the overall query is a simple greeting or can be answered from your general knowledge without lookup or search, *and* the business info check results indicate low relevance for all parts. The argument should be the direct answer here.
276
+
277
+ **Crucially:**
278
+ - **Prioritize LOOKUP_BUSINESS_INFO** for any part of the query where the preliminary business info check score was {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f} or higher.
279
+ - Use **SEARCH** for parts about external information or where the business info check score was below {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f}.
280
+ - If a part of the query is clearly external (like asking about current events or famous people) even if its business info score wasn't zero, you should likely use SEARCH for it.
281
+ - Do NOT output any other text besides the ACTION lines.
282
+ - If the results suggest a direct answer is sufficient, use ANSWER_DIRECTLY.
283
+
284
+ Now, analyze the following user query, considering the business info check results provided above, and output the required actions:
285
+ """
286
+
287
+ # --- Pass 2 System Prompt ---
288
+ pass2_instructions_synthesize = """You are a helpful assistant for a business. You have been provided with the original user query, relevant Business Information (if found), and results from external searches (if performed).
289
+
290
+ Your task is to synthesize ALL the provided information to answer the user's original question concisely and accurately.
291
+
292
+ **Prioritize Business Information** for details about the business, its services, or individuals mentioned within that context.
293
+ Use the Search Results for current external information that was requested.
294
+ If information for a specific part of the question was not found in either Business Information or Search Results, use your general knowledge if possible, or state that the information could not be found.
295
+
296
+ Synthesize the information into a natural language response. Do NOT copy and paste raw context or strings like 'Business Information:' or 'SEARCH RESULTS:' or 'ACTION:' or the raw user query.
297
+
298
+ After your answer, generate a few concise follow-up questions that a user might ask based on the previous turn's conversation and your response. List these questions clearly at the end of your response.
299
+ When search results were used to answer the question, list the URLs from the search results you used under a "Sources:" heading at the very end.
300
+ """
301
+
302
+ # --- Main Inference Function for Gradio ---
303
+ # This function will be called every time the user submits a query
304
+ # chat_history is now a parameter managed by Gradio's State
305
+ def respond(user_input, chat_history):
306
+ """
307
+ Processes user input, performs actions (lookup/search), and generates a response.
308
+ Manages chat history within Gradio state.
309
+ """
310
+ # Check if models loaded successfully
311
+ if model is None or tokenizer is None or embedder is None or nlp is None:
312
+ return "", chat_history + [(user_input, "Sorry, the application failed to load necessary components. Please try again later or contact the administrator.")] # Return empty string for input, updated history
313
+
314
+ original_user_input = user_input
315
+
316
+ # Initialize action results containers for this turn
317
+ search_results_dicts = []
318
+ business_lookup_results_formatted = []
319
+ response_pass1_raw = "" # To store the raw actions generated by Pass 1
320
+
321
+ # --- Pre-Pass 1: Programmatic Business Info Check for Query Parts ---
322
+ query_parts = split_query(original_user_input)
323
+ business_check_results = []
324
+ overall_pre_pass1_score = 0.0
325
+
326
+ print("\n--- Processing new user query ---")
327
+ print(f"User: {user_input}")
328
+ print("Performing programmatic business info check on query parts...")
329
+
330
+ if query_parts:
331
+ for i, part in enumerate(query_parts):
332
+ match, score = retrieve_business_info(part, data, embeddings, embedder, threshold=0.0)
333
+ business_check_results.append({"part": part, "score": score, "match": match})
334
+ print(f"- Part '{part}': Score {score:.4f}")
335
+ overall_pre_pass1_score = max(overall_pre_pass1_score, score)
336
+ else:
337
+ match, score = retrieve_business_info(original_user_input, data, embeddings, embedder, threshold=0.0)
338
+ business_check_results.append({"part": original_user_input, "score": score, "match": match})
339
+ print(f"- Part '{original_user_input}': Score {score:.4f}")
340
+ overall_pre_pass1_score = score
341
+
342
+ is_likely_direct_answer = overall_pre_pass1_score < PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD and len(query_parts) <= 2
343
+
344
+ # Format business check summary for Pass 1 prompt
345
+ business_check_summary = "Business Info Check Results for Query Parts:\n"
346
+ if business_check_results:
347
+ for result in business_check_results:
348
+ status = "High Relevance" if result['score'] >= PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD else "Low Relevance"
349
+ business_check_summary += f"- Part '{result['part']}': Score {result['score']:.4f} ({status})\n"
350
+ else:
351
+ business_check_summary += "- No parts identified or check skipped.\n"
352
+ business_check_summary += "\n"
353
+
354
+ # --- Pass 1: Action Identification (if not direct answer) ---
355
+ requested_actions = []
356
+ answer_directly_provided = None
357
+
358
+ if is_likely_direct_answer:
359
+ print("Programmatically determined likely direct answer.")
360
+ response_pass1_raw = f"ACTION: ANSWER_DIRECTLY: " # Signal Pass 2
361
+
362
+ else:
363
+ pass1_user_message_content = pass1_instructions_action.format(
364
+ business_check_summary=business_check_summary,
365
+ PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD=PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD # Pass threshold to prompt
366
+ ) + "\n\nUser Query: " + user_input
367
+
368
+ # Create a temporary history for Pass 1 focusing only on the current turn's user query and instructions
369
+ temp_chat_history_pass1 = [{"role": "user", "content": pass1_user_message_content}]
370
+
371
+ try:
372
+ prompt_pass1 = tokenizer.apply_chat_template(
373
+ temp_chat_history_pass1,
374
+ tokenize=False,
375
+ add_generation_prompt=True
376
+ )
377
+ # print("\n--- Pass 1 Prompt ---") # Debug print
378
+ # print(prompt_pass1)
379
+ # print("---------------------")
380
+
381
+ generation_config_pass1 = GenerationConfig(
382
+ max_new_tokens=200,
383
+ do_sample=False,
384
+ temperature=0.1,
385
+ eos_token_id=tokenizer.eos_token_id,
386
+ pad_token_id=tokenizer.pad_token_id,
387
+ use_cache=True
388
+ )
389
+
390
+ input_ids_pass1 = tokenizer(prompt_pass1, return_tensors="pt").input_ids # Removed .to(model.device) as device_map="auto" handles it
391
+ if model and input_ids_pass1.numel() > 0: # Added check for model
392
+ outputs_pass1 = model.generate(
393
+ input_ids=input_ids_pass1,
394
+ generation_config=generation_config_pass1,
395
+ )
396
+ prompt_length_pass1 = input_ids_pass1.shape[1]
397
+ if outputs_pass1.shape[1] > prompt_length_pass1:
398
+ generated_tokens_pass1 = outputs_pass1[0, prompt_length_pass1:]
399
+ response_pass1_raw = tokenizer.decode(generated_tokens_pass1, skip_special_tokens=True).strip()
400
+ else:
401
+ response_pass1_raw = "" # No actions generated
402
+ else:
403
+ response_pass1_raw = "" # Empty input or model not loaded
404
+
405
+ # print("\n--- Raw Pass 1 Response ---") # Debug print
406
+ # print(response_pass1_raw)
407
+ # print("--------------------------")
408
+
409
+
410
+ except Exception as e:
411
+ print(f"Error during Pass 1 (Action Identification): {e}")
412
+ # If Pass 1 fails, fallback to attempting a direct answer in Pass 2
413
+ response_pass1_raw = f"ACTION: ANSWER_DIRECTLY: Error in Pass 1 - {e}"
414
+
415
+
416
+ # --- Parse Model's Requested Actions with Validation ---
417
+ # Always parse even if flagged for direct answer to handle potential Pass 1 errors
418
+ if response_pass1_raw:
419
+ lines = response_pass1_raw.strip().split('\n')
420
+ for line in lines:
421
+ line = line.strip()
422
+ if line.startswith(SEARCH_MARKER):
423
+ query = line[len(SEARCH_MARKER):].strip()
424
+ if query:
425
+ # Validate SEARCH Action
426
+ _, score = retrieve_business_info(query, data, embeddings, embedder, threshold=0.0)
427
+ if score < SEARCH_VALIDATION_THRESHOLD:
428
+ requested_actions.append(("SEARCH", query))
429
+ print(f"Validated Search Action for '{query}' (Score: {score:.4f})")
430
+ else:
431
+ print(f"Rejected Search Action for '{query}' (Score: {score:.4f}) - Too similar to business data.")
432
+ elif line.startswith(BUSINESS_LOOKUP_MARKER):
433
+ query = line[len(BUSINESS_LOOKUP_MARKER):].strip()
434
+ if query:
435
+ # Validate Business Lookup Query
436
+ match, score = retrieve_business_info(query, data, embeddings, embedder, threshold=0.0) # Use low threshold for scoring
437
+ if score > BUSINESS_LOOKUP_VALIDATION_THRESHOLD:
438
+ requested_actions.append(("LOOKUP_BUSINESS_INFO", query))
439
+ print(f"Validated Business Lookup Action for '{query}' (Score: {score:.4f})")
440
+ else:
441
+ print(f"Rejected Business Lookup Action for '{query}' (Score: {score:.4f}) - Below validation threshold.")
442
+ elif line.startswith(ANSWER_DIRECTLY_MARKER):
443
+ answer = line[len(ANSWER_DIRECTLY_MARKER):].strip()
444
+ answer_directly_provided = answer if answer else original_user_input # Use explicit answer if provided, else original query hint
445
+ requested_actions = [] # Clear other actions if DIRECT_ANSWER is given
446
+ break # Exit action parsing loop
447
+
448
+ # --- Execute Actions (Search and Lookup) ---
449
+ # Only execute actions if ANSWER_DIRECTLY was NOT the primary outcome of Pass 1
450
+ # and there are validated requested actions.
451
+ context_for_pass2 = ""
452
+
453
+ if requested_actions:
454
+ print("Executing requested actions...")
455
+ for action_type, query in requested_actions:
456
+ if action_type == "SEARCH":
457
+ print(f"Performing search for: '{query}'")
458
+ results = perform_duckduckgo_search(query)
459
+ if results:
460
+ search_results_dicts.extend(results)
461
+ print(f"Found {len(results)} search results.")
462
+ else:
463
+ print(f"No search results found for '{query}'.")
464
+
465
+ elif action_type == "LOOKUP_BUSINESS_INFO":
466
+ print(f"Performing business info lookup for: '{query}'")
467
+ match, score = retrieve_business_info(query, data, embeddings, embedder, threshold=retrieve_business_info.__defaults__[0]) # Use default threshold for retrieval
468
+ print(f"Actual lookup score for '{query}': {score:.4f} (Threshold: {retrieve_business_info.__defaults__[0]})")
469
+ if match:
470
+ formatted_match = f"""Service: {match.get('Service', 'N/A')}
471
+ Description: {match.get('Description', 'N/A')}
472
+ Price: {match.get('Price', 'N/A')}
473
+ Available: {match.get('Available', 'N/A')}"""
474
+ business_lookup_results_formatted.append(formatted_match)
475
+ print(f"Found business info match.")
476
+ else:
477
+ print(f"No business info match found for '{query}' at threshold {retrieve_business_info.__defaults__[0]}.")
478
+
479
+ # --- Prepare Context for Pass 2 based on executed actions ---
480
+ if business_lookup_results_formatted:
481
+ context_for_pass2 += "Business Information (Use this for questions about the business):\n"
482
+ context_for_pass2 += "\n---\n".join(business_lookup_results_formatted)
483
+ context_for_pass2 += "\n\n"
484
+
485
+ if search_results_dicts:
486
+ context_for_pass2 += "SEARCH RESULTS (Use this for current external information):\n"
487
+ aggregated_search_results_formatted = []
488
+ for result in search_results_dicts:
489
+ aggregated_search_results_formatted.append(f"Title: {result.get('title', 'N/A')}\nSnippet: {result.get('body', 'N/A')}\nURL: {result.get('href', 'N/A')}")
490
+ context_for_pass2 += "\n---\n".join(aggregated_search_results_formatted) + "\n\n"
491
+
492
+ if requested_actions and not business_lookup_results_formatted and not search_results_dicts:
493
+ context_for_pass2 = "Note: No relevant information was found in Business Information or via Search for your query."
494
+ print("Note: No results were found for the requested actions.")
495
+
496
+ # If ANSWER_DIRECTLY was determined (either programmatically or by Pass 1 model output)
497
+ if answer_directly_provided is not None:
498
+ print(f"Handling as direct answer: {answer_directly_provided}")
499
+ # Provide a simple context indicating it's a direct answer scenario
500
+ context_for_pass2 = "Note: This query is a simple request or greeting."
501
+ if answer_directly_provided != original_user_input and answer_directly_provided != "":
502
+ context_for_pass2 += f" Initial suggestion from action step: {answer_directly_provided}"
503
+ # Ensure no search/lookup results are included if it was flagged as direct answer
504
+ search_results_dicts = []
505
+ business_lookup_results_formatted = []
506
+
507
+
508
+ # If no actions were requested or direct answer flagged, and no results found...
509
+ # This handles cases where Pass 1 failed or generated nothing useful
510
+ if not requested_actions and answer_directly_provided is None:
511
+ if response_pass1_raw.strip():
512
+ print("Warning: Pass 1 did not result in valid actions or a direct answer.")
513
+ context_for_pass2 = f"Error: Could not determine actions from Pass 1 response: '{response_pass1_raw}'."
514
+ else:
515
+ print("Warning: Pass 1 generated an empty response.")
516
+ context_for_pass2 = "Error: Pass 1 generated an empty response."
517
+ # In this case, we will still try Pass 2 with the limited context
518
+
519
+
520
+ # --- Pass 2: Synthesize and Respond ---
521
+ final_response = "Sorry, I couldn't generate a response." # Default response on error
522
+
523
+ if model is not None and tokenizer is not None:
524
+ pass2_user_message_content = pass2_instructions_synthesize + "\n\nOriginal User Query: " + original_user_input + "\n\n" + context_for_pass2
525
+
526
+ # --- Chat History Management for Pass 2 ---
527
+ # Gradio's chat history state is [(User1, Bot1), (User2, Bot2), ...]
528
+ # We need to format the history correctly for the model template
529
+ # The Pass 2 prompt should build upon the *actual* conversation history, not just the Pass 2 context message.
530
+ # Let's build the chat history for the model template
531
+ model_chat_history = []
532
+ for user_msg, bot_msg in chat_history:
533
+ model_chat_history.append({"role": "user", "content": user_msg})
534
+ model_chat_history.append({"role": "assistant", "content": bot_msg})
535
+
536
+ # Add the *current* user query and the Pass 2 specific content as the latest turn
537
+ # The Pass 2 instructions and context are part of the *current* user turn's input to the model
538
+ model_chat_history.append({"role": "user", "content": pass2_user_message_content})
539
+
540
+ try:
541
+ prompt_pass2 = tokenizer.apply_chat_template(
542
+ model_chat_history,
543
+ tokenize=False,
544
+ add_generation_prompt=True # Add the assistant prompt token to start the response
545
+ )
546
+ # print("\n--- Pass 2 Prompt ---") # Debug print
547
+ # print(prompt_pass2)
548
+ # print("---------------------")
549
+
550
+
551
+ generation_config_pass2 = GenerationConfig(
552
+ max_new_tokens=1500, # Generate a longer response
553
+ do_sample=True,
554
+ temperature=0.7,
555
+ top_k=50,
556
+ top_p=0.95,
557
+ repetition_penalty=1.1,
558
+ eos_token_id=tokenizer.eos_token_id,
559
+ pad_token_id=tokenizer.pad_token_id,
560
+ use_cache=True
561
+ )
562
+
563
+ input_ids_pass2 = tokenizer(prompt_pass2, return_tensors="pt").input_ids # Removed .to(model.device)
564
+ if model and input_ids_pass2.numel() > 0: # Added check for model
565
+ outputs_pass2 = model.generate(
566
+ input_ids=input_ids_pass2,
567
+ generation_config=generation_config_pass2,
568
+ )
569
+
570
+ prompt_length_pass2 = input_ids_pass2.shape[1]
571
+ if outputs_pass2.shape[1] > prompt_length_pass2:
572
+ generated_tokens_pass2 = outputs_pass2[0, prompt_length_pass2:]
573
+ final_response = tokenizer.decode(generated_tokens_pass2, skip_special_tokens=True).strip()
574
+ else:
575
+ final_response = "..." # Indicate potentially empty response
576
+ else:
577
+ final_response = "Error: Model or empty input for Pass 2." # Indicate model not loaded or empty input
578
+
579
+
580
+ except Exception as gen_error:
581
+ print(f"Error during model generation in Pass 2: {gen_error}")
582
+ final_response = "Error generating response in Pass 2."
583
+
584
+
585
+ # --- Post-process Final Response from Pass 2 ---
586
+ cleaned_response = final_response
587
+ # Filter out the Pass 2 instructions and context markers that might bleed through
588
+ lines = cleaned_response.split('\n')
589
+ cleaned_lines = [line for line in lines if not line.strip().lower().startswith("business information")
590
+ and not line.strip().lower().startswith("search results")
591
+ and not line.strip().startswith("---")
592
+ and not line.strip().lower().startswith("original user query:")
593
+ and not line.strip().lower().startswith("you are a helpful assistant for a business.")]
594
+
595
+ cleaned_response = "\n".join(cleaned_lines).strip()
596
+
597
+ # Extract and list URLs from the search results that were actually used
598
+ # This assumes the model uses the provided snippets with URLs
599
+ urls_to_list = [result.get('href') for result in search_results_dicts if result.get('href')]
600
+ urls_to_list = list(dict.fromkeys(urls_to_list)) # Remove duplicates
601
+
602
+ # Only add Sources if search was performed AND results were found
603
+ if search_results_dicts and urls_to_list:
604
+ cleaned_response += "\n\nSources:\n" + "\n".join(urls_to_list)
605
+
606
+ final_response = cleaned_response
607
+
608
+ # Check if the final response is empty or just whitespace after cleaning
609
+ if not final_response.strip():
610
+ final_response = "Sorry, I couldn't generate a meaningful response based on the information found."
611
+ print("Warning: Final response was empty after cleaning.")
612
+
613
+ else: # Model or tokenizer not loaded (this check is at the very beginning of the function)
614
+ final_response = "Sorry, the core language model is not available."
615
+ print("Error: LLM model or tokenizer not loaded for Pass 2.")
616
+
617
+
618
+ # --- Update Chat History for Gradio ---
619
+ # Append the user's original message and the final bot response to the history state
620
+ # The format is (user_input, bot_response)
621
+ updated_chat_history = chat_history + [(original_user_input, final_response)]
622
+
623
+ # Optional: Manage history length
624
+ max_history_pairs = 10 # Keep last 10 turns (20 messages total)
625
+ if len(updated_chat_history) > max_history_pairs:
626
+ updated_chat_history = updated_chat_history[-max_history_pairs:]
627
+ # print(f"History truncated. Keeping last {len(updated_chat_history)} turns.") # Debug print
628
+
629
+ # Return the updated history state and an empty string for the input box
630
+ return "", updated_chat_history