|
import importlib |
|
import logging |
|
import os |
|
|
|
import requests |
|
import yaml |
|
from dotenv import find_dotenv, load_dotenv |
|
from litellm._logging import _disable_debugging |
|
from openinference.instrumentation.smolagents import SmolagentsInstrumentor |
|
from phoenix.otel import register |
|
|
|
|
|
from smolagents import CodeAgent, LiteLLMModel |
|
from smolagents.default_tools import ( |
|
DuckDuckGoSearchTool, |
|
VisitWebpageTool, |
|
WikipediaSearchTool, |
|
) |
|
from smolagents.monitoring import LogLevel |
|
|
|
from agents.data_agent.agent import create_data_agent |
|
from agents.media_agent.agent import create_media_agent |
|
from agents.web_agent.agent import create_web_agent |
|
from utils import extract_final_answer |
|
|
|
_disable_debugging() |
|
|
|
|
|
register() |
|
SmolagentsInstrumentor().instrument() |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
load_dotenv(find_dotenv()) |
|
|
|
API_BASE = os.getenv("API_BASE") |
|
API_KEY = os.getenv("API_KEY") |
|
MODEL_ID = os.getenv("MODEL_ID") |
|
|
|
model = LiteLLMModel( |
|
api_base=API_BASE, |
|
api_key=API_KEY, |
|
model_id=MODEL_ID, |
|
) |
|
|
|
data_agent = create_data_agent(model) |
|
media_agent = create_media_agent(model) |
|
web_agent = create_web_agent(model) |
|
|
|
prompt_templates = yaml.safe_load( |
|
importlib.resources.files("smolagents.prompts") |
|
.joinpath("code_agent.yaml") |
|
.read_text() |
|
) |
|
|
|
agent = CodeAgent( |
|
|
|
additional_authorized_imports=[ |
|
"json", |
|
"pandas", |
|
"numpy", |
|
"re", |
|
|
|
|
|
], |
|
|
|
|
|
model=model, |
|
prompt_templates=prompt_templates, |
|
tools=[ |
|
DuckDuckGoSearchTool(max_results=3), |
|
VisitWebpageTool(max_output_length=1024), |
|
WikipediaSearchTool(), |
|
], |
|
step_callbacks=None, |
|
verbosity_level=LogLevel.ERROR, |
|
) |
|
|
|
agent.visualize() |
|
|
|
|
|
def main(task: str): |
|
|
|
gaia_task = f"""Instructions: |
|
1. Your response must contain ONLY the answer to the question, nothing else |
|
2. Do not repeat the question or any part of it |
|
3. Do not include any explanations, reasoning, or context |
|
4. Do not include source attribution or references |
|
5. Do not use phrases like "The answer is" or "I found that" |
|
6. Do not include any formatting, bullet points, or line breaks |
|
7. If the answer is a number, return only the number |
|
8. If the answer requires multiple items, separate them with commas |
|
9. If the answer requires ordering, maintain the specified order |
|
10. Use the most direct and succinct form possible |
|
|
|
{task}""" |
|
|
|
result = agent.run( |
|
additional_args=None, |
|
images=None, |
|
max_steps=5, |
|
reset=True, |
|
stream=False, |
|
task=gaia_task, |
|
) |
|
|
|
logger.info(f"Result: {result}") |
|
|
|
return extract_final_answer(result) |
|
|
|
|
|
if __name__ == "__main__": |
|
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
|
|
|
api_url = DEFAULT_API_URL |
|
questions_url = f"{api_url}/questions" |
|
submit_url = f"{api_url}/submit" |
|
|
|
response = requests.get(questions_url, timeout=15) |
|
response.raise_for_status() |
|
questions_data = response.json() |
|
|
|
for question_data in questions_data[:1]: |
|
file_name = question_data["file_name"] |
|
level = question_data["Level"] |
|
question = question_data["question"] |
|
task_id = question_data["task_id"] |
|
|
|
logger.info(f"Question: {question}") |
|
|
|
if file_name: |
|
logger.info(f"File Name: {file_name}") |
|
|
|
|
|
final_answer = main(question) |
|
logger.info(f"Final Answer: {final_answer}") |
|
logger.info("--------------------------------") |
|
|