Spaces:
Sleeping
Sleeping
import re | |
import litellm | |
from typing import List, Dict, Any, Optional | |
# Import the necessary functions | |
from prompts import PROMPTS, format_prompt | |
from utils import save_results, generate_user_id | |
def generate_strategies(query: str, selected_strategy: Optional[str], k: int) -> List[str]: | |
""" | |
Generate k strategy options using the LLM. | |
Args: | |
query: The original user query | |
selected_strategy: Previously selected strategy (if any) | |
k: Number of strategies to generate | |
Returns: | |
List of strategy options | |
""" | |
# Choose the appropriate prompt template | |
if selected_strategy: | |
prompt_key = "continuation_strategies" | |
format_args = { | |
"query": query, | |
"selected_strategy": selected_strategy, | |
"k": k | |
} | |
else: | |
prompt_key = "initial_strategies" | |
format_args = { | |
"query": query, | |
"k": k | |
} | |
# Format the prompts | |
system_template = PROMPTS["ibfs"][prompt_key]["system"] | |
user_template = PROMPTS["ibfs"][prompt_key]["user"] | |
system_message = format_prompt(system_template, **format_args) | |
user_message = format_prompt(user_template, **format_args) | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
] | |
# Get response from LLM | |
response = litellm.completion( | |
model="gpt-4o", | |
messages=messages, | |
temperature=0.7, | |
max_tokens=1000 | |
) | |
content = response.choices[0].message.content | |
# Parse strategies using regex (simplified) | |
strategies = re.findall(r'\d+\.\s*(.+?)(?=\n\d+\.|\Z)', content, re.DOTALL) | |
# Ensure we have exactly k strategies | |
strategies = [s.strip() for s in strategies[:k]] | |
# Fill in missing strategies if needed | |
while len(strategies) < k: | |
strategies.append(f"Strategy option #{len(strategies) + 1}") | |
return strategies | |
def answer_query(query: str, final_strategy: str) -> str: | |
""" | |
Generate the final answer based on the selected strategy. | |
Args: | |
query: The original user query | |
final_strategy: The final selected strategy | |
Returns: | |
The final answer | |
""" | |
# Get prompt templates | |
system_template = PROMPTS["ibfs"]["final_answer"]["system"] | |
user_template = PROMPTS["ibfs"]["final_answer"]["user"] | |
# Format the prompts | |
format_args = { | |
"query": query, | |
"final_strategy": final_strategy | |
} | |
system_message = format_prompt(system_template, **format_args) | |
user_message = format_prompt(user_template, **format_args) | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": user_message} | |
] | |
response = litellm.completion( | |
model="gpt-4o", | |
messages=messages, | |
temperature=0.3, | |
max_tokens=2000 | |
) | |
return response.choices[0].message.content | |
def start_ibfs(query: str, k: int, m: int) -> tuple: | |
""" | |
Start the IBFS process with a new query. | |
Args: | |
query: User's query | |
k: Branching factor | |
m: Depth | |
Returns: | |
Initial state and chat history | |
""" | |
if not query or not query.strip(): | |
return None, [{"role": "assistant", "content": "Please enter a query."}] | |
# Generate a unique user ID | |
user_id = generate_user_id() | |
# Initialize chat history | |
chat_history = [] | |
# Add the original query to chat history | |
chat_history.append({"role": "user", "content": f"Query: {query}"}) | |
# Welcome message | |
welcome_msg = format_prompt( | |
PROMPTS["ibfs"]["ui_messages"]["welcome"], | |
m=m, | |
k=k | |
) | |
chat_history.append({"role": "assistant", "content": welcome_msg}) | |
# Generating message | |
generating_msg = PROMPTS["ibfs"]["ui_messages"]["generating"] | |
chat_history.append({"role": "assistant", "content": generating_msg}) | |
try: | |
# Generate initial strategies | |
strategies = generate_strategies(query, None, k) | |
# Format strategies for display | |
valid_options = ", ".join(map(str, range(1, len(strategies) + 1))) | |
select_msg = format_prompt( | |
PROMPTS["ibfs"]["ui_messages"]["select_strategy"], | |
current_step=1, | |
max_steps=m, | |
valid_options=valid_options | |
) | |
options_text = f"{select_msg}\n\n" | |
for idx, strategy in enumerate(strategies): | |
options_text += f"{idx + 1}. {strategy}\n\n" | |
chat_history.append({"role": "assistant", "content": options_text}) | |
# Create state | |
state = { | |
"user_id": user_id, | |
"query": query, | |
"k": k, | |
"m": m, | |
"current_step": 0, | |
"strategy_path": [], | |
"chat_history": chat_history, | |
"strategies": strategies, | |
} | |
return state, chat_history | |
except Exception as e: | |
error_msg = format_prompt( | |
PROMPTS["ibfs"]["ui_messages"]["error_general"], | |
error_message=str(e) | |
) | |
print(error_msg) | |
return None, [ | |
{"role": "user", "content": f"Query: {query}"}, | |
{"role": "assistant", "content": "I encountered an error while starting the IBFS process."}, | |
{"role": "system", "content": error_msg} | |
] | |
def handle_choice(state: Dict[str, Any], choice: str) -> tuple: | |
""" | |
Handle the user's choice of strategy. | |
Args: | |
state: Current state | |
choice: The user's selected option (1-based index as string) | |
Returns: | |
Updated state and chat history | |
""" | |
if not state: | |
return state, [] | |
chat_history = state.get("chat_history", []) | |
# Validate choice input | |
if not choice or not choice.strip(): | |
chat_history.append({"role": "system", "content": PROMPTS["ibfs"]["ui_messages"]["error_missing_choice"]}) | |
return state, chat_history | |
try: | |
choice_idx = int(choice.strip()) - 1 # Convert to 0-based index | |
strategies = state.get("strategies", []) | |
# Check if we have strategies and if choice is valid | |
if not strategies: | |
chat_history.append({"role": "system", "content": PROMPTS["ibfs"]["ui_messages"]["error_no_strategies"]}) | |
return state, chat_history | |
if choice_idx < 0 or choice_idx >= len(strategies): | |
error_msg = format_prompt( | |
PROMPTS["ibfs"]["ui_messages"]["error_invalid_choice"], | |
max_option=len(strategies) | |
) | |
chat_history.append({"role": "system", "content": error_msg}) | |
return state, chat_history | |
# Get the selected strategy and update path | |
selected_strategy = strategies[choice_idx] | |
strategy_path = state.get("strategy_path", []) | |
strategy_path.append(selected_strategy) | |
# Record user's choice | |
chat_history.append({"role": "user", "content": f"I choose option {choice_idx + 1}: {selected_strategy}"}) | |
# Update current step | |
current_step = state.get("current_step", 0) + 1 | |
m = state.get("m", 2) | |
# Check if we've reached the final step | |
if current_step >= m: | |
# Generate final answer | |
final_processing_msg = PROMPTS["ibfs"]["ui_messages"]["final_processing"] | |
chat_history.append({"role": "assistant", "content": final_processing_msg}) | |
query = state.get("query", "") | |
final_answer = answer_query(query, selected_strategy) | |
# Format final result | |
result_msg = format_prompt( | |
PROMPTS["ibfs"]["ui_messages"]["final_result"], | |
final_strategy=selected_strategy, | |
final_answer=final_answer | |
) | |
chat_history.append({"role": "assistant", "content": result_msg}) | |
# Save results | |
user_id = state.get("user_id", "") | |
save_path = save_results( | |
user_id=user_id, | |
query=query, | |
final_answer=final_answer, | |
method="ibfs", | |
strategy_path=strategy_path | |
) | |
saved_msg = format_prompt( | |
PROMPTS["ibfs"]["ui_messages"]["saved_result"], | |
save_path=save_path, | |
user_id=user_id | |
) | |
chat_history.append({"role": "system", "content": saved_msg}) | |
# Reset state for a new query | |
state = { | |
"user_id": user_id, | |
"current_step": 0, | |
"strategy_path": [], | |
"strategies": [], | |
"chat_history": chat_history, | |
} | |
return state, chat_history | |
# If we're not at the final step, generate new strategies | |
k = state.get("k", 3) | |
query = state.get("query", "") | |
# Generate sub-strategies for next step | |
sub_strategies = generate_strategies(query, selected_strategy, k) | |
# Format strategies for display | |
valid_options = ", ".join(map(str, range(1, len(sub_strategies) + 1))) | |
select_msg = format_prompt( | |
PROMPTS["ibfs"]["ui_messages"]["select_strategy"], | |
current_step=current_step + 1, | |
max_steps=m, | |
valid_options=valid_options | |
) | |
options_text = f"{select_msg}\n\n" | |
for idx, strategy in enumerate(sub_strategies): | |
options_text += f"{idx + 1}. {strategy}\n\n" | |
chat_history.append({"role": "assistant", "content": options_text}) | |
# Update state | |
state.update({ | |
"current_step": current_step, | |
"strategy_path": strategy_path, | |
"chat_history": chat_history, | |
"strategies": sub_strategies, | |
}) | |
return state, chat_history | |
except ValueError: | |
error_msg = format_prompt( | |
PROMPTS["ibfs"]["ui_messages"]["error_invalid_number"], | |
choice=choice | |
) | |
chat_history.append({"role": "system", "content": error_msg}) | |
return state, chat_history | |
except Exception as e: | |
error_msg = format_prompt( | |
PROMPTS["ibfs"]["ui_messages"]["error_general"], | |
error_message=str(e) | |
) | |
print(error_msg) | |
chat_history.append({"role": "system", "content": error_msg}) | |
return state, chat_history |