Spaces:
Sleeping
Sleeping
File size: 10,413 Bytes
55b0a04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 |
import re
import litellm
from typing import List, Dict, Any, Optional
# Import the necessary functions
from prompts import PROMPTS, format_prompt
from utils import save_results, generate_user_id
def generate_strategies(query: str, selected_strategy: Optional[str], k: int) -> List[str]:
"""
Generate k strategy options using the LLM.
Args:
query: The original user query
selected_strategy: Previously selected strategy (if any)
k: Number of strategies to generate
Returns:
List of strategy options
"""
# Choose the appropriate prompt template
if selected_strategy:
prompt_key = "continuation_strategies"
format_args = {
"query": query,
"selected_strategy": selected_strategy,
"k": k
}
else:
prompt_key = "initial_strategies"
format_args = {
"query": query,
"k": k
}
# Format the prompts
system_template = PROMPTS["ibfs"][prompt_key]["system"]
user_template = PROMPTS["ibfs"][prompt_key]["user"]
system_message = format_prompt(system_template, **format_args)
user_message = format_prompt(user_template, **format_args)
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
]
# Get response from LLM
response = litellm.completion(
model="gpt-4o",
messages=messages,
temperature=0.7,
max_tokens=1000
)
content = response.choices[0].message.content
# Parse strategies using regex (simplified)
strategies = re.findall(r'\d+\.\s*(.+?)(?=\n\d+\.|\Z)', content, re.DOTALL)
# Ensure we have exactly k strategies
strategies = [s.strip() for s in strategies[:k]]
# Fill in missing strategies if needed
while len(strategies) < k:
strategies.append(f"Strategy option #{len(strategies) + 1}")
return strategies
def answer_query(query: str, final_strategy: str) -> str:
"""
Generate the final answer based on the selected strategy.
Args:
query: The original user query
final_strategy: The final selected strategy
Returns:
The final answer
"""
# Get prompt templates
system_template = PROMPTS["ibfs"]["final_answer"]["system"]
user_template = PROMPTS["ibfs"]["final_answer"]["user"]
# Format the prompts
format_args = {
"query": query,
"final_strategy": final_strategy
}
system_message = format_prompt(system_template, **format_args)
user_message = format_prompt(user_template, **format_args)
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
]
response = litellm.completion(
model="gpt-4o",
messages=messages,
temperature=0.3,
max_tokens=2000
)
return response.choices[0].message.content
def start_ibfs(query: str, k: int, m: int) -> tuple:
"""
Start the IBFS process with a new query.
Args:
query: User's query
k: Branching factor
m: Depth
Returns:
Initial state and chat history
"""
if not query or not query.strip():
return None, [{"role": "assistant", "content": "Please enter a query."}]
# Generate a unique user ID
user_id = generate_user_id()
# Initialize chat history
chat_history = []
# Add the original query to chat history
chat_history.append({"role": "user", "content": f"Query: {query}"})
# Welcome message
welcome_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["welcome"],
m=m,
k=k
)
chat_history.append({"role": "assistant", "content": welcome_msg})
# Generating message
generating_msg = PROMPTS["ibfs"]["ui_messages"]["generating"]
chat_history.append({"role": "assistant", "content": generating_msg})
try:
# Generate initial strategies
strategies = generate_strategies(query, None, k)
# Format strategies for display
valid_options = ", ".join(map(str, range(1, len(strategies) + 1)))
select_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["select_strategy"],
current_step=1,
max_steps=m,
valid_options=valid_options
)
options_text = f"{select_msg}\n\n"
for idx, strategy in enumerate(strategies):
options_text += f"{idx + 1}. {strategy}\n\n"
chat_history.append({"role": "assistant", "content": options_text})
# Create state
state = {
"user_id": user_id,
"query": query,
"k": k,
"m": m,
"current_step": 0,
"strategy_path": [],
"chat_history": chat_history,
"strategies": strategies,
}
return state, chat_history
except Exception as e:
error_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["error_general"],
error_message=str(e)
)
print(error_msg)
return None, [
{"role": "user", "content": f"Query: {query}"},
{"role": "assistant", "content": "I encountered an error while starting the IBFS process."},
{"role": "system", "content": error_msg}
]
def handle_choice(state: Dict[str, Any], choice: str) -> tuple:
"""
Handle the user's choice of strategy.
Args:
state: Current state
choice: The user's selected option (1-based index as string)
Returns:
Updated state and chat history
"""
if not state:
return state, []
chat_history = state.get("chat_history", [])
# Validate choice input
if not choice or not choice.strip():
chat_history.append({"role": "system", "content": PROMPTS["ibfs"]["ui_messages"]["error_missing_choice"]})
return state, chat_history
try:
choice_idx = int(choice.strip()) - 1 # Convert to 0-based index
strategies = state.get("strategies", [])
# Check if we have strategies and if choice is valid
if not strategies:
chat_history.append({"role": "system", "content": PROMPTS["ibfs"]["ui_messages"]["error_no_strategies"]})
return state, chat_history
if choice_idx < 0 or choice_idx >= len(strategies):
error_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["error_invalid_choice"],
max_option=len(strategies)
)
chat_history.append({"role": "system", "content": error_msg})
return state, chat_history
# Get the selected strategy and update path
selected_strategy = strategies[choice_idx]
strategy_path = state.get("strategy_path", [])
strategy_path.append(selected_strategy)
# Record user's choice
chat_history.append({"role": "user", "content": f"I choose option {choice_idx + 1}: {selected_strategy}"})
# Update current step
current_step = state.get("current_step", 0) + 1
m = state.get("m", 2)
# Check if we've reached the final step
if current_step >= m:
# Generate final answer
final_processing_msg = PROMPTS["ibfs"]["ui_messages"]["final_processing"]
chat_history.append({"role": "assistant", "content": final_processing_msg})
query = state.get("query", "")
final_answer = answer_query(query, selected_strategy)
# Format final result
result_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["final_result"],
final_strategy=selected_strategy,
final_answer=final_answer
)
chat_history.append({"role": "assistant", "content": result_msg})
# Save results
user_id = state.get("user_id", "")
save_path = save_results(
user_id=user_id,
query=query,
final_answer=final_answer,
method="ibfs",
strategy_path=strategy_path
)
saved_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["saved_result"],
save_path=save_path,
user_id=user_id
)
chat_history.append({"role": "system", "content": saved_msg})
# Reset state for a new query
state = {
"user_id": user_id,
"current_step": 0,
"strategy_path": [],
"strategies": [],
"chat_history": chat_history,
}
return state, chat_history
# If we're not at the final step, generate new strategies
k = state.get("k", 3)
query = state.get("query", "")
# Generate sub-strategies for next step
sub_strategies = generate_strategies(query, selected_strategy, k)
# Format strategies for display
valid_options = ", ".join(map(str, range(1, len(sub_strategies) + 1)))
select_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["select_strategy"],
current_step=current_step + 1,
max_steps=m,
valid_options=valid_options
)
options_text = f"{select_msg}\n\n"
for idx, strategy in enumerate(sub_strategies):
options_text += f"{idx + 1}. {strategy}\n\n"
chat_history.append({"role": "assistant", "content": options_text})
# Update state
state.update({
"current_step": current_step,
"strategy_path": strategy_path,
"chat_history": chat_history,
"strategies": sub_strategies,
})
return state, chat_history
except ValueError:
error_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["error_invalid_number"],
choice=choice
)
chat_history.append({"role": "system", "content": error_msg})
return state, chat_history
except Exception as e:
error_msg = format_prompt(
PROMPTS["ibfs"]["ui_messages"]["error_general"],
error_message=str(e)
)
print(error_msg)
chat_history.append({"role": "system", "content": error_msg})
return state, chat_history |