Spaces:
Sleeping
Sleeping
# This use Gemma 9b and bitsandbytes | |
import os | |
import torch | |
import re | |
import warnings | |
import time | |
import json | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, BitsAndBytesConfig | |
from sentence_transformers import SentenceTransformer, util | |
import gspread | |
from google.auth import default # Use standard google.auth | |
from tqdm import tqdm | |
from duckduckgo_search import DDGS | |
import spacy | |
import gradio as gr # Import gradio | |
from pathlib import Path # For handling spacy model path | |
# Suppress warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
# --- Configuration --- | |
SHEET_ID = "19ipxC2vHYhpXCefpxpIkpeYdI43a1Ku2kYwecgUULIw" # Your Google Sheet ID | |
HF_TOKEN = os.getenv("HF_TOKEN") # Get Hugging Face token from Space Secrets | |
# It's highly recommended to use a Google Service Account for GSheets on Spaces | |
# Store the JSON key as a base64 encoded string in a Space Secret (e.g., GOOGLE_SERVICE_ACCOUNT_KEY_BASE64) | |
GOOGLE_SERVICE_ACCOUNT_KEY_BASE64 = os.getenv("GOOGLE_SERVICE_ACCOUNT_KEY_BASE64") | |
# Changed model_id to Gemma 2 9B | |
model_id = "google/gemma-2-9b-it" # Ensure this model is accessible with your HF_TOKEN | |
# --- Constants for Prompting and Validation --- | |
SEARCH_MARKER = "ACTION: SEARCH:" | |
BUSINESS_LOOKUP_MARKER = "ACTION: LOOKUP_BUSINESS_INFO:" | |
ANSWER_DIRECTLY_MARKER = "ACTION: ANSWER_DIRECTLY:" | |
BUSINESS_LOOKUP_VALIDATION_THRESHOLD = 0.6 | |
SEARCH_VALIDATION_THRESHOLD = 0.6 | |
PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD = 0.5 | |
# --- Global variables to load once --- | |
tokenizer = None | |
model = None | |
nlp = None # SpaCy model | |
embedder = None # Sentence Transformer | |
data = [] # Google Sheet data | |
descriptions = [] | |
embeddings = torch.tensor([]) # Google Sheet embeddings | |
# --- Loading Functions (Run once on startup) --- | |
def load_spacy_model(): | |
"""Loads or downloads the spaCy model.""" | |
model_name = "en_core_web_sm" | |
try: | |
print(f"Loading spaCy model '{model_name}'...") | |
nlp_model = spacy.load(model_name) | |
print(f"SpaCy model '{model_name}' loaded.") | |
return nlp_model | |
except OSError: | |
print(f"SpaCy model '{model_name}' not found locally. Attempting download...") | |
# Use subprocess or os.system carefully in production, maybe pre-download in .spacebuild | |
# For simplicity here, we'll just print instruction or try a different path if needed. | |
print("Please ensure 'en_core_web_sm' is installed (e.g., `python -m spacy download en_core_web_sm`).") | |
print("Attempting to load after assuming it's installed via requirements.txt...") | |
try: | |
nlp_model = spacy.load(model_name) | |
print(f"SpaCy model '{model_name}' loaded after assumed installation.") | |
return nlp_model | |
except Exception as e: | |
print(f"Failed to load spaCy model '{model_name}' after assumed installation: {e}") | |
print("SpaCy will not be available.") | |
return None # Return None if loading fails | |
def load_sentence_transformer(): | |
"""Loads the Sentence Transformer model.""" | |
print("Loading Sentence Transformer...") | |
try: | |
embedder_model = SentenceTransformer("all-MiniLM-L6-v2") | |
print("Sentence Transformer loaded.") | |
return embedder_model | |
except Exception as e: | |
print(f"Error loading Sentence Transformer: {e}") | |
return None | |
def load_google_sheet_data(sheet_id, service_account_key_base64): | |
"""Authenticates and loads data from Google Sheet.""" | |
print(f"Attempting to load Google Sheet data from ID: {sheet_id}") | |
if not service_account_key_base64: | |
print("Warning: GOOGLE_SERVICE_ACCOUNT_KEY_BASE64 secret is not set. Cannot access Google Sheets.") | |
return [], [], torch.tensor([]) | |
try: | |
# Decode the base64 key | |
key_bytes = base64.b64decode(service_account_key_base64) | |
key_dict = json.loads(key_bytes) | |
# Authenticate using the service account key | |
creds = default(credentials=None, project=key_dict.get('project_id'))[0] | |
# Need to refresh/verify creds if not loaded from default | |
from google.oauth2 import service_account | |
creds = service_account.Credentials.from_service_account_info(key_dict) | |
client = gspread.authorize(creds) | |
sheet = client.open_by_key(sheet_id).sheet1 | |
print(f"Successfully opened Google Sheet with ID: {sheet_id}") | |
sheet_data = sheet.get_all_records() | |
if not sheet_data: | |
print(f"Warning: No data records found in Google Sheet with ID: {sheet_id}") | |
return [], [], torch.tensor([]) | |
filtered_data = [row for row in sheet_data if row.get('Service') and row.get('Description')] | |
if not filtered_data: | |
print("Warning: Filtered data is empty after checking for 'Service' and 'Description'.") | |
return [], [], torch.tensor([]) | |
if not filtered_data or 'Service' not in filtered_data[0] or 'Description' not in filtered_data[0]: | |
print("Error: Filtered Google Sheet data must contain 'Service' and 'Description' columns.") | |
return [], [], torch.tensor([]) | |
services = [row["Service"] for row in filtered_data] | |
descriptions = [row["Description"] for row in filtered_data] | |
print(f"Loaded {len(descriptions)} entries from Google Sheet for embedding.") | |
# Encoding descriptions - do this after loading embedder | |
# embeddings = embedder.encode(descriptions, convert_to_tensor=True) # This line must be AFTER embedder is loaded | |
# print("Encoding complete.") | |
return filtered_data, descriptions, None # Return descriptions, embeddings encoded later | |
except gspread.exceptions.SpreadsheetNotFound: | |
print(f"Error: Google Sheet with ID '{sheet_id}' not found.") | |
print("Please check the SHEET_ID and ensure the service account has access.") | |
return [], [], torch.tensor([]) | |
except Exception as e: | |
print(f"An error occurred while accessing the Google Sheet: {e}") | |
return [], [], torch.tensor([]) | |
def load_llm_model(model_id, hf_token): | |
"""Loads the LLM using 4-bit quantization.""" | |
print(f"Loading model {model_id}...") | |
if not hf_token: | |
print("Error: HF_TOKEN secret is not set. Cannot load Hugging Face model.") | |
return None, None | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16 | |
) | |
try: | |
llm_tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
if llm_tokenizer.pad_token is None: | |
llm_tokenizer.pad_token = llm_tokenizer.eos_token | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=hf_token, | |
device_map="auto", # Let accelerate decide | |
quantization_config=bnb_config, | |
) | |
print(f"Model {model_id} loaded using 4-bit quantization.") | |
return llm_model, llm_tokenizer | |
except Exception as e: | |
print(f"Error loading model {model_id}: {e}") | |
print("Please ensure bitsandbytes, trl, peft, and accelerate are installed.") | |
print("Check your Hugging Face token.") | |
# Do not raise, return None to allow app to start without LLM | |
return None, None | |
# --- Load all assets on startup --- | |
print("Loading assets...") | |
nlp = load_spacy_model() | |
embedder = load_sentence_transformer() | |
data, descriptions, _ = load_google_sheet_data(SHEET_ID, GOOGLE_SERVICE_ACCOUNT_KEY_BASE64) # Load data and descriptions first | |
if embedder and descriptions: | |
print("Encoding Google Sheet descriptions...") | |
try: | |
embeddings = embedder.encode(descriptions, convert_to_tensor=True) | |
print("Encoding complete.") | |
except Exception as e: | |
print(f"Error during embedding: {e}") | |
embeddings = torch.tensor([]) # Ensure embeddings is an empty tensor on error | |
else: | |
print("Skipping embedding due to missing embedder or descriptions.") | |
embeddings = torch.tensor([]) # Ensure embeddings is an empty tensor if no descriptions | |
model, tokenizer = load_llm_model(model_id, HF_TOKEN) | |
# Check if essential components loaded | |
if not model or not tokenizer or not embedder or not nlp: | |
print("\nERROR: Essential components failed to load. The application may not function correctly.") | |
if not model: print("- LLM Model failed to load.") | |
if not tokenizer: print("- LLM Tokenizer failed to load.") | |
if not embedder: print("- Sentence Embedder failed to load.") | |
if not nlp: print("- spaCy Model failed to load.") | |
# Continue, but the main inference function will need checks | |
# --- Helper Functions (from your script) --- | |
# Function to perform DuckDuckGo Search and return results with URLs | |
def perform_duckduckgo_search(query, max_results=3): | |
""" | |
Performs a search using DuckDuckGo and returns a list of dictionaries. | |
Includes a delay to avoid rate limits. | |
""" | |
search_results_list = [] | |
try: | |
time.sleep(1) # Add a delay before each search | |
with DDGS() as ddgs: | |
for r in ddgs.text(query, max_results=max_results): | |
search_results_list.append(r) # Append the dictionary directly | |
except Exception as e: | |
print(f"Error during DuckDuckgo search for '{query}': {e}") | |
return [] | |
return search_results_list | |
# Function to retrieve relevant business info | |
def retrieve_business_info(query, data, embeddings, embedder, threshold=0.50): | |
""" | |
Retrieves relevant business information based on query similarity. | |
Returns a dictionary if a match above threshold is found, otherwise None. | |
Also returns the similarity score. | |
Uses the global embedder, data, and embeddings. | |
""" | |
if not data or (embeddings is None or embeddings.numel() == 0) or embedder is None: | |
print("Skipping business info retrieval: Data, embeddings or embedder not available.") | |
return None, 0.0 | |
try: | |
user_embedding = embedder.encode(query, convert_to_tensor=True) | |
cos_scores = util.cos_sim(user_embedding, embeddings)[0] | |
best_score = cos_scores.max().item() | |
if best_score > threshold: | |
best_match_idx = cos_scores.argmax().item() | |
best_match = data[best_match_idx] | |
return best_match, best_score | |
else: | |
return None, best_score | |
except Exception as e: | |
print(f"Error during business information retrieval: {e}") | |
return None, 0.0 | |
# Function to split user query into potential sub-queries using spaCy | |
def split_query(query): | |
"""Splits a user query into potential sub-queries using spaCy.""" | |
if nlp is None: | |
print("SpaCy model not loaded. Cannot split query.") | |
return [query] # Return original query if nlp is not available | |
try: | |
doc = nlp(query) | |
sentences = [sent.text.strip() for sent in doc.sents] | |
if len(sentences) == 1: | |
parts = re.split(r',| and (who|what|where|when|why|how|is|are|can|tell me about)|;', query, flags=re.IGNORECASE) | |
parts = [part.strip() for part in parts if part is not None and part.strip()] | |
if len(parts) <= 1: | |
return [query] | |
return parts | |
return sentences | |
except Exception as e: | |
print(f"Error during query splitting: {e}") | |
return [query] # Return original query on error | |
# --- Pass 1 System Prompt --- | |
pass1_instructions_action = """You are a helpful assistant for a business. Your primary goal in this first step is to analyze the user's query and decide which actions are needed to answer it. | |
You have analyzed the user's query and potentially broken it down into parts. For each part, a preliminary check was done to see if it matches known business information. The results of this check are provided below. | |
{business_check_summary} | |
Based on the user's query and the results of the business info check for each part, identify if you need to perform actions. | |
Output one or more actions, each on a new line, in the format: | |
ACTION: [ACTION_TYPE]: [Argument/Query for the action] | |
Possible actions: | |
1. **LOOKUP_BUSINESS_INFO**: If a part of the query asks about the business's services, prices, availability, or individuals mentioned in the business context, *and* the business info check for that part indicates a high relevance ({PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f} or higher). The argument should be the specific phrase or name to look up. | |
2. **SEARCH**: If a part of the query asks for current external information (e.g., current events, real-time data, general facts not in business info), *or* if a part that seems like it could be business info did *not* have a high relevance score in the preliminary check (below {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f}). The argument should be the precise search query. | |
3. **ANSWER_DIRECTLY**: If the overall query is a simple greeting or can be answered from your general knowledge without lookup or search, *and* the business info check results indicate low relevance for all parts. The argument should be the direct answer here. | |
**Crucially:** | |
- **Prioritize LOOKUP_BUSINESS_INFO** for any part of the query where the preliminary business info check score was {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f} or higher. | |
- Use **SEARCH** for parts about external information or where the business info check score was below {PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD:.2f}. | |
- If a part of the query is clearly external (like asking about current events or famous people) even if its business info score wasn't zero, you should likely use SEARCH for it. | |
- Do NOT output any other text besides the ACTION lines. | |
- If the results suggest a direct answer is sufficient, use ANSWER_DIRECTLY. | |
Now, analyze the following user query, considering the business info check results provided above, and output the required actions: | |
""" | |
# --- Pass 2 System Prompt --- | |
pass2_instructions_synthesize = """You are a helpful assistant for a business. You have been provided with the original user query, relevant Business Information (if found), and results from external searches (if performed). | |
Your task is to synthesize ALL the provided information to answer the user's original question concisely and accurately. | |
**Prioritize Business Information** for details about the business, its services, or individuals mentioned within that context. | |
Use the Search Results for current external information that was requested. | |
If information for a specific part of the question was not found in either Business Information or Search Results, use your general knowledge if possible, or state that the information could not be found. | |
Synthesize the information into a natural language response. Do NOT copy and paste raw context or strings like 'Business Information:' or 'SEARCH RESULTS:' or 'ACTION:' or the raw user query. | |
After your answer, generate a few concise follow-up questions that a user might ask based on the previous turn's conversation and your response. List these questions clearly at the end of your response. | |
When search results were used to answer the question, list the URLs from the search results you used under a "Sources:" heading at the very end. | |
""" | |
# --- Main Inference Function for Gradio --- | |
# This function will be called every time the user submits a query | |
# chat_history is now a parameter managed by Gradio's State | |
def respond(user_input, chat_history): | |
""" | |
Processes user input, performs actions (lookup/search), and generates a response. | |
Manages chat history within Gradio state. | |
""" | |
# Check if models loaded successfully | |
if model is None or tokenizer is None or embedder is None or nlp is None: | |
return chat_history + [(user_input, "Sorry, the application failed to load necessary components. Please try again later or contact the administrator.")] | |
original_user_input = user_input | |
# Initialize action results containers for this turn | |
search_results_dicts = [] | |
business_lookup_results_formatted = [] | |
response_pass1_raw = "" # To store the raw actions generated by Pass 1 | |
# --- Pre-Pass 1: Programmatic Business Info Check for Query Parts --- | |
query_parts = split_query(original_user_input) | |
business_check_results = [] | |
overall_pre_pass1_score = 0.0 | |
print("\n--- Processing new user query ---") | |
print(f"User: {user_input}") | |
print("Performing programmatic business info check on query parts...") | |
if query_parts: | |
for i, part in enumerate(query_parts): | |
match, score = retrieve_business_info(part, data, embeddings, embedder, threshold=0.0) | |
business_check_results.append({"part": part, "score": score, "match": match}) | |
print(f"- Part '{part}': Score {score:.4f}") | |
overall_pre_pass1_score = max(overall_pre_pass1_score, score) | |
else: | |
match, score = retrieve_business_info(original_user_input, data, embeddings, embedder, threshold=0.0) | |
business_check_results.append({"part": original_user_input, "score": score, "match": match}) | |
print(f"- Part '{original_user_input}': Score {score:.4f}") | |
overall_pre_pass1_score = score | |
is_likely_direct_answer = overall_pre_pass1_score < PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD and len(query_parts) <= 2 | |
# Format business check summary for Pass 1 prompt | |
business_check_summary = "Business Info Check Results for Query Parts:\n" | |
if business_check_results: | |
for result in business_check_results: | |
status = "High Relevance" if result['score'] >= PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD else "Low Relevance" | |
business_check_summary += f"- Part '{result['part']}': Score {result['score']:.4f} ({status})\n" | |
else: | |
business_check_summary += "- No parts identified or check skipped.\n" | |
business_check_summary += "\n" | |
# --- Pass 1: Action Identification (if not direct answer) --- | |
requested_actions = [] | |
answer_directly_provided = None | |
if is_likely_direct_answer: | |
print("Programmatically determined likely direct answer.") | |
response_pass1_raw = f"ACTION: ANSWER_DIRECTLY: " # Signal Pass 2 | |
else: | |
pass1_user_message_content = pass1_instructions_action.format( | |
business_check_summary=business_check_summary, | |
PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD=PRE_PASS1_BUSINESS_PART_LOOKUP_THRESHOLD # Pass threshold to prompt | |
) + "\n\nUser Query: " + user_input | |
# Create a temporary history for Pass 1 focusing only on the current turn's user query and instructions | |
temp_chat_history_pass1 = [{"role": "user", "content": pass1_user_message_content}] | |
try: | |
prompt_pass1 = tokenizer.apply_chat_template( | |
temp_chat_history_pass1, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# print("\n--- Pass 1 Prompt ---") # Debug print | |
# print(prompt_pass1) | |
# print("---------------------") | |
generation_config_pass1 = GenerationConfig( | |
max_new_tokens=200, | |
do_sample=False, | |
temperature=0.1, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
use_cache=True | |
) | |
input_ids_pass1 = tokenizer(prompt_pass1, return_tensors="pt").input_ids.to(model.device) | |
if input_ids_pass1.numel() > 0: | |
outputs_pass1 = model.generate( | |
input_ids=input_ids_pass1, | |
generation_config=generation_config_pass1, | |
) | |
prompt_length_pass1 = input_ids_pass1.shape[1] | |
if outputs_pass1.shape[1] > prompt_length_pass1: | |
generated_tokens_pass1 = outputs_pass1[0, prompt_length_pass1:] | |
response_pass1_raw = tokenizer.decode(generated_tokens_pass1, skip_special_tokens=True).strip() | |
else: | |
response_pass1_raw = "" # No actions generated | |
else: | |
response_pass1_raw = "" # Empty input | |
# print("\n--- Raw Pass 1 Response ---") # Debug print | |
# print(response_pass1_raw) | |
# print("--------------------------") | |
except Exception as e: | |
print(f"Error during Pass 1 (Action Identification): {e}") | |
# If Pass 1 fails, fallback to attempting a direct answer in Pass 2 | |
response_pass1_raw = f"ACTION: ANSWER_DIRECTLY: Error in Pass 1 - {e}" | |
# --- Parse Model's Requested Actions with Validation --- | |
# Always parse even if flagged for direct answer to handle potential Pass 1 errors | |
if response_pass1_raw: | |
lines = response_pass1_raw.strip().split('\n') | |
for line in lines: | |
line = line.strip() | |
if line.startswith(SEARCH_MARKER): | |
query = line[len(SEARCH_MARKER):].strip() | |
if query: | |
# Validate SEARCH Action | |
_, score = retrieve_business_info(query, data, embeddings, embedder, threshold=0.0) | |
if score < SEARCH_VALIDATION_THRESHOLD: | |
requested_actions.append(("SEARCH", query)) | |
print(f"Validated Search Action for '{query}' (Score: {score:.4f})") | |
else: | |
print(f"Rejected Search Action for '{query}' (Score: {score:.4f}) - Too similar to business data.") | |
elif line.startswith(BUSINESS_LOOKUP_MARKER): | |
query = line[len(BUSINESS_LOOKUP_MARKER):].strip() | |
if query: | |
# Validate Business Lookup Query | |
match, score = retrieve_business_info(query, data, embeddings, embedder, threshold=0.0) # Use low threshold for scoring | |
if score > BUSINESS_LOOKUP_VALIDATION_THRESHOLD: | |
requested_actions.append(("LOOKUP_BUSINESS_INFO", query)) | |
print(f"Validated Business Lookup Action for '{query}' (Score: {score:.4f})") | |
else: | |
print(f"Rejected Business Lookup Action for '{query}' (Score: {score:.4f}) - Below validation threshold.") | |
elif line.startswith(ANSWER_DIRECTLY_MARKER): | |
answer = line[len(ANSWER_DIRECTLY_MARKER):].strip() | |
answer_directly_provided = answer if answer else original_user_input # Use explicit answer if provided, else original query hint | |
requested_actions = [] # Clear other actions if DIRECT_ANSWER is given | |
break # Exit action parsing loop | |
# --- Execute Actions (Search and Lookup) --- | |
# Only execute actions if ANSWER_DIRECTLY was NOT the primary outcome of Pass 1 | |
# and there are validated requested actions. | |
context_for_pass2 = "" | |
if requested_actions: | |
print("Executing requested actions...") | |
for action_type, query in requested_actions: | |
if action_type == "SEARCH": | |
print(f"Performing search for: '{query}'") | |
results = perform_duckduckgo_search(query) | |
if results: | |
search_results_dicts.extend(results) | |
print(f"Found {len(results)} search results.") | |
else: | |
print(f"No search results found for '{query}'.") | |
elif action_type == "LOOKUP_BUSINESS_INFO": | |
print(f"Performing business info lookup for: '{query}'") | |
match, score = retrieve_business_info(query, data, embeddings, embedder, threshold=retrieve_business_info.__defaults__[0]) # Use default threshold for retrieval | |
print(f"Actual lookup score for '{query}': {score:.4f} (Threshold: {retrieve_business_info.__defaults__[0]})") | |
if match: | |
formatted_match = f"""Service: {match.get('Service', 'N/A')} | |
Description: {match.get('Description', 'N/A')} | |
Price: {match.get('Price', 'N/A')} | |
Available: {match.get('Available', 'N/A')}""" | |
business_lookup_results_formatted.append(formatted_match) | |
print(f"Found business info match.") | |
else: | |
print(f"No business info match found for '{query}' at threshold {retrieve_business_info.__defaults__[0]}.") | |
# --- Prepare Context for Pass 2 based on executed actions --- | |
if business_lookup_results_formatted: | |
context_for_pass2 += "Business Information (Use this for questions about the business):\n" | |
context_for_pass2 += "\n---\n".join(business_lookup_results_formatted) | |
context_for_pass2 += "\n\n" | |
if search_results_dicts: | |
context_for_pass2 += "SEARCH RESULTS (Use this for current external information):\n" | |
aggregated_search_results_formatted = [] | |
for result in search_results_dicts: | |
aggregated_search_results_formatted.append(f"Title: {result.get('title', 'N/A')}\nSnippet: {result.get('body', 'N/A')}\nURL: {result.get('href', 'N/A')}") | |
context_for_pass2 += "\n---\n".join(aggregated_search_results_formatted) + "\n\n" | |
if requested_actions and not business_lookup_results_formatted and not search_results_dicts: | |
context_for_pass2 = "Note: No relevant information was found in Business Information or via Search for your query." | |
print("Note: No results were found for the requested actions.") | |
# If ANSWER_DIRECTLY was determined (either programmatically or by Pass 1 model output) | |
if answer_directly_provided is not None: | |
print(f"Handling as direct answer: {answer_directly_provided}") | |
# Provide a simple context indicating it's a direct answer scenario | |
context_for_pass2 = "Note: This query is a simple request or greeting." | |
if answer_directly_provided != original_user_input and answer_directly_provided != "": | |
context_for_pass2 += f" Initial suggestion from action step: {answer_directly_provided}" | |
# Ensure no search/lookup results are included if it was flagged as direct answer | |
search_results_dicts = [] | |
business_lookup_results_formatted = [] | |
# If no actions were requested or direct answer flagged, and no results found... | |
# This handles cases where Pass 1 failed or generated nothing useful | |
if not requested_actions and answer_directly_provided is None: | |
if response_pass1_raw.strip(): | |
print("Warning: Pass 1 did not result in valid actions or a direct answer.") | |
context_for_pass2 = f"Error: Could not determine actions from Pass 1 response: '{response_pass1_raw}'." | |
else: | |
print("Warning: Pass 1 generated an empty response.") | |
context_for_pass2 = "Error: Pass 1 generated an empty response." | |
# In this case, we will still try Pass 2 with the limited context | |
# --- Pass 2: Synthesize and Respond --- | |
final_response = "Sorry, I couldn't generate a response." # Default response on error | |
if model is not None and tokenizer is not None: | |
pass2_user_message_content = pass2_instructions_synthesize + "\n\nOriginal User Query: " + original_user_input + "\n\n" + context_for_pass2 | |
# --- Chat History Management for Pass 2 --- | |
# Gradio's chat history state is [(User1, Bot1), (User2, Bot2), ...] | |
# We need to format the history correctly for the model template | |
# The Pass 2 prompt should build upon the *actual* conversation history, not just the Pass 2 context message. | |
# Let's build the chat history for the model template | |
model_chat_history = [] | |
for user_msg, bot_msg in chat_history: | |
model_chat_history.append({"role": "user", "content": user_msg}) | |
model_chat_history.append({"role": "assistant", "content": bot_msg}) | |
# Add the *current* user query and the Pass 2 specific content as the latest turn | |
# The Pass 2 instructions and context are part of the *current* user turn's input to the model | |
model_chat_history.append({"role": "user", "content": pass2_user_message_content}) | |
try: | |
prompt_pass2 = tokenizer.apply_chat_template( | |
model_chat_history, | |
tokenize=False, | |
add_generation_prompt=True # Add the assistant prompt token to start the response | |
) | |
# print("\n--- Pass 2 Prompt ---") # Debug print | |
# print(prompt_pass2) | |
# print("---------------------") | |
generation_config_pass2 = GenerationConfig( | |
max_new_tokens=1500, # Generate a longer response | |
do_sample=True, | |
temperature=0.7, | |
top_k=50, | |
top_p=0.95, | |
repetition_penalty=1.1, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
use_cache=True | |
) | |
input_ids_pass2 = tokenizer(prompt_pass2, return_tensors="pt").input_ids.to(model.device) | |
if input_ids_pass2.numel() > 0: | |
outputs_pass2 = model.generate( | |
input_ids=input_ids_pass2, | |
generation_config=generation_config_pass2, | |
) | |
prompt_length_pass2 = input_ids_pass2.shape[1] | |
if outputs_pass2.shape[1] > prompt_length_pass2: | |
generated_tokens_pass2 = outputs_pass2[0, prompt_length_pass2:] | |
final_response = tokenizer.decode(generated_tokens_pass2, skip_special_tokens=True).strip() | |
else: | |
final_response = "..." # Indicate potentially empty response | |
except Exception as gen_error: # <--- Error occurred here previously | |
print(f"Error during model generation in Pass 2: {gen_error}") | |
final_response = "Error generating response in Pass 2." | |
# --- Post-process Final Response from Pass 2 --- | |
cleaned_response = final_response | |
# Filter out the Pass 2 instructions and context markers that might bleed through | |
lines = cleaned_response.split('\n') | |
cleaned_lines = [line for line in lines if not line.strip().lower().startswith("business information") | |
and not line.strip().lower().startswith("search results") | |
and not line.strip().startswith("---") | |
and not line.strip().lower().startswith("original user query:") | |
and not line.strip().lower().startswith("you are a helpful assistant for a business.")] | |
cleaned_response = "\n".join(cleaned_lines).strip() | |
# Extract and list URLs from the search results that were actually used | |
# This assumes the model uses the provided snippets with URLs | |
urls_to_list = [result.get('href') for result in search_results_dicts if result.get('href')] | |
urls_to_list = list(dict.fromkeys(urls_to_list)) # Remove duplicates # <-- THIS LINE WAS THE SOURCE OF THE PREVIOUS SYNTAX ERROR | |
# Only add Sources if search was performed AND results were found | |
if search_results_dicts and urls_to_list: | |
cleaned_response += "\n\nSources:\n" + "\n".join(urls_to_list) | |
final_response = cleaned_response | |
# Check if the final response is empty or just whitespace after cleaning | |
if not final_response.strip(): | |
final_response = "Sorry, I couldn't generate a meaningful response based on the information found." | |
print("Warning: Final response was empty after cleaning.") | |
# This 'else' block is tied to the 'if model is not None and tokenizer is not None:' check much earlier in the function | |
# It seems correctly placed as a fallback if models didn't load at the start. | |
# Make sure the indentation aligns with that outer 'if'. | |
else: # Model or tokenizer not loaded | |
final_response = "Sorry, the core language model is not available." | |
print("Error: LLM model or tokenizer not loaded for Pass 2.") | |
# --- Update Chat History for Gradio --- | |
# Append the user's original message and the final bot response to the history state | |
chat_history = chat_history + [(original_user_input, final_response)] | |
# Optional: Manage history length | |
max_history_pairs = 10 # Keep last 10 turns (20 messages total) | |
if len(chat_history) > max_history_pairs: | |
chat_history = chat_history[-max_history_pairs:] | |
# print(f"History truncated. Keeping last {len(chat_history)} turns.") # Debug print | |
# Return the updated history state | |
# This return statement MUST be inside the respond function definition | |
# return "", chat_history # Return empty string for the input box, and the updated history | |
return "", chat_history + [(user_input, "Sorry, the application failed to load necessary components. Please try again later or contact the administrator.")] |