Spaces:
Runtime error
Runtime error
Deploy full application code
Browse files- .gitignore +39 -0
- Dockerfile +8 -5
- app.py +3 -13
- app/api/__init__.py +7 -0
- app/api/datasets.py +151 -0
- app/core/celery_app.py +98 -0
- app/core/config.py +48 -0
- app/main.py +46 -0
- app/schemas/dataset.py +81 -0
- app/schemas/dataset_common.py +17 -0
- app/services/hf_datasets.py +501 -0
- app/services/redis_client.py +302 -0
- app/tasks/dataset_tasks.py +73 -0
- migrations/20250620000000_create_combined_datasets_table.sql +57 -0
- setup.py +8 -0
- tests/test_datasets.py +78 -0
- tests/test_datasets_api.py +88 -0
.gitignore
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
build/
|
8 |
+
develop-eggs/
|
9 |
+
dist/
|
10 |
+
downloads/
|
11 |
+
eggs/
|
12 |
+
.eggs/
|
13 |
+
lib/
|
14 |
+
lib64/
|
15 |
+
parts/
|
16 |
+
sdist/
|
17 |
+
var/
|
18 |
+
wheels/
|
19 |
+
*.egg-info/
|
20 |
+
.installed.cfg
|
21 |
+
*.egg
|
22 |
+
|
23 |
+
# Environment
|
24 |
+
.env
|
25 |
+
.venv
|
26 |
+
env/
|
27 |
+
venv/
|
28 |
+
ENV/
|
29 |
+
|
30 |
+
# Logs
|
31 |
+
*.log
|
32 |
+
logs/
|
33 |
+
celery_worker_*.log
|
34 |
+
nohup.out
|
35 |
+
|
36 |
+
# Database
|
37 |
+
*.sqlite
|
38 |
+
*.db
|
39 |
+
celerybeat-schedule
|
Dockerfile
CHANGED
@@ -1,14 +1,17 @@
|
|
1 |
# Use the official Python 3.10.9 image
|
2 |
FROM python:3.10.9
|
3 |
|
4 |
-
#
|
5 |
-
|
6 |
|
7 |
-
#
|
8 |
-
|
9 |
|
10 |
-
# Install requirements.txt
|
11 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
12 |
|
|
|
|
|
|
|
13 |
# Start the FastAPI app on port 7860, the default port expected by Spaces
|
14 |
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
|
|
1 |
# Use the official Python 3.10.9 image
|
2 |
FROM python:3.10.9
|
3 |
|
4 |
+
# Set the working directory
|
5 |
+
WORKDIR /app
|
6 |
|
7 |
+
# Copy the current directory contents into the container
|
8 |
+
COPY . .
|
9 |
|
10 |
+
# Install requirements.txt
|
11 |
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
12 |
|
13 |
+
# Install the application in development mode
|
14 |
+
RUN pip install -e .
|
15 |
+
|
16 |
# Start the FastAPI app on port 7860, the default port expected by Spaces
|
17 |
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
CHANGED
@@ -1,17 +1,7 @@
|
|
1 |
-
from
|
2 |
-
import uvicorn
|
3 |
|
4 |
-
#
|
5 |
-
app = FastAPI(title="Collinear API")
|
6 |
-
|
7 |
-
@app.get("/")
|
8 |
-
async def root():
|
9 |
-
return {"message": "Welcome to Collinear API"}
|
10 |
-
|
11 |
-
@app.get("/health")
|
12 |
-
async def health():
|
13 |
-
return {"status": "healthy"}
|
14 |
|
15 |
if __name__ == "__main__":
|
16 |
-
|
17 |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
1 |
+
from app.main import app
|
|
|
2 |
|
3 |
+
# This file is needed for Hugging Face Spaces to find the app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
if __name__ == "__main__":
|
6 |
+
import uvicorn
|
7 |
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
|
app/api/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter
|
2 |
+
from app.api.datasets import router as datasets_router
|
3 |
+
# from . import batch # Removed batch import
|
4 |
+
|
5 |
+
api_router = APIRouter()
|
6 |
+
api_router.include_router(datasets_router, tags=["datasets"])
|
7 |
+
# api_router.include_router(batch.router) # Removed batch router
|
app/api/datasets.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, Query, HTTPException
|
2 |
+
from typing import List, Optional, Dict, Any, Set
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from fastapi.concurrency import run_in_threadpool
|
5 |
+
from app.services.hf_datasets import (
|
6 |
+
get_dataset_commits,
|
7 |
+
get_dataset_files,
|
8 |
+
get_file_url,
|
9 |
+
get_datasets_page_from_zset,
|
10 |
+
get_dataset_commits_async,
|
11 |
+
get_dataset_files_async,
|
12 |
+
get_file_url_async,
|
13 |
+
get_datasets_page_from_cache,
|
14 |
+
fetch_and_cache_all_datasets,
|
15 |
+
)
|
16 |
+
from app.services.redis_client import cache_get
|
17 |
+
import logging
|
18 |
+
import time
|
19 |
+
from fastapi.responses import JSONResponse
|
20 |
+
import os
|
21 |
+
|
22 |
+
router = APIRouter(prefix="/datasets", tags=["datasets"])
|
23 |
+
log = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
SIZE_LOW = 100 * 1024 * 1024
|
26 |
+
SIZE_MEDIUM = 1024 * 1024 * 1024
|
27 |
+
|
28 |
+
class DatasetInfo(BaseModel):
|
29 |
+
id: str
|
30 |
+
name: Optional[str]
|
31 |
+
description: Optional[str]
|
32 |
+
size_bytes: Optional[int]
|
33 |
+
impact_level: Optional[str]
|
34 |
+
downloads: Optional[int]
|
35 |
+
likes: Optional[int]
|
36 |
+
tags: Optional[List[str]]
|
37 |
+
class Config:
|
38 |
+
extra = "ignore"
|
39 |
+
|
40 |
+
class PaginatedDatasets(BaseModel):
|
41 |
+
total: int
|
42 |
+
items: List[DatasetInfo]
|
43 |
+
|
44 |
+
class CommitInfo(BaseModel):
|
45 |
+
id: str
|
46 |
+
title: Optional[str]
|
47 |
+
message: Optional[str]
|
48 |
+
author: Optional[Dict[str, Any]]
|
49 |
+
date: Optional[str]
|
50 |
+
|
51 |
+
class CacheStatus(BaseModel):
|
52 |
+
last_update: Optional[str]
|
53 |
+
total_items: int
|
54 |
+
warming_up: bool
|
55 |
+
|
56 |
+
def deduplicate_by_id(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
57 |
+
seen: Set[str] = set()
|
58 |
+
unique_items = []
|
59 |
+
for item in items:
|
60 |
+
item_id = item.get("id")
|
61 |
+
if item_id and item_id not in seen:
|
62 |
+
seen.add(item_id)
|
63 |
+
unique_items.append(item)
|
64 |
+
return unique_items
|
65 |
+
|
66 |
+
@router.get("/cache-status", response_model=CacheStatus)
|
67 |
+
async def cache_status():
|
68 |
+
meta = await cache_get("hf:datasets:meta")
|
69 |
+
last_update = meta["last_update"] if meta and "last_update" in meta else None
|
70 |
+
total_items = meta["total_items"] if meta and "total_items" in meta else 0
|
71 |
+
warming_up = not bool(total_items)
|
72 |
+
return CacheStatus(last_update=last_update, total_items=total_items, warming_up=warming_up)
|
73 |
+
|
74 |
+
@router.get("/", response_model=None)
|
75 |
+
async def list_datasets(
|
76 |
+
limit: int = Query(10, ge=1, le=1000),
|
77 |
+
offset: int = Query(0, ge=0),
|
78 |
+
search: str = Query(None, description="Search term for dataset id or description"),
|
79 |
+
sort_by: str = Query(None, description="Field to sort by (e.g., 'downloads', 'likes', 'created_at')"),
|
80 |
+
sort_order: str = Query("desc", regex="^(asc|desc)$", description="Sort order: 'asc' or 'desc'"),
|
81 |
+
):
|
82 |
+
# Fetch the full list from cache
|
83 |
+
result, status = get_datasets_page_from_cache(1000000, 0) # get all for in-memory filtering
|
84 |
+
if status != 200:
|
85 |
+
return JSONResponse(result, status_code=status)
|
86 |
+
items = result["items"]
|
87 |
+
# Filtering
|
88 |
+
if search:
|
89 |
+
items = [d for d in items if search.lower() in (d.get("id", "") + " " + str(d.get("description", "")).lower())]
|
90 |
+
# Sorting
|
91 |
+
if sort_by:
|
92 |
+
items = sorted(items, key=lambda d: d.get(sort_by) or 0, reverse=(sort_order == "desc"))
|
93 |
+
# Pagination
|
94 |
+
total = len(items)
|
95 |
+
page = items[offset:offset+limit]
|
96 |
+
total_pages = (total + limit - 1) // limit
|
97 |
+
current_page = (offset // limit) + 1
|
98 |
+
next_page = current_page + 1 if offset + limit < total else None
|
99 |
+
prev_page = current_page - 1 if current_page > 1 else None
|
100 |
+
return {
|
101 |
+
"total": total,
|
102 |
+
"current_page": current_page,
|
103 |
+
"total_pages": total_pages,
|
104 |
+
"next_page": next_page,
|
105 |
+
"prev_page": prev_page,
|
106 |
+
"items": page
|
107 |
+
}
|
108 |
+
|
109 |
+
@router.get("/{dataset_id:path}/commits", response_model=List[CommitInfo])
|
110 |
+
async def get_commits(dataset_id: str):
|
111 |
+
"""
|
112 |
+
Get commit history for a dataset.
|
113 |
+
"""
|
114 |
+
try:
|
115 |
+
return await get_dataset_commits_async(dataset_id)
|
116 |
+
except Exception as e:
|
117 |
+
log.error(f"Error fetching commits for {dataset_id}: {e}")
|
118 |
+
raise HTTPException(status_code=404, detail=f"Could not fetch commits: {e}")
|
119 |
+
|
120 |
+
@router.get("/{dataset_id:path}/files", response_model=List[str])
|
121 |
+
async def list_files(dataset_id: str):
|
122 |
+
"""
|
123 |
+
List files in a dataset.
|
124 |
+
"""
|
125 |
+
try:
|
126 |
+
return await get_dataset_files_async(dataset_id)
|
127 |
+
except Exception as e:
|
128 |
+
log.error(f"Error listing files for {dataset_id}: {e}")
|
129 |
+
raise HTTPException(status_code=404, detail=f"Could not list files: {e}")
|
130 |
+
|
131 |
+
@router.get("/{dataset_id:path}/file-url")
|
132 |
+
async def get_file_url_endpoint(dataset_id: str, filename: str = Query(...), revision: Optional[str] = None):
|
133 |
+
"""
|
134 |
+
Get download URL for a file in a dataset.
|
135 |
+
"""
|
136 |
+
url = await get_file_url_async(dataset_id, filename, revision)
|
137 |
+
return {"download_url": url}
|
138 |
+
|
139 |
+
@router.get("/meta")
|
140 |
+
async def get_datasets_meta():
|
141 |
+
meta = await cache_get("hf:datasets:meta")
|
142 |
+
return meta if meta else {}
|
143 |
+
|
144 |
+
# Endpoint to trigger cache refresh manually (for admin/testing)
|
145 |
+
@router.post("/datasets/refresh-cache")
|
146 |
+
def refresh_cache():
|
147 |
+
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
148 |
+
if not token:
|
149 |
+
return JSONResponse({"error": "HUGGINGFACEHUB_API_TOKEN not set"}, status_code=500)
|
150 |
+
count = fetch_and_cache_all_datasets(token)
|
151 |
+
return {"status": "ok", "cached": count}
|
app/core/celery_app.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Celery configuration for task processing."""
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from celery import Celery
|
5 |
+
from celery.signals import task_failure, task_success, worker_ready, worker_shutdown
|
6 |
+
|
7 |
+
from app.core.config import settings
|
8 |
+
|
9 |
+
# Configure logging
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
# Celery configuration
|
13 |
+
celery_app = Celery(
|
14 |
+
"dataset_impacts",
|
15 |
+
broker=settings.REDIS_URL,
|
16 |
+
backend=settings.REDIS_URL,
|
17 |
+
)
|
18 |
+
|
19 |
+
# Configure Celery settings
|
20 |
+
celery_app.conf.update(
|
21 |
+
task_serializer="json",
|
22 |
+
accept_content=["json"],
|
23 |
+
result_serializer="json",
|
24 |
+
timezone="UTC",
|
25 |
+
enable_utc=True,
|
26 |
+
worker_concurrency=settings.WORKER_CONCURRENCY,
|
27 |
+
task_acks_late=True, # Tasks are acknowledged after execution
|
28 |
+
task_reject_on_worker_lost=True, # Tasks are rejected if worker is terminated during execution
|
29 |
+
task_time_limit=3600, # 1 hour timeout per task
|
30 |
+
task_soft_time_limit=3000, # Soft timeout (30 minutes) - allows for graceful shutdown
|
31 |
+
worker_prefetch_multiplier=1, # Single prefetch - improves fair distribution of tasks
|
32 |
+
broker_connection_retry=True,
|
33 |
+
broker_connection_retry_on_startup=True,
|
34 |
+
broker_connection_max_retries=10,
|
35 |
+
broker_pool_limit=10, # Connection pool size
|
36 |
+
result_expires=60 * 60 * 24, # Results expire after 24 hours
|
37 |
+
task_track_started=True, # Track when tasks are started
|
38 |
+
)
|
39 |
+
|
40 |
+
# Set up task routes for different task types
|
41 |
+
celery_app.conf.task_routes = {
|
42 |
+
"app.tasks.dataset_tasks.*": {"queue": "dataset_impacts"},
|
43 |
+
"app.tasks.maintenance.*": {"queue": "maintenance"},
|
44 |
+
}
|
45 |
+
|
46 |
+
# Configure retry settings
|
47 |
+
celery_app.conf.task_default_retry_delay = 30 # 30 seconds
|
48 |
+
celery_app.conf.task_max_retries = 3
|
49 |
+
|
50 |
+
# Setup beat schedule for periodic tasks if enabled
|
51 |
+
celery_app.conf.beat_schedule = {
|
52 |
+
"cleanup-stale-tasks": {
|
53 |
+
"task": "app.tasks.maintenance.cleanup_stale_tasks",
|
54 |
+
"schedule": 3600.0, # Run every hour
|
55 |
+
},
|
56 |
+
"health-check": {
|
57 |
+
"task": "app.tasks.maintenance.health_check",
|
58 |
+
"schedule": 300.0, # Run every 5 minutes
|
59 |
+
},
|
60 |
+
"refresh-hf-datasets-cache": {
|
61 |
+
"task": "app.tasks.dataset_tasks.refresh_hf_datasets_cache",
|
62 |
+
"schedule": 3600.0, # Run every hour
|
63 |
+
},
|
64 |
+
}
|
65 |
+
|
66 |
+
# Signal handlers for monitoring and logging
|
67 |
+
@task_failure.connect
|
68 |
+
def task_failure_handler(sender=None, task_id=None, exception=None, **kwargs):
|
69 |
+
"""Log failed tasks."""
|
70 |
+
logger.error(f"Task {task_id} {sender.name} failed: {exception}")
|
71 |
+
|
72 |
+
@task_success.connect
|
73 |
+
def task_success_handler(sender=None, result=None, **kwargs):
|
74 |
+
"""Log successful tasks."""
|
75 |
+
task_name = sender.name if sender else "Unknown"
|
76 |
+
logger.info(f"Task {task_name} completed successfully")
|
77 |
+
|
78 |
+
@worker_ready.connect
|
79 |
+
def worker_ready_handler(**kwargs):
|
80 |
+
"""Log when worker is ready."""
|
81 |
+
logger.info(f"Celery worker ready: {kwargs.get('hostname')}")
|
82 |
+
|
83 |
+
@worker_shutdown.connect
|
84 |
+
def worker_shutdown_handler(**kwargs):
|
85 |
+
"""Log when worker is shutting down."""
|
86 |
+
logger.info(f"Celery worker shutting down: {kwargs.get('hostname')}")
|
87 |
+
|
88 |
+
def get_celery_app():
|
89 |
+
"""Get the Celery app instance."""
|
90 |
+
# Import all tasks to ensure they're registered
|
91 |
+
try:
|
92 |
+
# Using the improved app.tasks module which properly imports all tasks
|
93 |
+
import app.tasks
|
94 |
+
logger.info(f"Tasks successfully imported; registered {len(celery_app.tasks)} tasks")
|
95 |
+
except ImportError as e:
|
96 |
+
logger.error(f"Error importing tasks: {e}")
|
97 |
+
|
98 |
+
return celery_app
|
app/core/config.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Final, Optional
|
4 |
+
|
5 |
+
from pydantic import SecretStr, HttpUrl, Field
|
6 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
7 |
+
|
8 |
+
class Settings(BaseSettings):
|
9 |
+
"""
|
10 |
+
Core application settings.
|
11 |
+
Reads environment variables and .env file.
|
12 |
+
"""
|
13 |
+
# Supabase Settings
|
14 |
+
SUPABASE_URL: HttpUrl
|
15 |
+
SUPABASE_SERVICE_KEY: SecretStr
|
16 |
+
SUPABASE_ANON_KEY: SecretStr
|
17 |
+
SUPABASE_JWT_SECRET: Optional[SecretStr] = None # Optional for local dev
|
18 |
+
|
19 |
+
# Hugging Face API Token
|
20 |
+
HF_API_TOKEN: Optional[SecretStr] = None
|
21 |
+
|
22 |
+
# Redis settings
|
23 |
+
REDIS_URL: str = "redis://localhost:6379/0"
|
24 |
+
REDIS_PASSWORD: Optional[SecretStr] = None
|
25 |
+
|
26 |
+
# Toggle Redis cache layer
|
27 |
+
ENABLE_REDIS_CACHE: bool = True
|
28 |
+
|
29 |
+
# ──────────────────────────────── Security ────────────────────────────────
|
30 |
+
# JWT secret key. NEVER hard-code in source; override with env variable in production.
|
31 |
+
SECRET_KEY: SecretStr = Field("change-me", env="SECRET_KEY")
|
32 |
+
ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(60 * 24 * 7, env="ACCESS_TOKEN_EXPIRE_MINUTES") # 1 week by default
|
33 |
+
|
34 |
+
# Worker settings
|
35 |
+
WORKER_CONCURRENCY: int = 10 # Increased from 5 for better parallel performance
|
36 |
+
|
37 |
+
# Batch processing chunk size for Celery dataset tasks
|
38 |
+
DATASET_BATCH_CHUNK_SIZE: int = 50
|
39 |
+
|
40 |
+
# Tell pydantic-settings to auto-load `.env` if present
|
41 |
+
model_config: Final = SettingsConfigDict(
|
42 |
+
env_file=".env",
|
43 |
+
case_sensitive=False,
|
44 |
+
extra="ignore"
|
45 |
+
)
|
46 |
+
|
47 |
+
# Single, shared instance of settings
|
48 |
+
settings = Settings()
|
app/main.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import json
|
3 |
+
from fastapi import FastAPI
|
4 |
+
from app.api import api_router
|
5 |
+
from fastapi.middleware.cors import CORSMiddleware
|
6 |
+
from fastapi.middleware.gzip import GZipMiddleware
|
7 |
+
|
8 |
+
class JsonFormatter(logging.Formatter):
|
9 |
+
def format(self, record):
|
10 |
+
log_record = {
|
11 |
+
"level": record.levelname,
|
12 |
+
"time": self.formatTime(record, self.datefmt),
|
13 |
+
"name": record.name,
|
14 |
+
"message": record.getMessage(),
|
15 |
+
}
|
16 |
+
if record.exc_info:
|
17 |
+
log_record["exc_info"] = self.formatException(record.exc_info)
|
18 |
+
return json.dumps(log_record)
|
19 |
+
|
20 |
+
handler = logging.StreamHandler()
|
21 |
+
handler.setFormatter(JsonFormatter())
|
22 |
+
logging.basicConfig(level=logging.INFO, handlers=[handler])
|
23 |
+
|
24 |
+
app = FastAPI(title="Collinear API")
|
25 |
+
|
26 |
+
# Enable CORS for the frontend
|
27 |
+
frontend_origin = "http://localhost:5173"
|
28 |
+
app.add_middleware(
|
29 |
+
CORSMiddleware,
|
30 |
+
allow_origins=[frontend_origin],
|
31 |
+
allow_credentials=True,
|
32 |
+
allow_methods=["*"],
|
33 |
+
allow_headers=["*"],
|
34 |
+
)
|
35 |
+
|
36 |
+
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
37 |
+
|
38 |
+
app.include_router(api_router, prefix="/api")
|
39 |
+
|
40 |
+
@app.get("/")
|
41 |
+
async def root():
|
42 |
+
return {"message": "Welcome to the Collinear Data Tool API"}
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
import uvicorn
|
46 |
+
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
|
app/schemas/dataset.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Dict, List, Optional, Any
|
3 |
+
from datetime import datetime
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
from app.schemas.dataset_common import ImpactLevel, DatasetMetrics
|
7 |
+
|
8 |
+
# Log for this module
|
9 |
+
log = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
# Supported strategies for dataset combination
|
12 |
+
SUPPORTED_STRATEGIES = ["merge", "intersect", "filter"]
|
13 |
+
|
14 |
+
class ImpactAssessment(BaseModel):
|
15 |
+
dataset_id: str = Field(..., description="The ID of the dataset being assessed")
|
16 |
+
impact_level: ImpactLevel = Field(..., description="The impact level: low, medium, or high")
|
17 |
+
assessment_method: str = Field(
|
18 |
+
"unknown",
|
19 |
+
description="Method used to determine impact level (e.g., size_based, downloads_and_likes_based)"
|
20 |
+
)
|
21 |
+
metrics: DatasetMetrics = Field(
|
22 |
+
...,
|
23 |
+
description="Metrics used for impact assessment"
|
24 |
+
)
|
25 |
+
thresholds: Dict[str, Dict[str, str]] = Field(
|
26 |
+
{},
|
27 |
+
description="Thresholds used for determining impact levels (for reference)"
|
28 |
+
)
|
29 |
+
|
30 |
+
class DatasetInfo(BaseModel):
|
31 |
+
id: str
|
32 |
+
impact_level: Optional[ImpactLevel] = None
|
33 |
+
impact_assessment: Optional[Dict] = None
|
34 |
+
# Add other fields as needed
|
35 |
+
class Config:
|
36 |
+
extra = "allow" # Allow extra fields from the API
|
37 |
+
|
38 |
+
class DatasetBase(BaseModel):
|
39 |
+
name: str
|
40 |
+
description: Optional[str] = None
|
41 |
+
tags: Optional[List[str]] = None
|
42 |
+
|
43 |
+
class DatasetCreate(DatasetBase):
|
44 |
+
files: Optional[List[str]] = None
|
45 |
+
|
46 |
+
class DatasetUpdate(DatasetBase):
|
47 |
+
name: Optional[str] = None # Make fields optional for updates
|
48 |
+
|
49 |
+
class Dataset(DatasetBase):
|
50 |
+
id: int # or str depending on your ID format
|
51 |
+
owner_id: str # Assuming user IDs are strings
|
52 |
+
created_at: Optional[str] = None
|
53 |
+
updated_at: Optional[str] = None
|
54 |
+
class Config:
|
55 |
+
pass # Removed orm_mode = True since ORM is not used
|
56 |
+
|
57 |
+
class DatasetCombineRequest(BaseModel):
|
58 |
+
source_datasets: List[str] = Field(..., description="List of dataset IDs to combine")
|
59 |
+
name: str = Field(..., description="Name for the combined dataset")
|
60 |
+
description: Optional[str] = Field(None, description="Description for the combined dataset")
|
61 |
+
combination_strategy: str = Field("merge", description="Strategy to use when combining datasets (e.g., 'merge', 'intersect', 'filter')")
|
62 |
+
filter_criteria: Optional[Dict[str, Any]] = Field(None, description="Criteria for filtering when combining datasets")
|
63 |
+
|
64 |
+
class CombinedDataset(BaseModel):
|
65 |
+
id: str = Field(..., description="ID of the combined dataset")
|
66 |
+
name: str = Field(..., description="Name of the combined dataset")
|
67 |
+
description: Optional[str] = Field(None, description="Description of the combined dataset")
|
68 |
+
source_datasets: List[str] = Field(..., description="IDs of the source datasets")
|
69 |
+
created_at: datetime = Field(..., description="Creation timestamp")
|
70 |
+
created_by: str = Field(..., description="ID of the user who created this combined dataset")
|
71 |
+
impact_level: Optional[ImpactLevel] = Field(None, description="Calculated impact level of the combined dataset")
|
72 |
+
status: str = Field("processing", description="Status of the dataset combination process")
|
73 |
+
combination_strategy: str = Field(..., description="Strategy used when combining datasets")
|
74 |
+
metrics: Optional[DatasetMetrics] = Field(None, description="Metrics for the combined dataset")
|
75 |
+
storage_bucket_id: Optional[str] = Field(None, description="ID of the storage bucket containing dataset files")
|
76 |
+
storage_folder_path: Optional[str] = Field(None, description="Path to the dataset files within the bucket")
|
77 |
+
class Config:
|
78 |
+
extra = "allow" # Allow extra fields for flexibility
|
79 |
+
|
80 |
+
__all__ = ["ImpactLevel", "ImpactAssessment", "DatasetInfo", "DatasetMetrics",
|
81 |
+
"Dataset", "DatasetCreate", "DatasetUpdate", "DatasetCombineRequest", "CombinedDataset"]
|
app/schemas/dataset_common.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
# Define the impact level as an enum for better type safety
|
6 |
+
class ImpactLevel(str, Enum):
|
7 |
+
NA = "not_available" # New category for when size information is unavailable
|
8 |
+
LOW = "low"
|
9 |
+
MEDIUM = "medium"
|
10 |
+
HIGH = "high"
|
11 |
+
|
12 |
+
# Define metrics model for impact assessment
|
13 |
+
class DatasetMetrics(BaseModel):
|
14 |
+
size_bytes: Optional[int] = Field(None, description="Size of the dataset in bytes")
|
15 |
+
file_count: Optional[int] = Field(None, description="Number of files in the dataset")
|
16 |
+
downloads: Optional[int] = Field(None, description="Number of downloads (all time)")
|
17 |
+
likes: Optional[int] = Field(None, description="Number of likes")
|
app/services/hf_datasets.py
ADDED
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import json
|
3 |
+
from typing import Any, List, Optional, Dict, Tuple
|
4 |
+
import requests
|
5 |
+
from huggingface_hub import HfApi
|
6 |
+
from app.core.config import settings
|
7 |
+
from app.schemas.dataset_common import ImpactLevel
|
8 |
+
from app.services.redis_client import sync_cache_set, sync_cache_get, generate_cache_key, get_redis_sync
|
9 |
+
import time
|
10 |
+
import asyncio
|
11 |
+
import redis
|
12 |
+
import gzip
|
13 |
+
from datetime import datetime, timezone
|
14 |
+
import os
|
15 |
+
from app.schemas.dataset import ImpactAssessment
|
16 |
+
from app.schemas.dataset_common import DatasetMetrics
|
17 |
+
import httpx
|
18 |
+
import redis.asyncio as aioredis
|
19 |
+
|
20 |
+
log = logging.getLogger(__name__)
|
21 |
+
api = HfApi()
|
22 |
+
redis_client = redis.Redis(host="redis", port=6379, decode_responses=True)
|
23 |
+
|
24 |
+
# Thresholds for impact categorization
|
25 |
+
SIZE_THRESHOLD_LOW = 100 * 1024 * 1024 # 100 MB
|
26 |
+
SIZE_THRESHOLD_MEDIUM = 1024 * 1024 * 1024 # 1 GB
|
27 |
+
DOWNLOADS_THRESHOLD_LOW = 1000
|
28 |
+
DOWNLOADS_THRESHOLD_MEDIUM = 10000
|
29 |
+
LIKES_THRESHOLD_LOW = 10
|
30 |
+
LIKES_THRESHOLD_MEDIUM = 100
|
31 |
+
|
32 |
+
HF_API_URL = "https://huggingface.co/api/datasets"
|
33 |
+
DATASET_CACHE_TTL = 60 * 60 # 1 hour
|
34 |
+
|
35 |
+
# Redis and HuggingFace API setup
|
36 |
+
REDIS_KEY = "hf:datasets:all:compressed"
|
37 |
+
REDIS_META_KEY = "hf:datasets:meta"
|
38 |
+
REDIS_TTL = 60 * 60 # 1 hour
|
39 |
+
|
40 |
+
# Impact thresholds (in bytes)
|
41 |
+
SIZE_LOW = 100 * 1024 * 1024
|
42 |
+
SIZE_MEDIUM = 1024 * 1024 * 1024
|
43 |
+
|
44 |
+
def get_hf_token():
|
45 |
+
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
46 |
+
if not token:
|
47 |
+
raise RuntimeError("HUGGINGFACEHUB_API_TOKEN environment variable is not set. Please set it securely.")
|
48 |
+
return token
|
49 |
+
|
50 |
+
def get_dataset_commits(dataset_id: str, limit: int = 20):
|
51 |
+
from huggingface_hub import HfApi
|
52 |
+
import logging
|
53 |
+
log = logging.getLogger(__name__)
|
54 |
+
api = HfApi()
|
55 |
+
log.info(f"[get_dataset_commits] Fetching commits for dataset_id={dataset_id}")
|
56 |
+
try:
|
57 |
+
commits = api.list_repo_commits(repo_id=dataset_id, repo_type="dataset")
|
58 |
+
log.info(f"[get_dataset_commits] Received {len(commits)} commits for {dataset_id}")
|
59 |
+
except Exception as e:
|
60 |
+
log.error(f"[get_dataset_commits] Error fetching commits for {dataset_id}: {e}", exc_info=True)
|
61 |
+
raise # Let the API layer catch and handle this
|
62 |
+
result = []
|
63 |
+
for c in commits[:limit]:
|
64 |
+
try:
|
65 |
+
commit_id = getattr(c, "commit_id", "")
|
66 |
+
title = getattr(c, "title", "")
|
67 |
+
message = getattr(c, "message", title)
|
68 |
+
authors = getattr(c, "authors", [])
|
69 |
+
author_name = authors[0] if authors and isinstance(authors, list) else ""
|
70 |
+
created_at = getattr(c, "created_at", None)
|
71 |
+
if created_at:
|
72 |
+
if hasattr(created_at, "isoformat"):
|
73 |
+
date = created_at.isoformat()
|
74 |
+
else:
|
75 |
+
date = str(created_at)
|
76 |
+
else:
|
77 |
+
date = ""
|
78 |
+
result.append({
|
79 |
+
"id": commit_id or "",
|
80 |
+
"title": title or message or "",
|
81 |
+
"message": message or title or "",
|
82 |
+
"author": {"name": author_name, "email": ""},
|
83 |
+
"date": date,
|
84 |
+
})
|
85 |
+
except Exception as e:
|
86 |
+
log.error(f"[get_dataset_commits] Error parsing commit: {e} | Commit: {getattr(c, '__dict__', str(c))}", exc_info=True)
|
87 |
+
log.info(f"[get_dataset_commits] Returning {len(result)} parsed commits for {dataset_id}")
|
88 |
+
return result
|
89 |
+
|
90 |
+
def get_dataset_files(dataset_id: str) -> List[str]:
|
91 |
+
return api.list_repo_files(repo_id=dataset_id, repo_type="dataset")
|
92 |
+
|
93 |
+
def get_file_url(dataset_id: str, filename: str, revision: Optional[str] = None) -> str:
|
94 |
+
from huggingface_hub import hf_hub_url
|
95 |
+
return hf_hub_url(repo_id=dataset_id, filename=filename, repo_type="dataset", revision=revision)
|
96 |
+
|
97 |
+
def get_datasets_page_from_zset(offset: int = 0, limit: int = 10, search: str = None) -> dict:
|
98 |
+
import redis
|
99 |
+
import json
|
100 |
+
redis_client = redis.Redis(host="redis", port=6379, db=0, decode_responses=True)
|
101 |
+
zset_key = "hf:datasets:all:zset"
|
102 |
+
hash_key = "hf:datasets:all:hash"
|
103 |
+
# Get total count
|
104 |
+
total = redis_client.zcard(zset_key)
|
105 |
+
# Get dataset IDs for the page
|
106 |
+
ids = redis_client.zrange(zset_key, offset, offset + limit - 1)
|
107 |
+
# Fetch metadata for those IDs
|
108 |
+
if not ids:
|
109 |
+
return {"items": [], "count": total}
|
110 |
+
items = redis_client.hmget(hash_key, ids)
|
111 |
+
# Parse JSON and filter/search if needed
|
112 |
+
parsed = []
|
113 |
+
for raw in items:
|
114 |
+
if not raw:
|
115 |
+
continue
|
116 |
+
try:
|
117 |
+
item = json.loads(raw)
|
118 |
+
parsed.append(item)
|
119 |
+
except Exception:
|
120 |
+
continue
|
121 |
+
if search:
|
122 |
+
parsed = [d for d in parsed if search.lower() in (d.get("id") or "").lower()]
|
123 |
+
return {"items": parsed, "count": total}
|
124 |
+
|
125 |
+
async def _fetch_size(session: httpx.AsyncClient, dataset_id: str) -> Optional[int]:
|
126 |
+
"""Fetch dataset size from the datasets server asynchronously."""
|
127 |
+
url = f"https://datasets-server.huggingface.co/size?dataset={dataset_id}"
|
128 |
+
try:
|
129 |
+
resp = await session.get(url, timeout=30)
|
130 |
+
if resp.status_code == 200:
|
131 |
+
data = resp.json()
|
132 |
+
return data.get("size", {}).get("dataset", {}).get("num_bytes_original_files")
|
133 |
+
except Exception as e:
|
134 |
+
log.warning(f"Could not fetch size for {dataset_id}: {e}")
|
135 |
+
return None
|
136 |
+
|
137 |
+
async def _fetch_sizes(dataset_ids: List[str]) -> Dict[str, Optional[int]]:
|
138 |
+
"""Fetch dataset sizes in parallel."""
|
139 |
+
results: Dict[str, Optional[int]] = {}
|
140 |
+
async with httpx.AsyncClient() as session:
|
141 |
+
tasks = {dataset_id: asyncio.create_task(_fetch_size(session, dataset_id)) for dataset_id in dataset_ids}
|
142 |
+
for dataset_id, task in tasks.items():
|
143 |
+
results[dataset_id] = await task
|
144 |
+
return results
|
145 |
+
|
146 |
+
def process_datasets_page(offset, limit):
|
147 |
+
"""
|
148 |
+
Fetch and process a single page of datasets from Hugging Face and cache them in Redis.
|
149 |
+
"""
|
150 |
+
import redis
|
151 |
+
import os
|
152 |
+
import json
|
153 |
+
import asyncio
|
154 |
+
log = logging.getLogger(__name__)
|
155 |
+
log.info(f"[process_datasets_page] ENTRY: offset={offset}, limit={limit}")
|
156 |
+
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
157 |
+
if not token:
|
158 |
+
log.error("[process_datasets_page] HUGGINGFACEHUB_API_TOKEN environment variable is not set.")
|
159 |
+
raise RuntimeError("HUGGINGFACEHUB_API_TOKEN environment variable is not set. Please set it securely.")
|
160 |
+
headers = {
|
161 |
+
"Authorization": f"Bearer {token}",
|
162 |
+
"User-Agent": "Mozilla/5.0 (compatible; CollinearTool/1.0; +https://yourdomain.com)"
|
163 |
+
}
|
164 |
+
params = {"limit": limit, "offset": offset, "full": "True"}
|
165 |
+
redis_client = redis.Redis(host="redis", port=6379, db=0, decode_responses=True)
|
166 |
+
stream_key = "hf:datasets:all:stream"
|
167 |
+
zset_key = "hf:datasets:all:zset"
|
168 |
+
hash_key = "hf:datasets:all:hash"
|
169 |
+
try:
|
170 |
+
log.info(f"[process_datasets_page] Requesting {HF_API_URL} with params={params}")
|
171 |
+
response = requests.get(HF_API_URL, headers=headers, params=params, timeout=120)
|
172 |
+
response.raise_for_status()
|
173 |
+
|
174 |
+
page_items = response.json()
|
175 |
+
|
176 |
+
log.info(f"[process_datasets_page] Received {len(page_items)} datasets at offset {offset}")
|
177 |
+
|
178 |
+
dataset_ids = [ds.get("id") for ds in page_items]
|
179 |
+
size_map = asyncio.run(_fetch_sizes(dataset_ids))
|
180 |
+
|
181 |
+
for ds in page_items:
|
182 |
+
dataset_id = ds.get("id")
|
183 |
+
size_bytes = size_map.get(dataset_id)
|
184 |
+
downloads = ds.get("downloads")
|
185 |
+
likes = ds.get("likes")
|
186 |
+
impact_level, assessment_method = determine_impact_level_by_criteria(size_bytes, downloads, likes)
|
187 |
+
metrics = DatasetMetrics(size_bytes=size_bytes, downloads=downloads, likes=likes)
|
188 |
+
thresholds = {
|
189 |
+
"size_bytes": {
|
190 |
+
"low": str(100 * 1024 * 1024),
|
191 |
+
"medium": str(1 * 1024 * 1024 * 1024),
|
192 |
+
"high": str(10 * 1024 * 1024 * 1024)
|
193 |
+
}
|
194 |
+
}
|
195 |
+
impact_assessment = ImpactAssessment(
|
196 |
+
dataset_id=dataset_id,
|
197 |
+
impact_level=impact_level,
|
198 |
+
assessment_method=assessment_method,
|
199 |
+
metrics=metrics,
|
200 |
+
thresholds=thresholds
|
201 |
+
).model_dump()
|
202 |
+
item = {
|
203 |
+
"id": dataset_id,
|
204 |
+
"name": ds.get("name"),
|
205 |
+
"description": ds.get("description"),
|
206 |
+
"size_bytes": size_bytes,
|
207 |
+
"impact_level": impact_level.value if isinstance(impact_level, ImpactLevel) else impact_level,
|
208 |
+
"downloads": downloads,
|
209 |
+
"likes": likes,
|
210 |
+
"tags": ds.get("tags", []),
|
211 |
+
"impact_assessment": json.dumps(impact_assessment)
|
212 |
+
}
|
213 |
+
final_item = {}
|
214 |
+
for k, v in item.items():
|
215 |
+
if isinstance(v, list) or isinstance(v, dict):
|
216 |
+
final_item[k] = json.dumps(v)
|
217 |
+
elif v is None:
|
218 |
+
final_item[k] = 'null'
|
219 |
+
else:
|
220 |
+
final_item[k] = str(v)
|
221 |
+
|
222 |
+
redis_client.xadd(stream_key, final_item)
|
223 |
+
redis_client.zadd(zset_key, {dataset_id: offset})
|
224 |
+
redis_client.hset(hash_key, dataset_id, json.dumps(item))
|
225 |
+
|
226 |
+
log.info(f"[process_datasets_page] EXIT: Cached {len(page_items)} datasets at offset {offset}")
|
227 |
+
return len(page_items)
|
228 |
+
except Exception as exc:
|
229 |
+
log.error(f"[process_datasets_page] ERROR: offset={offset}, limit={limit}, exc={exc}", exc_info=True)
|
230 |
+
raise
|
231 |
+
|
232 |
+
def refresh_datasets_cache():
|
233 |
+
"""
|
234 |
+
Orchestrator: Enqueue Celery tasks to fetch all Hugging Face datasets in parallel.
|
235 |
+
Uses direct calls to HF API.
|
236 |
+
"""
|
237 |
+
import requests
|
238 |
+
log.info("[refresh_datasets_cache] Orchestrating dataset fetch tasks using direct HF API calls.")
|
239 |
+
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
240 |
+
if not token:
|
241 |
+
log.error("[refresh_datasets_cache] HUGGINGFACEHUB_API_TOKEN environment variable is not set.")
|
242 |
+
raise RuntimeError("HUGGINGFACEHUB_API_TOKEN environment variable is not set. Please set it securely.")
|
243 |
+
|
244 |
+
headers = {
|
245 |
+
"Authorization": f"Bearer {token}",
|
246 |
+
"User-Agent": "Mozilla/5.0 (compatible; CollinearTool/1.0; +https://yourdomain.com)"
|
247 |
+
}
|
248 |
+
limit = 500
|
249 |
+
|
250 |
+
params = {"limit": 1, "offset": 0}
|
251 |
+
try:
|
252 |
+
response = requests.get(HF_API_URL, headers=headers, params=params, timeout=120)
|
253 |
+
response.raise_for_status()
|
254 |
+
total_str = response.headers.get('X-Total-Count')
|
255 |
+
if not total_str:
|
256 |
+
log.error("[refresh_datasets_cache] 'X-Total-Count' header not found in HF API response.")
|
257 |
+
raise ValueError("'X-Total-Count' header missing from Hugging Face API response.")
|
258 |
+
total = int(total_str)
|
259 |
+
log.info(f"[refresh_datasets_cache] Total datasets reported by HF API: {total}")
|
260 |
+
except requests.RequestException as e:
|
261 |
+
log.error(f"[refresh_datasets_cache] Error fetching total dataset count from HF API: {e}")
|
262 |
+
raise
|
263 |
+
except ValueError as e:
|
264 |
+
log.error(f"[refresh_datasets_cache] Error parsing total dataset count: {e}")
|
265 |
+
raise
|
266 |
+
|
267 |
+
num_pages = (total + limit - 1) // limit
|
268 |
+
from app.tasks.dataset_tasks import fetch_datasets_page
|
269 |
+
from celery import group
|
270 |
+
tasks = []
|
271 |
+
for page_num in range(num_pages):
|
272 |
+
offset = page_num * limit
|
273 |
+
tasks.append(fetch_datasets_page.s(offset, limit))
|
274 |
+
log.info(f"[refresh_datasets_cache] Scheduled page at offset {offset}, limit {limit}.")
|
275 |
+
if tasks:
|
276 |
+
group(tasks).apply_async()
|
277 |
+
log.info(f"[refresh_datasets_cache] Enqueued {len(tasks)} fetch tasks.")
|
278 |
+
else:
|
279 |
+
log.warning("[refresh_datasets_cache] No dataset pages found to schedule.")
|
280 |
+
|
281 |
+
def determine_impact_level_by_criteria(size_bytes, downloads=None, likes=None):
|
282 |
+
try:
|
283 |
+
size = int(size_bytes) if size_bytes not in (None, 'null') else 0
|
284 |
+
except Exception:
|
285 |
+
size = 0
|
286 |
+
|
287 |
+
# Prefer size_bytes if available
|
288 |
+
if size >= 10 * 1024 * 1024 * 1024:
|
289 |
+
return ("high", "large_size")
|
290 |
+
elif size >= 1 * 1024 * 1024 * 1024:
|
291 |
+
return ("medium", "medium_size")
|
292 |
+
elif size >= 100 * 1024 * 1024:
|
293 |
+
return ("low", "small_size")
|
294 |
+
# Fallback to downloads if size_bytes is missing or too small
|
295 |
+
if downloads is not None:
|
296 |
+
try:
|
297 |
+
downloads = int(downloads)
|
298 |
+
if downloads >= 100000:
|
299 |
+
return ("high", "downloads")
|
300 |
+
elif downloads >= 10000:
|
301 |
+
return ("medium", "downloads")
|
302 |
+
elif downloads >= 1000:
|
303 |
+
return ("low", "downloads")
|
304 |
+
except Exception:
|
305 |
+
pass
|
306 |
+
# Fallback to likes if downloads is missing
|
307 |
+
if likes is not None:
|
308 |
+
try:
|
309 |
+
likes = int(likes)
|
310 |
+
if likes >= 1000:
|
311 |
+
return ("high", "likes")
|
312 |
+
elif likes >= 100:
|
313 |
+
return ("medium", "likes")
|
314 |
+
elif likes >= 10:
|
315 |
+
return ("low", "likes")
|
316 |
+
except Exception:
|
317 |
+
pass
|
318 |
+
return ("not_available", "size_and_downloads_and_likes_unknown")
|
319 |
+
|
320 |
+
def get_dataset_size(dataset: dict, dataset_id: str = None):
|
321 |
+
"""
|
322 |
+
Extract the size in bytes from a dataset dictionary.
|
323 |
+
Tries multiple locations based on possible HuggingFace API responses.
|
324 |
+
"""
|
325 |
+
# Try top-level key
|
326 |
+
size_bytes = dataset.get("size_bytes")
|
327 |
+
if size_bytes not in (None, 'null'):
|
328 |
+
return size_bytes
|
329 |
+
# Try nested structure from the size API
|
330 |
+
size_bytes = (
|
331 |
+
dataset.get("size", {})
|
332 |
+
.get("dataset", {})
|
333 |
+
.get("num_bytes_original_files")
|
334 |
+
)
|
335 |
+
if size_bytes not in (None, 'null'):
|
336 |
+
return size_bytes
|
337 |
+
# Try metrics or info sub-dictionaries if present
|
338 |
+
for key in ["metrics", "info"]:
|
339 |
+
sub = dataset.get(key, {})
|
340 |
+
if isinstance(sub, dict):
|
341 |
+
size_bytes = sub.get("size_bytes")
|
342 |
+
if size_bytes not in (None, 'null'):
|
343 |
+
return size_bytes
|
344 |
+
# Not found
|
345 |
+
return None
|
346 |
+
|
347 |
+
async def get_datasets_page_from_zset_async(offset: int = 0, limit: int = 10, search: str = None) -> dict:
|
348 |
+
redis_client = aioredis.Redis(host="redis", port=6379, db=0, decode_responses=True)
|
349 |
+
zset_key = "hf:datasets:all:zset"
|
350 |
+
hash_key = "hf:datasets:all:hash"
|
351 |
+
total = await redis_client.zcard(zset_key)
|
352 |
+
ids = await redis_client.zrange(zset_key, offset, offset + limit - 1)
|
353 |
+
if not ids:
|
354 |
+
return {"items": [], "count": total}
|
355 |
+
items = await redis_client.hmget(hash_key, ids)
|
356 |
+
parsed = []
|
357 |
+
for raw in items:
|
358 |
+
if not raw:
|
359 |
+
continue
|
360 |
+
try:
|
361 |
+
item = json.loads(raw)
|
362 |
+
parsed.append(item)
|
363 |
+
except Exception:
|
364 |
+
continue
|
365 |
+
if search:
|
366 |
+
parsed = [d for d in parsed if search.lower() in (d.get("id") or "").lower()]
|
367 |
+
return {"items": parsed, "count": total}
|
368 |
+
|
369 |
+
async def get_dataset_commits_async(dataset_id: str, limit: int = 20):
|
370 |
+
from huggingface_hub import HfApi
|
371 |
+
import logging
|
372 |
+
log = logging.getLogger(__name__)
|
373 |
+
api = HfApi()
|
374 |
+
log.info(f"[get_dataset_commits_async] Fetching commits for dataset_id={dataset_id}")
|
375 |
+
try:
|
376 |
+
# huggingface_hub is sync, so run in threadpool
|
377 |
+
import anyio
|
378 |
+
commits = await anyio.to_thread.run_sync(api.list_repo_commits, repo_id=dataset_id, repo_type="dataset")
|
379 |
+
log.info(f"[get_dataset_commits_async] Received {len(commits)} commits for {dataset_id}")
|
380 |
+
except Exception as e:
|
381 |
+
log.error(f"[get_dataset_commits_async] Error fetching commits for {dataset_id}: {e}", exc_info=True)
|
382 |
+
raise
|
383 |
+
result = []
|
384 |
+
for c in commits[:limit]:
|
385 |
+
try:
|
386 |
+
commit_id = getattr(c, "commit_id", "")
|
387 |
+
title = getattr(c, "title", "")
|
388 |
+
message = getattr(c, "message", title)
|
389 |
+
authors = getattr(c, "authors", [])
|
390 |
+
author_name = authors[0] if authors and isinstance(authors, list) else ""
|
391 |
+
created_at = getattr(c, "created_at", None)
|
392 |
+
if created_at:
|
393 |
+
if hasattr(created_at, "isoformat"):
|
394 |
+
date = created_at.isoformat()
|
395 |
+
else:
|
396 |
+
date = str(created_at)
|
397 |
+
else:
|
398 |
+
date = ""
|
399 |
+
result.append({
|
400 |
+
"id": commit_id or "",
|
401 |
+
"title": title or message or "",
|
402 |
+
"message": message or title or "",
|
403 |
+
"author": {"name": author_name, "email": ""},
|
404 |
+
"date": date,
|
405 |
+
})
|
406 |
+
except Exception as e:
|
407 |
+
log.error(f"[get_dataset_commits_async] Error parsing commit: {e} | Commit: {getattr(c, '__dict__', str(c))}", exc_info=True)
|
408 |
+
log.info(f"[get_dataset_commits_async] Returning {len(result)} parsed commits for {dataset_id}")
|
409 |
+
return result
|
410 |
+
|
411 |
+
async def get_dataset_files_async(dataset_id: str) -> List[str]:
|
412 |
+
from huggingface_hub import HfApi
|
413 |
+
import anyio
|
414 |
+
api = HfApi()
|
415 |
+
# huggingface_hub is sync, so run in threadpool
|
416 |
+
return await anyio.to_thread.run_sync(api.list_repo_files, repo_id=dataset_id, repo_type="dataset")
|
417 |
+
|
418 |
+
async def get_file_url_async(dataset_id: str, filename: str, revision: Optional[str] = None) -> str:
|
419 |
+
from huggingface_hub import hf_hub_url
|
420 |
+
import anyio
|
421 |
+
# huggingface_hub is sync, so run in threadpool
|
422 |
+
return await anyio.to_thread.run_sync(hf_hub_url, repo_id=dataset_id, filename=filename, repo_type="dataset", revision=revision)
|
423 |
+
|
424 |
+
# Fetch and cache all datasets
|
425 |
+
|
426 |
+
class EnhancedJSONEncoder(json.JSONEncoder):
|
427 |
+
def default(self, obj):
|
428 |
+
if isinstance(obj, datetime):
|
429 |
+
return obj.isoformat()
|
430 |
+
return super().default(obj)
|
431 |
+
|
432 |
+
async def fetch_size(session, dataset_id, token=None):
|
433 |
+
url = f"https://datasets-server.huggingface.co/size?dataset={dataset_id}"
|
434 |
+
headers = {"Authorization": f"Bearer {token}"} if token else {}
|
435 |
+
try:
|
436 |
+
resp = await session.get(url, headers=headers, timeout=30)
|
437 |
+
if resp.status_code == 200:
|
438 |
+
data = resp.json()
|
439 |
+
return dataset_id, data.get("size", {}).get("dataset", {}).get("num_bytes_original_files")
|
440 |
+
except Exception as e:
|
441 |
+
log.warning(f"Could not fetch size for {dataset_id}: {e}")
|
442 |
+
return dataset_id, None
|
443 |
+
|
444 |
+
async def fetch_all_sizes(dataset_ids, token=None, batch_size=50):
|
445 |
+
results = {}
|
446 |
+
async with httpx.AsyncClient() as session:
|
447 |
+
for i in range(0, len(dataset_ids), batch_size):
|
448 |
+
batch = dataset_ids[i:i+batch_size]
|
449 |
+
tasks = [fetch_size(session, ds_id, token) for ds_id in batch]
|
450 |
+
batch_results = await asyncio.gather(*tasks)
|
451 |
+
for ds_id, size in batch_results:
|
452 |
+
results[ds_id] = size
|
453 |
+
return results
|
454 |
+
|
455 |
+
def fetch_and_cache_all_datasets(token: str):
|
456 |
+
api = HfApi(token=token)
|
457 |
+
log.info("Fetching all datasets from Hugging Face Hub...")
|
458 |
+
all_datasets = list(api.list_datasets())
|
459 |
+
all_datasets_dicts = []
|
460 |
+
dataset_ids = [d.id for d in all_datasets]
|
461 |
+
# Fetch all sizes in batches
|
462 |
+
sizes = asyncio.run(fetch_all_sizes(dataset_ids, token=token, batch_size=50))
|
463 |
+
for d in all_datasets:
|
464 |
+
data = d.__dict__
|
465 |
+
size_bytes = sizes.get(d.id)
|
466 |
+
downloads = data.get("downloads")
|
467 |
+
likes = data.get("likes")
|
468 |
+
data["size_bytes"] = size_bytes
|
469 |
+
impact_level, _ = determine_impact_level_by_criteria(size_bytes, downloads, likes)
|
470 |
+
data["impact_level"] = impact_level
|
471 |
+
all_datasets_dicts.append(data)
|
472 |
+
compressed = gzip.compress(json.dumps(all_datasets_dicts, cls=EnhancedJSONEncoder).encode("utf-8"))
|
473 |
+
r = redis.Redis(host="redis", port=6379, decode_responses=False)
|
474 |
+
r.set(REDIS_KEY, compressed)
|
475 |
+
log.info(f"Cached {len(all_datasets_dicts)} datasets in Redis under {REDIS_KEY}")
|
476 |
+
return len(all_datasets_dicts)
|
477 |
+
|
478 |
+
# Native pagination from cache
|
479 |
+
|
480 |
+
def get_datasets_page_from_cache(limit: int, offset: int):
|
481 |
+
r = redis.Redis(host="redis", port=6379, decode_responses=False)
|
482 |
+
compressed = r.get(REDIS_KEY)
|
483 |
+
if not compressed:
|
484 |
+
return {"error": "Cache not found. Please refresh datasets."}, 404
|
485 |
+
all_datasets = json.loads(gzip.decompress(compressed).decode("utf-8"))
|
486 |
+
total = len(all_datasets)
|
487 |
+
if offset < 0 or offset >= total:
|
488 |
+
return {"error": "Offset out of range.", "total": total}, 400
|
489 |
+
page = all_datasets[offset:offset+limit]
|
490 |
+
total_pages = (total + limit - 1) // limit
|
491 |
+
current_page = (offset // limit) + 1
|
492 |
+
next_page = current_page + 1 if offset + limit < total else None
|
493 |
+
prev_page = current_page - 1 if current_page > 1 else None
|
494 |
+
return {
|
495 |
+
"total": total,
|
496 |
+
"current_page": current_page,
|
497 |
+
"total_pages": total_pages,
|
498 |
+
"next_page": next_page,
|
499 |
+
"prev_page": prev_page,
|
500 |
+
"items": page
|
501 |
+
}, 200
|
app/services/redis_client.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Redis client for caching and task queue management."""
|
2 |
+
|
3 |
+
import json
|
4 |
+
from typing import Any, Dict, Optional, TypeVar
|
5 |
+
from datetime import datetime
|
6 |
+
import logging
|
7 |
+
from time import time as _time
|
8 |
+
|
9 |
+
import redis.asyncio as redis_async
|
10 |
+
import redis as redis_sync # Import synchronous Redis client
|
11 |
+
from pydantic import BaseModel
|
12 |
+
from tenacity import retry, stop_after_attempt, wait_exponential
|
13 |
+
|
14 |
+
from app.core.config import settings
|
15 |
+
|
16 |
+
# Type variable for cache
|
17 |
+
T = TypeVar('T')
|
18 |
+
|
19 |
+
# Configure logging
|
20 |
+
log = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
# Redis connection pools for reusing connections
|
23 |
+
_redis_pool_async = None
|
24 |
+
_redis_pool_sync = None # Synchronous pool
|
25 |
+
|
26 |
+
# Default cache expiration (12 hours)
|
27 |
+
DEFAULT_CACHE_EXPIRY = 60 * 60 * 12
|
28 |
+
|
29 |
+
@retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=1, max=10))
|
30 |
+
async def get_redis_pool() -> redis_async.Redis:
|
31 |
+
"""Get or create async Redis connection pool with retry logic."""
|
32 |
+
global _redis_pool_async
|
33 |
+
|
34 |
+
if _redis_pool_async is None:
|
35 |
+
# Get Redis configuration from settings
|
36 |
+
redis_url = settings.REDIS_URL or "redis://localhost:6379/0"
|
37 |
+
|
38 |
+
try:
|
39 |
+
# Create connection pool with reasonable defaults
|
40 |
+
_redis_pool_async = redis_async.ConnectionPool.from_url(
|
41 |
+
redis_url,
|
42 |
+
max_connections=10,
|
43 |
+
decode_responses=True,
|
44 |
+
health_check_interval=5,
|
45 |
+
socket_connect_timeout=5,
|
46 |
+
socket_keepalive=True,
|
47 |
+
retry_on_timeout=True
|
48 |
+
)
|
49 |
+
log.info(f"Created async Redis connection pool with URL: {redis_url}")
|
50 |
+
except Exception as e:
|
51 |
+
log.error(f"Error creating async Redis connection pool: {e}")
|
52 |
+
raise
|
53 |
+
|
54 |
+
return redis_async.Redis(connection_pool=_redis_pool_async)
|
55 |
+
|
56 |
+
def get_redis_pool_sync() -> redis_sync.Redis:
|
57 |
+
"""Get or create synchronous Redis connection pool."""
|
58 |
+
global _redis_pool_sync
|
59 |
+
|
60 |
+
if _redis_pool_sync is None:
|
61 |
+
# Get Redis configuration from settings
|
62 |
+
redis_url = settings.REDIS_URL or "redis://localhost:6379/0"
|
63 |
+
|
64 |
+
try:
|
65 |
+
# Create connection pool with reasonable defaults
|
66 |
+
_redis_pool_sync = redis_sync.ConnectionPool.from_url(
|
67 |
+
redis_url,
|
68 |
+
max_connections=10,
|
69 |
+
decode_responses=True,
|
70 |
+
socket_connect_timeout=5,
|
71 |
+
socket_keepalive=True,
|
72 |
+
retry_on_timeout=True
|
73 |
+
)
|
74 |
+
log.info(f"Created sync Redis connection pool with URL: {redis_url}")
|
75 |
+
except Exception as e:
|
76 |
+
log.error(f"Error creating sync Redis connection pool: {e}")
|
77 |
+
raise
|
78 |
+
|
79 |
+
return redis_sync.Redis(connection_pool=_redis_pool_sync)
|
80 |
+
|
81 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=5))
|
82 |
+
async def get_redis() -> redis_async.Redis:
|
83 |
+
"""Get Redis client from pool with retry logic."""
|
84 |
+
try:
|
85 |
+
redis_client = await get_redis_pool()
|
86 |
+
return redis_client
|
87 |
+
except Exception as e:
|
88 |
+
log.error(f"Error getting Redis client: {e}")
|
89 |
+
raise
|
90 |
+
|
91 |
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=5))
|
92 |
+
def get_redis_sync() -> redis_sync.Redis:
|
93 |
+
"""Get synchronous Redis client from pool with retry logic."""
|
94 |
+
try:
|
95 |
+
return get_redis_pool_sync()
|
96 |
+
except Exception as e:
|
97 |
+
log.error(f"Error getting synchronous Redis client: {e}")
|
98 |
+
raise
|
99 |
+
|
100 |
+
# Cache key generation
|
101 |
+
def generate_cache_key(prefix: str, *args: Any) -> str:
|
102 |
+
"""Generate cache key with prefix and args."""
|
103 |
+
key_parts = [prefix] + [str(arg) for arg in args if arg]
|
104 |
+
return ":".join(key_parts)
|
105 |
+
|
106 |
+
# JSON serialization helpers
|
107 |
+
def _json_serialize(obj: Any) -> str:
|
108 |
+
"""Serialize object to JSON with datetime support."""
|
109 |
+
def _serialize_datetime(o: Any) -> str:
|
110 |
+
if isinstance(o, datetime):
|
111 |
+
return o.isoformat()
|
112 |
+
if isinstance(o, BaseModel):
|
113 |
+
return o.dict()
|
114 |
+
return str(o)
|
115 |
+
|
116 |
+
return json.dumps(obj, default=_serialize_datetime)
|
117 |
+
|
118 |
+
def _json_deserialize(data: str, model_class: Optional[type] = None) -> Any:
|
119 |
+
"""Deserialize JSON string to object with datetime support."""
|
120 |
+
result = json.loads(data)
|
121 |
+
|
122 |
+
if model_class and issubclass(model_class, BaseModel):
|
123 |
+
return model_class.parse_obj(result)
|
124 |
+
|
125 |
+
return result
|
126 |
+
|
127 |
+
# Async cache operations
|
128 |
+
async def cache_set(key: str, value: Any, expire: int = DEFAULT_CACHE_EXPIRY) -> bool:
|
129 |
+
"""Set cache value with expiration (async version)."""
|
130 |
+
redis_client = await get_redis()
|
131 |
+
serialized = _json_serialize(value)
|
132 |
+
|
133 |
+
try:
|
134 |
+
await redis_client.set(key, serialized, ex=expire)
|
135 |
+
log.debug(f"Cached data at key: {key}, expires in {expire}s")
|
136 |
+
return True
|
137 |
+
except Exception as e:
|
138 |
+
log.error(f"Error caching data at key {key}: {e}")
|
139 |
+
return False
|
140 |
+
|
141 |
+
async def cache_get(key: str, model_class: Optional[type] = None) -> Optional[Any]:
|
142 |
+
"""Get cache value with optional model deserialization (async version)."""
|
143 |
+
redis_client = await get_redis()
|
144 |
+
|
145 |
+
try:
|
146 |
+
data = await redis_client.get(key)
|
147 |
+
if not data:
|
148 |
+
return None
|
149 |
+
|
150 |
+
log.debug(f"Cache hit for key: {key}")
|
151 |
+
return _json_deserialize(data, model_class)
|
152 |
+
except Exception as e:
|
153 |
+
log.error(f"Error retrieving cache for key {key}: {e}")
|
154 |
+
return None
|
155 |
+
|
156 |
+
# Synchronous cache operations for Celery tasks
|
157 |
+
def sync_cache_set(key: str, value: Any, expire: int = DEFAULT_CACHE_EXPIRY) -> bool:
|
158 |
+
"""Set cache value with expiration (synchronous version for Celery tasks). Logs slow operations."""
|
159 |
+
redis_client = get_redis_sync()
|
160 |
+
serialized = _json_serialize(value)
|
161 |
+
start = _time()
|
162 |
+
try:
|
163 |
+
redis_client.set(key, serialized, ex=expire)
|
164 |
+
elapsed = _time() - start
|
165 |
+
if elapsed > 2:
|
166 |
+
log.warning(f"Slow sync_cache_set for key {key}: {elapsed:.2f}s")
|
167 |
+
log.debug(f"Cached data at key: {key}, expires in {expire}s (sync)")
|
168 |
+
return True
|
169 |
+
except Exception as e:
|
170 |
+
log.error(f"Error caching data at key {key}: {e}")
|
171 |
+
return False
|
172 |
+
|
173 |
+
def sync_cache_get(key: str, model_class: Optional[type] = None) -> Optional[Any]:
|
174 |
+
"""Get cache value with optional model deserialization (synchronous version for Celery tasks). Logs slow operations."""
|
175 |
+
redis_client = get_redis_sync()
|
176 |
+
start = _time()
|
177 |
+
try:
|
178 |
+
data = redis_client.get(key)
|
179 |
+
elapsed = _time() - start
|
180 |
+
if elapsed > 2:
|
181 |
+
log.warning(f"Slow sync_cache_get for key {key}: {elapsed:.2f}s")
|
182 |
+
if not data:
|
183 |
+
return None
|
184 |
+
log.debug(f"Cache hit for key: {key} (sync)")
|
185 |
+
return _json_deserialize(data, model_class)
|
186 |
+
except Exception as e:
|
187 |
+
log.error(f"Error retrieving cache for key {key}: {e}")
|
188 |
+
return None
|
189 |
+
|
190 |
+
async def cache_invalidate(key: str) -> bool:
|
191 |
+
"""Invalidate cache for key."""
|
192 |
+
redis_client = await get_redis()
|
193 |
+
|
194 |
+
try:
|
195 |
+
await redis_client.delete(key)
|
196 |
+
log.debug(f"Invalidated cache for key: {key}")
|
197 |
+
return True
|
198 |
+
except Exception as e:
|
199 |
+
log.error(f"Error invalidating cache for key {key}: {e}")
|
200 |
+
return False
|
201 |
+
|
202 |
+
async def cache_invalidate_pattern(pattern: str) -> int:
|
203 |
+
"""Invalidate all cache keys matching pattern."""
|
204 |
+
redis_client = await get_redis()
|
205 |
+
|
206 |
+
try:
|
207 |
+
keys = await redis_client.keys(pattern)
|
208 |
+
if not keys:
|
209 |
+
return 0
|
210 |
+
|
211 |
+
count = await redis_client.delete(*keys)
|
212 |
+
log.debug(f"Invalidated {count} keys matching pattern: {pattern}")
|
213 |
+
return count
|
214 |
+
except Exception as e:
|
215 |
+
log.error(f"Error invalidating keys with pattern {pattern}: {e}")
|
216 |
+
return 0
|
217 |
+
|
218 |
+
# Task queue operations
|
219 |
+
async def enqueue_task(queue_name: str, task_id: str, payload: Dict[str, Any]) -> bool:
|
220 |
+
"""Add task to queue."""
|
221 |
+
redis_client = await get_redis()
|
222 |
+
|
223 |
+
try:
|
224 |
+
serialized = _json_serialize(payload)
|
225 |
+
await redis_client.lpush(f"queue:{queue_name}", serialized)
|
226 |
+
await redis_client.hset(f"tasks:{queue_name}", task_id, "pending")
|
227 |
+
log.info(f"Enqueued task {task_id} to queue {queue_name}")
|
228 |
+
return True
|
229 |
+
except Exception as e:
|
230 |
+
log.error(f"Error enqueueing task {task_id} to {queue_name}: {e}")
|
231 |
+
return False
|
232 |
+
|
233 |
+
async def mark_task_complete(queue_name: str, task_id: str, result: Optional[Dict[str, Any]] = None) -> bool:
|
234 |
+
"""Mark task as complete with optional result."""
|
235 |
+
redis_client = await get_redis()
|
236 |
+
|
237 |
+
try:
|
238 |
+
# Store result if provided
|
239 |
+
if result:
|
240 |
+
await redis_client.hset(
|
241 |
+
f"results:{queue_name}",
|
242 |
+
task_id,
|
243 |
+
_json_serialize(result)
|
244 |
+
)
|
245 |
+
|
246 |
+
# Mark task as complete
|
247 |
+
await redis_client.hset(f"tasks:{queue_name}", task_id, "complete")
|
248 |
+
await redis_client.expire(f"tasks:{queue_name}", 86400) # Expire after 24 hours
|
249 |
+
|
250 |
+
log.info(f"Marked task {task_id} as complete in queue {queue_name}")
|
251 |
+
return True
|
252 |
+
except Exception as e:
|
253 |
+
log.error(f"Error marking task {task_id} as complete: {e}")
|
254 |
+
return False
|
255 |
+
|
256 |
+
async def get_task_status(queue_name: str, task_id: str) -> Optional[str]:
|
257 |
+
"""Get status of a task."""
|
258 |
+
redis_client = await get_redis()
|
259 |
+
|
260 |
+
try:
|
261 |
+
status = await redis_client.hget(f"tasks:{queue_name}", task_id)
|
262 |
+
return status
|
263 |
+
except Exception as e:
|
264 |
+
log.error(f"Error getting status for task {task_id}: {e}")
|
265 |
+
return None
|
266 |
+
|
267 |
+
async def get_task_result(queue_name: str, task_id: str) -> Optional[Dict[str, Any]]:
|
268 |
+
"""Get result of a completed task."""
|
269 |
+
redis_client = await get_redis()
|
270 |
+
|
271 |
+
try:
|
272 |
+
data = await redis_client.hget(f"results:{queue_name}", task_id)
|
273 |
+
if not data:
|
274 |
+
return None
|
275 |
+
|
276 |
+
return _json_deserialize(data)
|
277 |
+
except Exception as e:
|
278 |
+
log.error(f"Error getting result for task {task_id}: {e}")
|
279 |
+
return None
|
280 |
+
|
281 |
+
# Stream processing for real-time updates
|
282 |
+
async def add_to_stream(stream: str, data: Dict[str, Any], max_len: int = 1000) -> str:
|
283 |
+
"""Add event to Redis stream."""
|
284 |
+
redis_client = await get_redis()
|
285 |
+
|
286 |
+
try:
|
287 |
+
# Convert dict values to strings (Redis streams requirement)
|
288 |
+
entry = {k: _json_serialize(v) for k, v in data.items()}
|
289 |
+
|
290 |
+
# Add to stream with automatic ID generation
|
291 |
+
event_id = await redis_client.xadd(
|
292 |
+
stream,
|
293 |
+
entry,
|
294 |
+
maxlen=max_len,
|
295 |
+
approximate=True
|
296 |
+
)
|
297 |
+
|
298 |
+
log.debug(f"Added event {event_id} to stream {stream}")
|
299 |
+
return event_id
|
300 |
+
except Exception as e:
|
301 |
+
log.error(f"Error adding to stream {stream}: {e}")
|
302 |
+
raise
|
app/tasks/dataset_tasks.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
import asyncio
|
4 |
+
from datetime import datetime, timezone
|
5 |
+
from typing import Dict, List, Any, Optional, Tuple
|
6 |
+
from celery import Task, shared_task
|
7 |
+
from app.core.celery_app import get_celery_app
|
8 |
+
from app.services.hf_datasets import (
|
9 |
+
determine_impact_level_by_criteria,
|
10 |
+
get_hf_token,
|
11 |
+
get_dataset_size,
|
12 |
+
refresh_datasets_cache,
|
13 |
+
fetch_and_cache_all_datasets,
|
14 |
+
)
|
15 |
+
from app.services.redis_client import sync_cache_set, sync_cache_get, generate_cache_key
|
16 |
+
from app.core.config import settings
|
17 |
+
import requests
|
18 |
+
import os
|
19 |
+
|
20 |
+
# Configure logging
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
# Get Celery app instance
|
24 |
+
celery_app = get_celery_app()
|
25 |
+
|
26 |
+
# Constants
|
27 |
+
DATASET_CACHE_TTL = 60 * 60 * 24 * 30 # 30 days
|
28 |
+
BATCH_PROGRESS_CACHE_TTL = 60 * 60 * 24 * 7 # 7 days for batch progress
|
29 |
+
DATASET_SIZE_CACHE_TTL = 60 * 60 * 24 * 30 # 30 days for size info
|
30 |
+
|
31 |
+
@celery_app.task(name="app.tasks.dataset_tasks.refresh_hf_datasets_cache")
|
32 |
+
def refresh_hf_datasets_cache():
|
33 |
+
"""Celery task to refresh the HuggingFace datasets cache in Redis."""
|
34 |
+
logger.info("Starting refresh of HuggingFace datasets cache via Celery task.")
|
35 |
+
try:
|
36 |
+
refresh_datasets_cache()
|
37 |
+
logger.info("Successfully refreshed HuggingFace datasets cache.")
|
38 |
+
return {"status": "success"}
|
39 |
+
except Exception as e:
|
40 |
+
logger.error(f"Failed to refresh HuggingFace datasets cache: {e}")
|
41 |
+
return {"status": "error", "error": str(e)}
|
42 |
+
|
43 |
+
@shared_task(bind=True, max_retries=3, default_retry_delay=10)
|
44 |
+
def fetch_datasets_page(self, offset, limit):
|
45 |
+
"""
|
46 |
+
Celery task to fetch and cache a single page of datasets from Hugging Face.
|
47 |
+
Retries on failure.
|
48 |
+
"""
|
49 |
+
logger.info(f"[fetch_datasets_page] ENTRY: offset={offset}, limit={limit}")
|
50 |
+
try:
|
51 |
+
from app.services.hf_datasets import process_datasets_page
|
52 |
+
logger.info(f"[fetch_datasets_page] Calling process_datasets_page with offset={offset}, limit={limit}")
|
53 |
+
result = process_datasets_page(offset, limit)
|
54 |
+
logger.info(f"[fetch_datasets_page] SUCCESS: offset={offset}, limit={limit}, result={result}")
|
55 |
+
return result
|
56 |
+
except Exception as exc:
|
57 |
+
logger.error(f"[fetch_datasets_page] ERROR: offset={offset}, limit={limit}, exc={exc}", exc_info=True)
|
58 |
+
raise self.retry(exc=exc)
|
59 |
+
|
60 |
+
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
|
61 |
+
def refresh_hf_datasets_full_cache(self):
|
62 |
+
logger.info("[refresh_hf_datasets_full_cache] Starting full Hugging Face datasets cache refresh.")
|
63 |
+
try:
|
64 |
+
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
|
65 |
+
if not token:
|
66 |
+
logger.error("[refresh_hf_datasets_full_cache] HUGGINGFACEHUB_API_TOKEN not set.")
|
67 |
+
return {"status": "error", "error": "HUGGINGFACEHUB_API_TOKEN not set"}
|
68 |
+
count = fetch_and_cache_all_datasets(token)
|
69 |
+
logger.info(f"[refresh_hf_datasets_full_cache] Cached {count} datasets.")
|
70 |
+
return {"status": "ok", "cached": count}
|
71 |
+
except Exception as exc:
|
72 |
+
logger.error(f"[refresh_hf_datasets_full_cache] ERROR: {exc}", exc_info=True)
|
73 |
+
raise self.retry(exc=exc)
|
migrations/20250620000000_create_combined_datasets_table.sql
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-- Create combined_datasets table
|
2 |
+
CREATE TABLE IF NOT EXISTS public.combined_datasets (
|
3 |
+
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
4 |
+
name TEXT NOT NULL,
|
5 |
+
description TEXT,
|
6 |
+
source_datasets TEXT[] NOT NULL,
|
7 |
+
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
8 |
+
updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
9 |
+
created_by UUID REFERENCES auth.users(id),
|
10 |
+
impact_level TEXT CHECK (impact_level = ANY (ARRAY['low', 'medium', 'high']::text[])),
|
11 |
+
status TEXT NOT NULL DEFAULT 'processing',
|
12 |
+
combination_strategy TEXT NOT NULL DEFAULT 'merge',
|
13 |
+
size_bytes BIGINT,
|
14 |
+
file_count INTEGER,
|
15 |
+
downloads INTEGER,
|
16 |
+
likes INTEGER
|
17 |
+
);
|
18 |
+
|
19 |
+
-- Add indexes for faster querying
|
20 |
+
CREATE INDEX IF NOT EXISTS idx_combined_datasets_created_by ON public.combined_datasets(created_by);
|
21 |
+
CREATE INDEX IF NOT EXISTS idx_combined_datasets_impact_level ON public.combined_datasets(impact_level);
|
22 |
+
CREATE INDEX IF NOT EXISTS idx_combined_datasets_status ON public.combined_datasets(status);
|
23 |
+
|
24 |
+
-- Add Row Level Security (RLS) policies
|
25 |
+
ALTER TABLE public.combined_datasets ENABLE ROW LEVEL SECURITY;
|
26 |
+
|
27 |
+
-- Policy to allow users to see all combined datasets
|
28 |
+
CREATE POLICY "Anyone can view combined datasets"
|
29 |
+
ON public.combined_datasets
|
30 |
+
FOR SELECT USING (true);
|
31 |
+
|
32 |
+
-- Policy to allow users to create their own combined datasets
|
33 |
+
CREATE POLICY "Users can create their own combined datasets"
|
34 |
+
ON public.combined_datasets
|
35 |
+
FOR INSERT
|
36 |
+
WITH CHECK (auth.uid() = created_by);
|
37 |
+
|
38 |
+
-- Policy to allow users to update only their own combined datasets
|
39 |
+
CREATE POLICY "Users can update their own combined datasets"
|
40 |
+
ON public.combined_datasets
|
41 |
+
FOR UPDATE
|
42 |
+
USING (auth.uid() = created_by);
|
43 |
+
|
44 |
+
-- Function to automatically update updated_at timestamp
|
45 |
+
CREATE OR REPLACE FUNCTION update_combined_datasets_updated_at()
|
46 |
+
RETURNS TRIGGER AS $$
|
47 |
+
BEGIN
|
48 |
+
NEW.updated_at = now();
|
49 |
+
RETURN NEW;
|
50 |
+
END;
|
51 |
+
$$ LANGUAGE plpgsql;
|
52 |
+
|
53 |
+
-- Trigger to automatically update updated_at timestamp
|
54 |
+
CREATE TRIGGER update_combined_datasets_updated_at_trigger
|
55 |
+
BEFORE UPDATE ON public.combined_datasets
|
56 |
+
FOR EACH ROW
|
57 |
+
EXECUTE FUNCTION update_combined_datasets_updated_at();
|
setup.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name="collinear-tool",
|
5 |
+
version="0.1.0",
|
6 |
+
packages=find_packages(),
|
7 |
+
include_package_data=True,
|
8 |
+
)
|
tests/test_datasets.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import pytest
|
3 |
+
import requests
|
4 |
+
|
5 |
+
BASE_URL = os.environ.get("BASE_URL", "http://127.0.0.1:8000/api")
|
6 |
+
|
7 |
+
# --- /datasets ---
|
8 |
+
def test_list_datasets_http():
|
9 |
+
resp = requests.get(f"{BASE_URL}/datasets")
|
10 |
+
assert resp.status_code == 200
|
11 |
+
data = resp.json()
|
12 |
+
assert "items" in data
|
13 |
+
assert "total" in data
|
14 |
+
assert "warming_up" in data
|
15 |
+
|
16 |
+
def test_list_datasets_offset_limit_http():
|
17 |
+
resp = requests.get(f"{BASE_URL}/datasets?offset=0&limit=3")
|
18 |
+
assert resp.status_code == 200
|
19 |
+
data = resp.json()
|
20 |
+
assert isinstance(data["items"], list)
|
21 |
+
assert len(data["items"]) <= 3
|
22 |
+
|
23 |
+
def test_list_datasets_large_offset_http():
|
24 |
+
resp = requests.get(f"{BASE_URL}/datasets?offset=99999&limit=2")
|
25 |
+
assert resp.status_code == 200
|
26 |
+
data = resp.json()
|
27 |
+
assert data["items"] == []
|
28 |
+
assert "warming_up" in data
|
29 |
+
|
30 |
+
def test_list_datasets_invalid_limit_http():
|
31 |
+
resp = requests.get(f"{BASE_URL}/datasets?limit=-5")
|
32 |
+
assert resp.status_code == 422
|
33 |
+
|
34 |
+
# --- /datasets/cache-status ---
|
35 |
+
def test_cache_status_http():
|
36 |
+
resp = requests.get(f"{BASE_URL}/datasets/cache-status")
|
37 |
+
assert resp.status_code == 200
|
38 |
+
data = resp.json()
|
39 |
+
assert "warming_up" in data
|
40 |
+
assert "total_items" in data
|
41 |
+
assert "last_update" in data
|
42 |
+
|
43 |
+
# --- /datasets/{dataset_id}/commits ---
|
44 |
+
def test_commits_valid_http():
|
45 |
+
resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/commits")
|
46 |
+
assert resp.status_code in (200, 404)
|
47 |
+
if resp.status_code == 200:
|
48 |
+
assert isinstance(resp.json(), list)
|
49 |
+
|
50 |
+
def test_commits_invalid_http():
|
51 |
+
resp = requests.get(f"{BASE_URL}/datasets/invalid-dataset-id/commits")
|
52 |
+
assert resp.status_code in (404, 422)
|
53 |
+
|
54 |
+
# --- /datasets/{dataset_id}/files ---
|
55 |
+
def test_files_valid_http():
|
56 |
+
resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/files")
|
57 |
+
assert resp.status_code in (200, 404)
|
58 |
+
if resp.status_code == 200:
|
59 |
+
assert isinstance(resp.json(), list)
|
60 |
+
|
61 |
+
def test_files_invalid_http():
|
62 |
+
resp = requests.get(f"{BASE_URL}/datasets/invalid-dataset-id/files")
|
63 |
+
assert resp.status_code in (404, 422)
|
64 |
+
|
65 |
+
# --- /datasets/{dataset_id}/file-url ---
|
66 |
+
def test_file_url_valid_http():
|
67 |
+
resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/file-url", params={"filename": "README.md"})
|
68 |
+
assert resp.status_code in (200, 404)
|
69 |
+
if resp.status_code == 200:
|
70 |
+
assert "download_url" in resp.json()
|
71 |
+
|
72 |
+
def test_file_url_invalid_file_http():
|
73 |
+
resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/file-url", params={"filename": "not_a_real_file.txt"})
|
74 |
+
assert resp.status_code in (404, 200)
|
75 |
+
|
76 |
+
def test_file_url_missing_filename_http():
|
77 |
+
resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/file-url")
|
78 |
+
assert resp.status_code in (404, 422)
|
tests/test_datasets_api.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
from fastapi.testclient import TestClient
|
3 |
+
from app.main import app
|
4 |
+
|
5 |
+
client = TestClient(app)
|
6 |
+
|
7 |
+
# --- /api/datasets ---
|
8 |
+
def test_list_datasets_default():
|
9 |
+
resp = client.get("/api/datasets")
|
10 |
+
assert resp.status_code == 200
|
11 |
+
data = resp.json()
|
12 |
+
assert "items" in data
|
13 |
+
assert isinstance(data["items"], list)
|
14 |
+
assert "total" in data
|
15 |
+
assert "warming_up" in data
|
16 |
+
|
17 |
+
def test_list_datasets_offset_limit():
|
18 |
+
resp = client.get("/api/datasets?offset=0&limit=2")
|
19 |
+
assert resp.status_code == 200
|
20 |
+
data = resp.json()
|
21 |
+
assert isinstance(data["items"], list)
|
22 |
+
assert len(data["items"]) <= 2
|
23 |
+
|
24 |
+
def test_list_datasets_large_offset():
|
25 |
+
resp = client.get("/api/datasets?offset=100000&limit=2")
|
26 |
+
assert resp.status_code == 200
|
27 |
+
data = resp.json()
|
28 |
+
assert data["items"] == []
|
29 |
+
assert data["warming_up"] in (True, False)
|
30 |
+
|
31 |
+
def test_list_datasets_negative_limit():
|
32 |
+
resp = client.get("/api/datasets?limit=-1")
|
33 |
+
assert resp.status_code == 422
|
34 |
+
|
35 |
+
def test_list_datasets_missing_params():
|
36 |
+
resp = client.get("/api/datasets")
|
37 |
+
assert resp.status_code == 200
|
38 |
+
data = resp.json()
|
39 |
+
assert "items" in data
|
40 |
+
assert "total" in data
|
41 |
+
assert "warming_up" in data
|
42 |
+
|
43 |
+
# --- /api/datasets/cache-status ---
|
44 |
+
def test_cache_status():
|
45 |
+
resp = client.get("/api/datasets/cache-status")
|
46 |
+
assert resp.status_code == 200
|
47 |
+
data = resp.json()
|
48 |
+
assert "warming_up" in data
|
49 |
+
assert "total_items" in data
|
50 |
+
assert "last_update" in data
|
51 |
+
|
52 |
+
# --- /api/datasets/{dataset_id}/commits ---
|
53 |
+
def test_get_commits_valid():
|
54 |
+
resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/commits")
|
55 |
+
# Accept 200 (found) or 404 (not found)
|
56 |
+
assert resp.status_code in (200, 404)
|
57 |
+
if resp.status_code == 200:
|
58 |
+
assert isinstance(resp.json(), list)
|
59 |
+
|
60 |
+
def test_get_commits_invalid():
|
61 |
+
resp = client.get("/api/datasets/invalid-dataset-id/commits")
|
62 |
+
assert resp.status_code in (404, 422)
|
63 |
+
|
64 |
+
# --- /api/datasets/{dataset_id}/files ---
|
65 |
+
def test_list_files_valid():
|
66 |
+
resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/files")
|
67 |
+
assert resp.status_code in (200, 404)
|
68 |
+
if resp.status_code == 200:
|
69 |
+
assert isinstance(resp.json(), list)
|
70 |
+
|
71 |
+
def test_list_files_invalid():
|
72 |
+
resp = client.get("/api/datasets/invalid-dataset-id/files")
|
73 |
+
assert resp.status_code in (404, 422)
|
74 |
+
|
75 |
+
# --- /api/datasets/{dataset_id}/file-url ---
|
76 |
+
def test_get_file_url_valid():
|
77 |
+
resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/file-url", params={"filename": "README.md"})
|
78 |
+
assert resp.status_code in (200, 404)
|
79 |
+
if resp.status_code == 200:
|
80 |
+
assert "download_url" in resp.json()
|
81 |
+
|
82 |
+
def test_get_file_url_invalid_file():
|
83 |
+
resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/file-url", params={"filename": "not_a_real_file.txt"})
|
84 |
+
assert resp.status_code in (404, 200)
|
85 |
+
|
86 |
+
def test_get_file_url_missing_filename():
|
87 |
+
resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/file-url")
|
88 |
+
assert resp.status_code in (404, 422)
|