ibfs_demo / ibfs.py
MarginallyEffective's picture
Upload folder using huggingface_hub
55b0a04 verified
import re
import litellm
from typing import List, Dict, Any, Optional
# Import the necessary functions
from prompts import PROMPTS, format_prompt
from utils import save_results, generate_user_id
def generate_strategies(query: str, selected_strategy: Optional[str], k: int) -> List[str]:
"""
Generate k strategy options using the LLM.
Args:
query: The original user query
selected_strategy: Previously selected strategy (if any)
k: Number of strategies to generate
Returns:
List of strategy options
"""
# Choose the appropriate prompt template
if selected_strategy:
prompt_key = "continuation_strategies"
format_args = {
"query": query,
"selected_strategy": selected_strategy,
"k": k
}
else:
prompt_key = "initial_strategies"
format_args = {
"query": query,
"k": k
}
# Format the prompts
system_template = PROMPTS["ibfs"][prompt_key]["system"]
user_template = PROMPTS["ibfs"][prompt_key]["user"]
system_message = format_prompt(system_template, **format_args)
user_message = format_prompt(user_template, **format_args)
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
]
# Get response from LLM
response = litellm.completion(
model="gpt-4o",
messages=messages,
temperature=0.7,
max_tokens=1000
)
content = response.choices[0].message.content
# Parse strategies using regex (simplified)
strategies = re.findall(r'\d+\.\s*(.+?)(?=\n\d+\.|\Z)', content, re.DOTALL)
# Ensure we have exactly k strategies
strategies = [s.strip() for s in strategies[:k]]
# Fill in missing strategies if needed
while len(strategies) < k:
strategies.append(f"Strategy option #{len(strategies) + 1}")
return strategies
def answer_query(query: str, final_strategy: str) -> str:
"""
Generate the final answer based on the selected strategy.
Args:
query: The original user query
final_strategy: The final selected strategy
Returns:
The final answer
"""
# Get prompt templates
system_template = PROMPTS["ibfs"]["final_answer"]["system"]
user_template = PROMPTS["ibfs"]["final_answer"]["user"]
# Format the prompts
format_args = {
"query": query,
"final_strategy": final_strategy
}
system_message = format_prompt(system_template, **format_args)
user_message = format_prompt(user_template, **format_args)
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
]
response = litellm.completion(
model="gpt-4o",
messages=messages,
temperature=0.3,
max_tokens=2000
)
return response.choices[0].message.content
def start_ibfs(query: str, k: int, m: int) -> tuple:
"""
Start the IBFS process with a new query.
Args:
query: User's query
k: Branching factor
m: Depth
Returns:
Initial state and chat history
"""
if not query or not query.strip():
return None, [{"role": "assistant", "content": "Please enter a query."}]
# Generate a unique user ID
user_id = generate_user_id()
# Initialize chat history
chat_history = []
# Add the original query to chat history
chat_history.append({"role": "user", "content": f"Query: {query}"})
# Welcome message
welcome_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["welcome"],
m=m,
k=k
)
chat_history.append({"role": "assistant", "content": welcome_msg})
# Generating message
generating_msg = PROMPTS["ibfs"]["ui_messages"]["generating"]
chat_history.append({"role": "assistant", "content": generating_msg})
try:
# Generate initial strategies
strategies = generate_strategies(query, None, k)
# Format strategies for display
valid_options = ", ".join(map(str, range(1, len(strategies) + 1)))
select_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["select_strategy"],
current_step=1,
max_steps=m,
valid_options=valid_options
)
options_text = f"{select_msg}\n\n"
for idx, strategy in enumerate(strategies):
options_text += f"{idx + 1}. {strategy}\n\n"
chat_history.append({"role": "assistant", "content": options_text})
# Create state
state = {
"user_id": user_id,
"query": query,
"k": k,
"m": m,
"current_step": 0,
"strategy_path": [],
"chat_history": chat_history,
"strategies": strategies,
}
return state, chat_history
except Exception as e:
error_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["error_general"],
error_message=str(e)
)
print(error_msg)
return None, [
{"role": "user", "content": f"Query: {query}"},
{"role": "assistant", "content": "I encountered an error while starting the IBFS process."},
{"role": "system", "content": error_msg}
]
def handle_choice(state: Dict[str, Any], choice: str) -> tuple:
"""
Handle the user's choice of strategy.
Args:
state: Current state
choice: The user's selected option (1-based index as string)
Returns:
Updated state and chat history
"""
if not state:
return state, []
chat_history = state.get("chat_history", [])
# Validate choice input
if not choice or not choice.strip():
chat_history.append({"role": "system", "content": PROMPTS["ibfs"]["ui_messages"]["error_missing_choice"]})
return state, chat_history
try:
choice_idx = int(choice.strip()) - 1 # Convert to 0-based index
strategies = state.get("strategies", [])
# Check if we have strategies and if choice is valid
if not strategies:
chat_history.append({"role": "system", "content": PROMPTS["ibfs"]["ui_messages"]["error_no_strategies"]})
return state, chat_history
if choice_idx < 0 or choice_idx >= len(strategies):
error_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["error_invalid_choice"],
max_option=len(strategies)
)
chat_history.append({"role": "system", "content": error_msg})
return state, chat_history
# Get the selected strategy and update path
selected_strategy = strategies[choice_idx]
strategy_path = state.get("strategy_path", [])
strategy_path.append(selected_strategy)
# Record user's choice
chat_history.append({"role": "user", "content": f"I choose option {choice_idx + 1}: {selected_strategy}"})
# Update current step
current_step = state.get("current_step", 0) + 1
m = state.get("m", 2)
# Check if we've reached the final step
if current_step >= m:
# Generate final answer
final_processing_msg = PROMPTS["ibfs"]["ui_messages"]["final_processing"]
chat_history.append({"role": "assistant", "content": final_processing_msg})
query = state.get("query", "")
final_answer = answer_query(query, selected_strategy)
# Format final result
result_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["final_result"],
final_strategy=selected_strategy,
final_answer=final_answer
)
chat_history.append({"role": "assistant", "content": result_msg})
# Save results
user_id = state.get("user_id", "")
save_path = save_results(
user_id=user_id,
query=query,
final_answer=final_answer,
method="ibfs",
strategy_path=strategy_path
)
saved_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["saved_result"],
save_path=save_path,
user_id=user_id
)
chat_history.append({"role": "system", "content": saved_msg})
# Reset state for a new query
state = {
"user_id": user_id,
"current_step": 0,
"strategy_path": [],
"strategies": [],
"chat_history": chat_history,
}
return state, chat_history
# If we're not at the final step, generate new strategies
k = state.get("k", 3)
query = state.get("query", "")
# Generate sub-strategies for next step
sub_strategies = generate_strategies(query, selected_strategy, k)
# Format strategies for display
valid_options = ", ".join(map(str, range(1, len(sub_strategies) + 1)))
select_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["select_strategy"],
current_step=current_step + 1,
max_steps=m,
valid_options=valid_options
)
options_text = f"{select_msg}\n\n"
for idx, strategy in enumerate(sub_strategies):
options_text += f"{idx + 1}. {strategy}\n\n"
chat_history.append({"role": "assistant", "content": options_text})
# Update state
state.update({
"current_step": current_step,
"strategy_path": strategy_path,
"chat_history": chat_history,
"strategies": sub_strategies,
})
return state, chat_history
except ValueError:
error_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["error_invalid_number"],
choice=choice
)
chat_history.append({"role": "system", "content": error_msg})
return state, chat_history
except Exception as e:
error_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["error_general"],
error_message=str(e)
)
print(error_msg)
chat_history.append({"role": "system", "content": error_msg})
return state, chat_history