File size: 1,599 Bytes
751d628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import ssl
import aiohttp
import logging
from typing import Dict
from urllib.parse import urljoin

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(name)s - %(message)s')
logger = logging.getLogger(__name__)

async def fetch_task_file(task_id: str, question: str) -> Dict[str, bytes]:
    """
    Fetch a file associated with a task from the GAIA API.
    Returns a dictionary of file extensions to content.
    """
    results = {}
    base_url = "https://gaia-benchmark-api.hf.space/files/"  # Updated URL
    extensions = ["xlsx", "csv", "pdf", "txt", "mp3", "jpg", "png"]

    ssl_context = ssl.create_default_context()
    ssl_context.check_hostname = False
    ssl_context.verify_mode = ssl.CERT_NONE

    async with aiohttp.ClientSession(
        connector=aiohttp.TCPConnector(ssl_context=ssl_context),
        timeout=aiohttp.ClientTimeout(total=30)
    ) as session:
        for ext in extensions:
            file_url = urljoin(base_url, f"{task_id}/{task_id}.{ext}")
            try:
                async with session.get(file_url) as response:
                    if response.status == 200:
                        content = await response.read()
                        results[ext] = content
                        logger.info(f"Fetched {ext} for task {task_id}")
                    else:
                        logger.warning(f"No {ext} for task {task_id}: HTTP {response.status}")
            except Exception as e:
                logger.warning(f"Error fetching {ext} for task {task_id}: {str(e)}")

    return results