ibfs_demo / zero_shot.py
MarginallyEffective's picture
Upload folder using huggingface_hub
55b0a04 verified
import litellm
from typing import List, Dict, Any
from prompts import PROMPTS, format_prompt
from utils import save_results, generate_user_id
def zero_shot_answer(query: str) -> str:
"""
Generate a direct answer without interactive search.
Args:
query: The user's query
Returns:
The direct answer
"""
# Get prompt templates
system_template = PROMPTS["zero_shot"]["direct_answer"]["system"]
user_template = PROMPTS["zero_shot"]["direct_answer"]["user"]
# Format the prompts
format_args = {"query": query}
system_message = format_prompt(system_template, **format_args)
user_message = format_prompt(user_template, **format_args)
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
]
# Get response from LLM
response = litellm.completion(
model="gpt-4o",
messages=messages,
temperature=0.3,
max_tokens=2000
)
return response.choices[0].message.content
def start_zero_shot(query: str) -> List[Dict[str, str]]:
"""
Generate a direct answer using zero-shot approach.
Args:
query: User's query
Returns:
Chat history with the answer
"""
if not query or not query.strip():
return [{"role": "assistant", "content": "Please enter a query."}]
# Generate a unique user ID
user_id = generate_user_id()
# Initialize chat history
chat_history = []
# Add the original query to chat history
chat_history.append({"role": "user", "content": f"Query: {query}"})
try:
# Use generating message template from YAML
generating_msg = PROMPTS["zero_shot"]["ui_messages"]["generating"]
chat_history.append({"role": "assistant", "content": generating_msg})
# Generate direct answer
answer = zero_shot_answer(query)
# Format result message
result_msg = format_prompt(
PROMPTS["zero_shot"]["ui_messages"]["result"],
answer=answer
)
chat_history.append({"role": "assistant", "content": result_msg})
# Save the results
save_path = save_results(
user_id=user_id,
query=query,
final_answer=answer,
method="zero_shot"
)
saved_msg = format_prompt(
PROMPTS["zero_shot"]["ui_messages"]["saved_result"],
save_path=save_path,
user_id=user_id
)
chat_history.append({"role": "system", "content": saved_msg})
return chat_history
except Exception as e:
error_msg = format_prompt(
PROMPTS["zero_shot"]["ui_messages"]["error_general"],
error_message=str(e)
)
print(error_msg)
return [
{"role": "user", "content": f"Query: {query}"},
{"role": "assistant", "content": "I encountered an error while generating a direct answer."},
{"role": "system", "content": error_msg}
]