Futuresony commited on
Commit
3b1ec5c
·
verified ·
1 Parent(s): 0a01327

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +658 -28
app.py CHANGED
@@ -1,38 +1,668 @@
1
- import gradio as gr
2
- from llama_cpp import Llama
3
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- # Download GGUF model from Hugging Face Hub
6
- MODEL_REPO = "Futuresony/gemma2-2b-gguf-q4_k_m"
7
- MODEL_FILENAME = "unsloth.Q4_K_M.gguf" # Or check exact filename on the repo
8
 
9
- model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
 
 
 
 
 
10
 
11
- # Load model
12
- llm = Llama(model_path=model_path, n_ctx=2048, n_threads=4, verbose=True)
13
 
14
- # Format prompt as Alpaca-style
15
- def format_prompt(user_message):
16
- return f"""### Instruction:
17
- {user_message}
 
 
 
18
 
19
- ### Response:"""
 
 
 
 
 
 
 
20
 
21
- # Chat handler
22
- def respond(user_message, chat_history):
23
- prompt = format_prompt(user_message)
24
- output = llm(prompt, max_tokens=300, stop=["###"])
25
- response = output["choices"][0]["text"].strip()
26
- chat_history.append((user_message, response))
27
- return "", chat_history
28
 
29
- # Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  with gr.Blocks() as demo:
31
- gr.Markdown("## 🤖 DStv AI Assistant - Powered by Gemma 2B GGUF")
32
- chatbot = gr.Chatbot()
33
- msg = gr.Textbox(placeholder="Ask your question...")
34
- state = gr.State([])
 
 
 
 
 
 
35
 
36
- msg.submit(respond, [msg, state], [msg, chatbot])
37
 
38
- demo.launch()
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import re
4
+ import warnings
5
+ import time
6
+ import json
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig
8
+ from sentence_transformers import SentenceTransformer, util
9
+ import gspread
10
+ from google.auth import default # Use standard google.auth
11
+ from tqdm import tqdm
12
+ from duckduckgo_search import DDGS
13
+ import spacy
14
+ import gradio as gr # Import gradio
15
+ from pathlib import Path # For handling spacy model path
16
 
17
+ # Suppress warnings
18
+ warnings.filterwarnings("ignore", category=UserWarning)
 
19
 
20
+ # --- Configuration ---
21
+ SHEET_ID = "19ipxC2vHYhpXCefpxpIkpeYdI43a1Ku2kYwecgUULIw" # Your Google Sheet ID
22
+ HF_TOKEN = os.getenv("HF_TOKEN") # Get Hugging Face token from Space Secrets
23
+ # It's highly recommended to use a Google Service Account for GSheets on Spaces
24
+ # Store the JSON key as a base64 encoded string in a Space Secret (e.g., GOOGLE_SERVICE_ACCOUNT_KEY_BASE64)
25
+ GOOGLE_SERVICE_ACCOUNT_KEY_BASE64 = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY_BASE64")
26
 
27
+ # Changed model_id to Gemma 2 9B
28
+ model_id = "google/gemma-2-9b-it" # Ensure this model is accessible with your HF_TOKEN
29
 
30
+ # --- Constants for Prompting and Validation ---
31
+ SEARCH_MARKER = "ACTION: SEARCH:"
32
+ BUSINESS_LOOKUP_MARKER = "ACTION: LOOKUP_BUSINESS_INFO:"
33
+ ANSWER_DIRECTLY_MARKER = "ACTION: ANSWER_DIRECTLY:"
34
+ BUSINESS_LOOKUP_VALIDATION_THRESHOLD = 0.6
35
+ SEARCH_VALIDATION_THRESHOLD = 0.6
36
+ PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD = 0.5
37
 
38
+ # --- Global variables to load once ---
39
+ tokenizer = None
40
+ model = None
41
+ nlp = None # SpaCy model
42
+ embedder = None # Sentence Transformer
43
+ data = [] # Google Sheet data
44
+ descriptions = []
45
+ embeddings = torch.tensor([]) # Google Sheet embeddings
46
 
47
+ # --- Loading Functions (Run once on startup) ---
 
 
 
 
 
 
48
 
49
+ def load_spacy_model():
50
+ """Loads or downloads the spaCy model."""
51
+ model_name = "en_core_web_sm"
52
+ try:
53
+ print(f"Loading spaCy model '{model_name}'...")
54
+ nlp_model = spacy.load(model_name)
55
+ print(f"SpaCy model '{model_name}' loaded.")
56
+ return nlp_model
57
+ except OSError:
58
+ print(f"SpaCy model '{model_name}' not found locally. Attempting download...")
59
+ # Use subprocess or os.system carefully in production, maybe pre-download in .spacebuild
60
+ # For simplicity here, we'll just print instruction or try a different path if needed.
61
+ print("Please ensure 'en_core_web_sm' is installed (e.g., `python -m spacy download en_core_web_sm`).")
62
+ print("Attempting to load after assuming it's installed via requirements.txt...")
63
+ try:
64
+ nlp_model = spacy.load(model_name)
65
+ print(f"SpaCy model '{model_name}' loaded after assumed installation.")
66
+ return nlp_model
67
+ except Exception as e:
68
+ print(f"Failed to load spaCy model '{model_name}' after assumed installation: {e}")
69
+ print("SpaCy will not be available.")
70
+ return None # Return None if loading fails
71
+
72
+ def load_sentence_transformer():
73
+ """Loads the Sentence Transformer model."""
74
+ print("Loading Sentence Transformer...")
75
+ try:
76
+ embedder_model = SentenceTransformer("all-MiniLM-L6-v2")
77
+ print("Sentence Transformer loaded.")
78
+ return embedder_model
79
+ except Exception as e:
80
+ print(f"Error loading Sentence Transformer: {e}")
81
+ return None
82
+
83
+ def load_google_sheet_data(sheet_id, service_account_key_base64):
84
+ """Authenticates and loads data from Google Sheet."""
85
+ print(f"Attempting to load Google Sheet data from ID: {sheet_id}")
86
+ if not service_account_key_base64:
87
+ print("Warning: GOOGLE_SERVICE_ACCOUNT_KEY_BASE64 secret is not set. Cannot access Google Sheets.")
88
+ return [], [], torch.tensor([])
89
+
90
+ try:
91
+ # Decode the base64 key
92
+ key_bytes = base64.b64decode(service_account_key_base64)
93
+ key_dict = json.loads(key_bytes)
94
+
95
+ # Authenticate using the service account key
96
+ creds = default(credentials=None, project=key_dict.get('project_id'))[0]
97
+ # Need to refresh/verify creds if not loaded from default
98
+ from google.oauth2 import service_account
99
+ creds = service_account.Credentials.from_service_account_info(key_dict)
100
+ client = gspread.authorize(creds)
101
+
102
+ sheet = client.open_by_key(sheet_id).sheet1
103
+ print(f"Successfully opened Google Sheet with ID: {sheet_id}")
104
+ sheet_data = sheet.get_all_records()
105
+
106
+ if not sheet_data:
107
+ print(f"Warning: No data records found in Google Sheet with ID: {sheet_id}")
108
+ return [], [], torch.tensor([])
109
+
110
+ filtered_data = [row for row in sheet_data if row.get('Service') and row.get('Description')]
111
+ if not filtered_data:
112
+ print("Warning: Filtered data is empty after checking for 'Service' and 'Description'.")
113
+ return [], [], torch.tensor([])
114
+
115
+ if not filtered_data or 'Service' not in filtered_data[0] or 'Description' not in filtered_data[0]:
116
+ print("Error: Filtered Google Sheet data must contain 'Service' and 'Description' columns.")
117
+ return [], [], torch.tensor([])
118
+
119
+ services = [row["Service"] for row in filtered_data]
120
+ descriptions = [row["Description"] for row in filtered_data]
121
+ print(f"Loaded {len(descriptions)} entries from Google Sheet for embedding.")
122
+
123
+ # Encoding descriptions - do this after loading embedder
124
+ # embeddings = embedder.encode(descriptions, convert_to_tensor=True) # This line must be AFTER embedder is loaded
125
+ # print("Encoding complete.")
126
+
127
+ return filtered_data, descriptions, None # Return descriptions, embeddings encoded later
128
+
129
+ except gspread.exceptions.SpreadsheetNotFound:
130
+ print(f"Error: Google Sheet with ID '{sheet_id}' not found.")
131
+ print("Please check the SHEET_ID and ensure the service account has access.")
132
+ return [], [], torch.tensor([])
133
+ except Exception as e:
134
+ print(f"An error occurred while accessing the Google Sheet: {e}")
135
+ return [], [], torch.tensor([])
136
+
137
+
138
+ def load_llm_model(model_id, hf_token):
139
+ """Loads the LLM using 4-bit quantization."""
140
+ print(f"Loading model {model_id}...")
141
+ if not hf_token:
142
+ print("Error: HF_TOKEN secret is not set. Cannot load Hugging Face model.")
143
+ return None, None
144
+
145
+ bnb_config = BitsAndBytesConfig(
146
+ load_in_4bit=True,
147
+ bnb_4bit_use_double_quant=True,
148
+ bnb_4bit_quant_type="nf4",
149
+ bnb_4bit_compute_dtype=torch.bfloat16
150
+ )
151
+
152
+ try:
153
+ llm_tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
154
+ if llm_tokenizer.pad_token is None:
155
+ llm_tokenizer.pad_token = llm_tokenizer.eos_token
156
+
157
+ llm_model = AutoModelForCausalLM.from_pretrained(
158
+ model_id,
159
+ token=hf_token,
160
+ device_map="auto", # Let accelerate decide
161
+ quantization_config=bnb_config,
162
+ )
163
+
164
+ print(f"Model {model_id} loaded using 4-bit quantization.")
165
+ return llm_model, llm_tokenizer
166
+
167
+ except Exception as e:
168
+ print(f"Error loading model {model_id}: {e}")
169
+ print("Please ensure bitsandbytes, trl, peft, and accelerate are installed.")
170
+ print("Check your Hugging Face token.")
171
+ # Do not raise, return None to allow app to start without LLM
172
+ return None, None
173
+
174
+ # --- Load all assets on startup ---
175
+ print("Loading assets...")
176
+ nlp = load_spacy_model()
177
+ embedder = load_sentence_transformer()
178
+ data, descriptions, _ = load_google_sheet_data(SHEET_ID, GOOGLE_SERVICE_ACCOUNT_KEY_BASE64) # Load data and descriptions first
179
+
180
+ if embedder and descriptions:
181
+ print("Encoding Google Sheet descriptions...")
182
+ try:
183
+ embeddings = embedder.encode(descriptions, convert_to_tensor=True)
184
+ print("Encoding complete.")
185
+ except Exception as e:
186
+ print(f"Error during embedding: {e}")
187
+ embeddings = torch.tensor([]) # Ensure embeddings is an empty tensor on error
188
+ else:
189
+ print("Skipping embedding due to missing embedder or descriptions.")
190
+ embeddings = torch.tensor([]) # Ensure embeddings is an empty tensor if no descriptions
191
+
192
+ model, tokenizer = load_llm_model(model_id, HF_TOKEN)
193
+
194
+ # Check if essential components loaded
195
+ if not model or not tokenizer or not embedder or not nlp:
196
+ print("\nERROR: Essential components failed to load. The application may not function correctly.")
197
+ if not model: print("- LLM Model failed to load.")
198
+ if not tokenizer: print("- LLM Tokenizer failed to load.")
199
+ if not embedder: print("- Sentence Embedder failed to load.")
200
+ if not nlp: print("- spaCy Model failed to load.")
201
+ # Continue, but the main inference function will need checks
202
+
203
+ # --- Helper Functions (from your script) ---
204
+
205
+ # Function to perform DuckDuckGo Search and return results with URLs
206
+ def perform_duckduckgo_search(query, max_results=3):
207
+ """
208
+ Performs a search using DuckDuckGo and returns a list of dictionaries.
209
+ Includes a delay to avoid rate limits.
210
+ """
211
+ search_results_list = []
212
+ try:
213
+ time.sleep(1) # Add a delay before each search
214
+ with DDGS() as ddgs:
215
+ for r in ddgs.text(query, max_results=max_results):
216
+ search_results_list.append(r) # Append the dictionary directly
217
+ except Exception as e:
218
+ print(f"Error during DuckDuckgo search for '{query}': {e}")
219
+ return []
220
+ return search_results_list
221
+
222
+ # Function to retrieve relevant business info
223
+ def retrieve_business_info(query, data, embeddings, embedder, threshold=0.50):
224
+ """
225
+ Retrieves relevant business information based on query similarity.
226
+ Returns a dictionary if a match above threshold is found, otherwise None.
227
+ Also returns the similarity score.
228
+ Uses the global embedder, data, and embeddings.
229
+ """
230
+ if not data or (embeddings is None or embeddings.numel() == 0) or embedder is None:
231
+ print("Skipping business info retrieval: Data, embeddings or embedder not available.")
232
+ return None, 0.0
233
+
234
+ try:
235
+ user_embedding = embedder.encode(query, convert_to_tensor=True)
236
+ cos_scores = util.cos_sim(user_embedding, embeddings)[0]
237
+ best_score = cos_scores.max().item()
238
+
239
+ if best_score > threshold:
240
+ best_match_idx = cos_scores.argmax().item()
241
+ best_match = data[best_match_idx]
242
+ return best_match, best_score
243
+ else:
244
+ return None, best_score
245
+ except Exception as e:
246
+ print(f"Error during business information retrieval: {e}")
247
+ return None, 0.0
248
+
249
+ # Function to split user query into potential sub-queries using spaCy
250
+ def split_query(query):
251
+ """Splits a user query into potential sub-queries using spaCy."""
252
+ if nlp is None:
253
+ print("SpaCy model not loaded. Cannot split query.")
254
+ return [query] # Return original query if nlp is not available
255
+
256
+ try:
257
+ doc = nlp(query)
258
+ sentences = [sent.text.strip() for sent in doc.sents]
259
+ if len(sentences) == 1:
260
+ parts = re.split(r',| and (who|what|where|when|why|how|is|are|can|tell me about)|;', query, flags=re.IGNORECASE)
261
+ parts = [part.strip() for part in parts if part is not None and part.strip()]
262
+ if len(parts) <= 1:
263
+ return [query]
264
+ return parts
265
+ return sentences
266
+ except Exception as e:
267
+ print(f"Error during query splitting: {e}")
268
+ return [query] # Return original query on error
269
+
270
+ # --- Pass 1 System Prompt ---
271
+ pass1_instructions_action = """You are a helpful assistant for a business. Your primary goal in this first step is to analyze the user's query and decide which actions are needed to answer it.
272
+
273
+ You have analyzed the user's query and potentially broken it down into parts. For each part, a preliminary check was done to see if it matches known business information. The results of this check are provided below.
274
+
275
+ {business_check_summary}
276
+
277
+ Based on the user's query and the results of the business info check for each part, identify if you need to perform actions.
278
+
279
+ Output one or more actions, each on a new line, in the format:
280
+ ACTION: [ACTION_TYPE]: [Argument/Query for the action]
281
+
282
+ Possible actions:
283
+ 1. **LOOKUP_BUSINESS_INFO**: If a part of the query asks about the business's services, prices, availability, or individuals mentioned in the business context, *and* the business info check for that part indicates a high relevance ({PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f} or higher). The argument should be the specific phrase or name to look up.
284
+ 2. **SEARCH**: If a part of the query asks for current external information (e.g., current events, real-time data, general facts not in business info), *or* if a part that seems like it could be business info did *not* have a high relevance score in the preliminary check (below {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f}). The argument should be the precise search query.
285
+ 3. **ANSWER_DIRECTLY**: If the overall query is a simple greeting or can be answered from your general knowledge without lookup or search, *and* the business info check results indicate low relevance for all parts. The argument should be the direct answer here.
286
+
287
+ **Crucially:**
288
+ - **Prioritize LOOKUP_BUSINESS_INFO** for any part of the query where the preliminary business info check score was {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f} or higher.
289
+ - Use **SEARCH** for parts about external information or where the business info check score was below {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f}.
290
+ - If a part of the query is clearly external (like asking about current events or famous people) even if its business info score wasn't zero, you should likely use SEARCH for it.
291
+ - Do NOT output any other text besides the ACTION lines.
292
+ - If the results suggest a direct answer is sufficient, use ANSWER_DIRECTLY.
293
+
294
+ Now, analyze the following user query, considering the business info check results provided above, and output the required actions:
295
+ """
296
+
297
+ # --- Pass 2 System Prompt ---
298
+ pass2_instructions_synthesize = """You are a helpful assistant for a business. You have been provided with the original user query, relevant Business Information (if found), and results from external searches (if performed).
299
+
300
+ Your task is to synthesize ALL the provided information to answer the user's original question concisely and accurately.
301
+
302
+ **Prioritize Business Information** for details about the business, its services, or individuals mentioned within that context.
303
+ Use the Search Results for current external information that was requested.
304
+ If information for a specific part of the question was not found in either Business Information or Search Results, use your general knowledge if possible, or state that the information could not be found.
305
+
306
+ Synthesize the information into a natural language response. Do NOT copy and paste raw context or strings like 'Business Information:' or 'SEARCH RESULTS:' or 'ACTION:' or the raw user query.
307
+
308
+ After your answer, generate a few concise follow-up questions that a user might ask based on the previous turn's conversation and your response. List these questions clearly at the end of your response.
309
+ When search results were used to answer the question, list the URLs from the search results you used under a "Sources:" heading at the very end.
310
+ """
311
+
312
+
313
+ # --- Main Inference Function for Gradio ---
314
+ # This function will be called every time the user submits a query
315
+ # chat_history is now a parameter managed by Gradio's State
316
+ def respond(user_input, chat_history):
317
+ """
318
+ Processes user input, performs actions (lookup/search), and generates a response.
319
+ Manages chat history within Gradio state.
320
+ """
321
+ # Check if models loaded successfully
322
+ if model is None or tokenizer is None or embedder is None or nlp is None:
323
+ return chat_history + [(user_input, "Sorry, the application failed to load necessary components. Please try again later or contact the administrator.")]
324
+
325
+ original_user_input = user_input
326
+
327
+ # Initialize action results containers for this turn
328
+ search_results_dicts = []
329
+ business_lookup_results_formatted = []
330
+ response_pass1_raw = "" # To store the raw actions generated by Pass 1
331
+
332
+ # --- Pre-Pass 1: Programmatic Business Info Check for Query Parts ---
333
+ query_parts = split_query(original_user_input)
334
+ business_check_results = []
335
+ overall_pre_pass1_score = 0.0
336
+
337
+ print("\n--- Processing new user query ---")
338
+ print(f"User: {user_input}")
339
+ print("Performing programmatic business info check on query parts...")
340
+
341
+ if query_parts:
342
+ for i, part in enumerate(query_parts):
343
+ match, score = retrieve_business_info(part, data, embeddings, embedder, threshold=0.0)
344
+ business_check_results.append({"part": part, "score": score, "match": match})
345
+ print(f"- Part '{part}': Score {score:.4f}")
346
+ overall_pre_pass1_score = max(overall_pre_pass1_score, score)
347
+ else:
348
+ match, score = retrieve_business_info(original_user_input, data, embeddings, embedder, threshold=0.0)
349
+ business_check_results.append({"part": original_user_input, "score": score, "match": match})
350
+ print(f"- Part '{original_user_input}': Score {score:.4f}")
351
+ overall_pre_pass1_score = score
352
+
353
+ is_likely_direct_answer = overall_pre_pass1_score < PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD and len(query_parts) <= 2
354
+
355
+ # Format business check summary for Pass 1 prompt
356
+ business_check_summary = "Business Info Check Results for Query Parts:\n"
357
+ if business_check_results:
358
+ for result in business_check_results:
359
+ status = "High Relevance" if result['score'] >= PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD else "Low Relevance"
360
+ business_check_summary += f"- Part '{result['part']}': Score {result['score']:.4f} ({status})\n"
361
+ else:
362
+ business_check_summary += "- No parts identified or check skipped.\n"
363
+ business_check_summary += "\n"
364
+
365
+
366
+ # --- Pass 1: Action Identification (if not direct answer) ---
367
+ requested_actions = []
368
+ answer_directly_provided = None
369
+
370
+ if is_likely_direct_answer:
371
+ print("Programmatically determined likely direct answer.")
372
+ response_pass1_raw = f"ACTION: ANSWER_DIRECTLY: " # Signal Pass 2
373
+
374
+ else:
375
+ pass1_user_message_content = pass1_instructions_action.format(
376
+ business_check_summary=business_check_summary,
377
+ PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD=PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD # Pass threshold to prompt
378
+ ) + "\n\nUser Query: " + user_input
379
+
380
+ # Create a temporary history for Pass 1 focusing only on the current turn's user query and instructions
381
+ temp_chat_history_pass1 = [{"role": "user", "content": pass1_user_message_content}]
382
+
383
+ try:
384
+ prompt_pass1 = tokenizer.apply_chat_template(
385
+ temp_chat_history_pass1,
386
+ tokenize=False,
387
+ add_generation_prompt=True
388
+ )
389
+ # print("\n--- Pass 1 Prompt ---") # Debug print
390
+ # print(prompt_pass1)
391
+ # print("---------------------")
392
+
393
+ generation_config_pass1 = GenerationConfig(
394
+ max_new_tokens=200,
395
+ do_sample=False,
396
+ temperature=0.1,
397
+ eos_token_id=tokenizer.eos_token_id,
398
+ pad_token_id=tokenizer.pad_token_id,
399
+ use_cache=True
400
+ )
401
+
402
+ input_ids_pass1 = tokenizer(prompt_pass1, return_tensors="pt").input_ids.to(model.device)
403
+
404
+ if input_ids_pass1.numel() > 0:
405
+ outputs_pass1 = model.generate(
406
+ input_ids=input_ids_pass1,
407
+ generation_config=generation_config_pass1,
408
+ )
409
+ prompt_length_pass1 = input_ids_pass1.shape[1]
410
+ if outputs_pass1.shape[1] > prompt_length_pass1:
411
+ generated_tokens_pass1 = outputs_pass1[0, prompt_length_pass1:]
412
+ response_pass1_raw = tokenizer.decode(generated_tokens_pass1, skip_special_tokens=True).strip()
413
+ else:
414
+ response_pass1_raw = "" # No actions generated
415
+ else:
416
+ response_pass1_raw = "" # Empty input
417
+
418
+ # print("\n--- Raw Pass 1 Response ---") # Debug print
419
+ # print(response_pass1_raw)
420
+ # print("--------------------------")
421
+
422
+
423
+ except Exception as e:
424
+ print(f"Error during Pass 1 (Action Identification): {e}")
425
+ # If Pass 1 fails, fallback to attempting a direct answer in Pass 2
426
+ response_pass1_raw = f"ACTION: ANSWER_DIRECTLY: Error in Pass 1 - {e}"
427
+
428
+
429
+ # --- Parse Model's Requested Actions with Validation ---
430
+ # Always parse even if flagged for direct answer to handle potential Pass 1 errors
431
+ if response_pass1_raw:
432
+ lines = response_pass1_raw.strip().split('\n')
433
+ for line in lines:
434
+ line = line.strip()
435
+ if line.startswith(SEARCH_MARKER):
436
+ query = line[len(SEARCH_MARKER):].strip()
437
+ if query:
438
+ # Validate SEARCH Action
439
+ _, score = retrieve_business_info(query, data, embeddings, embedder, threshold=0.0)
440
+ if score < SEARCH_VALIDATION_THRESHOLD:
441
+ requested_actions.append(("SEARCH", query))
442
+ print(f"Validated Search Action for '{query}' (Score: {score:.4f})")
443
+ else:
444
+ print(f"Rejected Search Action for '{query}' (Score: {score:.4f}) - Too similar to business data.")
445
+ elif line.startswith(BUSINESS_LOOKUP_MARKER):
446
+ query = line[len(BUSINESS_LOOKUP_MARKER):].strip()
447
+ if query:
448
+ # Validate Business Lookup Query
449
+ match, score = retrieve_business_info(query, data, embeddings, embedder, threshold=0.0) # Use low threshold for scoring
450
+ if score > BUSINESS_LOOKUP_VALIDATION_THRESHOLD:
451
+ requested_actions.append(("LOOKUP_BUSINESS_INFO", query))
452
+ print(f"Validated Business Lookup Action for '{query}' (Score: {score:.4f})")
453
+ else:
454
+ print(f"Rejected Business Lookup Action for '{query}' (Score: {score:.4f}) - Below validation threshold.")
455
+ elif line.startswith(ANSWER_DIRECTLY_MARKER):
456
+ answer = line[len(ANSWER_DIRECTLY_MARKER):].strip()
457
+ answer_directly_provided = answer if answer else original_user_input # Use explicit answer if provided, else original query hint
458
+ requested_actions = [] # Clear other actions if DIRECT_ANSWER is given
459
+ break # Exit action parsing loop
460
+
461
+ # --- Execute Actions (Search and Lookup) ---
462
+ # Only execute actions if ANSWER_DIRECTLY was NOT the primary outcome of Pass 1
463
+ # and there are validated requested actions.
464
+ context_for_pass2 = ""
465
+
466
+ if requested_actions:
467
+ print("Executing requested actions...")
468
+ for action_type, query in requested_actions:
469
+ if action_type == "SEARCH":
470
+ print(f"Performing search for: '{query}'")
471
+ results = perform_duckduckgo_search(query)
472
+ if results:
473
+ search_results_dicts.extend(results)
474
+ print(f"Found {len(results)} search results.")
475
+ else:
476
+ print(f"No search results found for '{query}'.")
477
+
478
+ elif action_type == "LOOKUP_BUSINESS_INFO":
479
+ print(f"Performing business info lookup for: '{query}'")
480
+ match, score = retrieve_business_info(query, data, embeddings, embedder, threshold=retrieve_business_info.__defaults__[0]) # Use default threshold for retrieval
481
+ print(f"Actual lookup score for '{query}': {score:.4f} (Threshold: {retrieve_business_info.__defaults__[0]})")
482
+ if match:
483
+ formatted_match = f"""Service: {match.get('Service', 'N/A')}
484
+ Description: {match.get('Description', 'N/A')}
485
+ Price: {match.get('Price', 'N/A')}
486
+ Available: {match.get('Available', 'N/A')}"""
487
+ business_lookup_results_formatted.append(formatted_match)
488
+ print(f"Found business info match.")
489
+ else:
490
+ print(f"No business info match found for '{query}' at threshold {retrieve_business_info.__defaults__[0]}.")
491
+
492
+ # --- Prepare Context for Pass 2 based on executed actions ---
493
+ if business_lookup_results_formatted:
494
+ context_for_pass2 += "Business Information (Use this for questions about the business):\n"
495
+ context_for_pass2 += "\n---\n".join(business_lookup_results_formatted)
496
+ context_for_pass2 += "\n\n"
497
+
498
+ if search_results_dicts:
499
+ context_for_pass2 += "SEARCH RESULTS (Use this for current external information):\n"
500
+ aggregated_search_results_formatted = []
501
+ for result in search_results_dicts:
502
+ aggregated_search_results_formatted.append(f"Title: {result.get('title', 'N/A')}\nSnippet: {result.get('body', 'N/A')}\nURL: {result.get('href', 'N/A')}")
503
+ context_for_pass2 += "\n---\n".join(aggregated_search_results_formatted) + "\n\n"
504
+
505
+ if requested_actions and not business_lookup_results_formatted and not search_results_dicts:
506
+ context_for_pass2 = "Note: No relevant information was found in Business Information or via Search for your query."
507
+ print("Note: No results were found for the requested actions.")
508
+
509
+ # If ANSWER_DIRECTLY was determined (either programmatically or by Pass 1 model output)
510
+ if answer_directly_provided is not None:
511
+ print(f"Handling as direct answer: {answer_directly_provided}")
512
+ # Provide a simple context indicating it's a direct answer scenario
513
+ context_for_pass2 = "Note: This query is a simple request or greeting."
514
+ if answer_directly_provided != original_user_input and answer_directly_provided != "":
515
+ context_for_pass2 += f" Initial suggestion from action step: {answer_directly_provided}"
516
+ # Ensure no search/lookup results are included if it was flagged as direct answer
517
+ search_results_dicts = []
518
+ business_lookup_results_formatted = []
519
+
520
+
521
+ # If no actions were requested or direct answer flagged, and no results found...
522
+ # This handles cases where Pass 1 failed or generated nothing useful
523
+ if not requested_actions and answer_directly_provided is None:
524
+ if response_pass1_raw.strip():
525
+ print("Warning: Pass 1 did not result in valid actions or a direct answer.")
526
+ context_for_pass2 = f"Error: Could not determine actions from Pass 1 response: '{response_pass1_raw}'."
527
+ else:
528
+ print("Warning: Pass 1 generated an empty response.")
529
+ context_for_pass2 = "Error: Pass 1 generated an empty response."
530
+ # In this case, we will still try Pass 2 with the limited context
531
+
532
+
533
+ # --- Pass 2: Synthesize and Respond ---
534
+ final_response = "Sorry, I couldn't generate a response." # Default response on error
535
+
536
+ if model is not None and tokenizer is not None:
537
+ pass2_user_message_content = pass2_instructions_synthesize + "\n\nOriginal User Query: " + original_user_input + "\n\n" + context_for_pass2
538
+
539
+ # --- Chat History Management for Pass 2 ---
540
+ # Gradio's chat history state is [(User1, Bot1), (User2, Bot2), ...]
541
+ # We need to format the history correctly for the model template
542
+ # The Pass 2 prompt should build upon the *actual* conversation history, not just the Pass 2 context message.
543
+ # Let's build the chat history for the model template
544
+ model_chat_history = []
545
+ for user_msg, bot_msg in chat_history:
546
+ model_chat_history.append({"role": "user", "content": user_msg})
547
+ model_chat_history.append({"role": "assistant", "content": bot_msg})
548
+
549
+ # Add the *current* user query and the Pass 2 specific content as the latest turn
550
+ # The Pass 2 instructions and context are part of the *current* user turn's input to the model
551
+ model_chat_history.append({"role": "user", "content": pass2_user_message_content})
552
+
553
+ try:
554
+ prompt_pass2 = tokenizer.apply_chat_template(
555
+ model_chat_history,
556
+ tokenize=False,
557
+ add_generation_prompt=True # Add the assistant prompt token to start the response
558
+ )
559
+ # print("\n--- Pass 2 Prompt ---") # Debug print
560
+ # print(prompt_pass2)
561
+ # print("---------------------")
562
+
563
+
564
+ generation_config_pass2 = GenerationConfig(
565
+ max_new_tokens=1500, # Generate a longer response
566
+ do_sample=True,
567
+ temperature=0.7,
568
+ top_k=50,
569
+ top_p=0.95,
570
+ repetition_penalty=1.1,
571
+ eos_token_id=tokenizer.eos_token_id,
572
+ pad_token_id=tokenizer.pad_token_id,
573
+ use_cache=True
574
+ )
575
+
576
+ input_ids_pass2 = tokenizer(prompt_pass2, return_tensors="pt").input_ids.to(model.device)
577
+
578
+ if input_ids_pass2.numel() > 0:
579
+ outputs_pass2 = model.generate(
580
+ input_ids=input_ids_pass2,
581
+ generation_config=generation_config_pass2,
582
+ )
583
+
584
+ prompt_length_pass2 = input_ids_pass2.shape[1]
585
+ if outputs_pass2.shape[1] > prompt_length_pass2:
586
+ generated_tokens_pass2 = outputs_pass2[0, prompt_length_pass2:]
587
+ final_response = tokenizer.decode(generated_tokens_pass2, skip_special_tokens=True).strip()
588
+ else:
589
+ final_response = "..." # Indicate potentially empty response
590
+
591
+
592
+ except Exception as gen_error:
593
+ print(f"Error during model generation in Pass 2: {gen_error}")
594
+ final_response = "Error generating response in Pass 2."
595
+
596
+
597
+ # --- Post-process Final Response from Pass 2 ---
598
+ cleaned_response = final_response
599
+ # Filter out the Pass 2 instructions and context markers that might bleed through
600
+ lines = cleaned_response.split('\n')
601
+ cleaned_lines = [line for line in lines if not line.strip().lower().startswith("business information")
602
+ and not line.strip().lower().startswith("search results")
603
+ and not line.strip().startswith("---")
604
+ and not line.strip().lower().startswith("original user query:")
605
+ and not line.strip().lower().startswith("you are a helpful assistant for a business.")]
606
+
607
+ cleaned_response = "\n".join(cleaned_lines).strip()
608
+
609
+ # Extract and list URLs from the search results that were actually used
610
+ # This assumes the model uses the provided snippets with URLs
611
+ urls_to_list = [result.get('href') for result in search_results_dicts if result.get('href')]
612
+ urls_to_list = list(dict.fromkeys(urls_to_list)) # Remove duplicates
613
+
614
+ # Only add Sources if search was performed AND results were found
615
+ if search_results_dicts and urls_to_list:
616
+ cleaned_response += "\n\nSources:\n" + "\n".join(urls_to_list)
617
+
618
+ final_response = cleaned_response
619
+
620
+ # Check if the final response is empty or just whitespace after cleaning
621
+ if not final_response.strip():
622
+ final_response = "Sorry, I couldn't generate a meaningful response based on the information found."
623
+ print("Warning: Final response was empty after cleaning.")
624
+
625
+
626
+ else: # Model or tokenizer not loaded
627
+ final_response = "Sorry, the core language model is not available."
628
+ print("Error: LLM model or tokenizer not loaded for Pass 2.")
629
+
630
+
631
+ # --- Update Chat History for Gradio ---
632
+ # Append the user's original message and the final bot response to the history state
633
+ chat_history = chat_history + [(original_user_input, final_response)]
634
+
635
+ # Optional: Manage history length
636
+ max_history_pairs = 10 # Keep last 10 turns (20 messages total)
637
+ if len(chat_history) > max_history_pairs:
638
+ chat_history = chat_history[-max_history_pairs:]
639
+ # print(f"History truncated. Keeping last {len(chat_history)} turns.") # Debug print
640
+
641
+
642
+ # Return the updated history state
643
+ return "", chat_history # Return empty string for the input box, and the updated history
644
+
645
+ # --- Gradio Interface ---
646
+
647
+ # Gradio will run this interface
648
+ # The `chatbot` element will manage the chat history state automatically
649
  with gr.Blocks() as demo:
650
+ gr.Markdown(f"# Business Q&A Assistant with {model_id}")
651
+ gr.Markdown("Ask questions about the business (details from Google Sheet) or general knowledge (via search).")
652
+
653
+ chatbot = gr.Chatbot(height=400) # height in pixels
654
+ msg = gr.Textbox(label="Your Question")
655
+ clear = gr.ClearButton([msg, chatbot])
656
+
657
+ # The submit event triggers the respond function
658
+ # The state is automatically passed to the last parameter of `respond`
659
+ msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])
660
 
 
661
 
662
+ # Launch the Gradio app
663
+ # share=True creates a public link (useful for testing, disable for private apps)
664
+ # enable_queue=True can help handle multiple users if needed
665
+ if __name__ == "__main__":
666
+ print("Launching Gradio app...")
667
+ # Use server_name="0.0.0.0" and server_port=7860 for Hugging Face Spaces default
668
+ demo.launch(server_name="0.0.0.0", server_port=7860)