iaroy commited on
Commit
fdc5d7a
·
1 Parent(s): f401f1d

Deploy full application code

Browse files
.gitignore ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Environment
24
+ .env
25
+ .venv
26
+ env/
27
+ venv/
28
+ ENV/
29
+
30
+ # Logs
31
+ *.log
32
+ logs/
33
+ celery_worker_*.log
34
+ nohup.out
35
+
36
+ # Database
37
+ *.sqlite
38
+ *.db
39
+ celerybeat-schedule
Dockerfile CHANGED
@@ -1,14 +1,17 @@
1
  # Use the official Python 3.10.9 image
2
  FROM python:3.10.9
3
 
4
- # Copy the current directory contents into the container at .
5
- COPY . .
6
 
7
- # Set the working directory to /
8
- WORKDIR /
9
 
10
- # Install requirements.txt
11
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
12
 
 
 
 
13
  # Start the FastAPI app on port 7860, the default port expected by Spaces
14
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
  # Use the official Python 3.10.9 image
2
  FROM python:3.10.9
3
 
4
+ # Set the working directory
5
+ WORKDIR /app
6
 
7
+ # Copy the current directory contents into the container
8
+ COPY . .
9
 
10
+ # Install requirements.txt
11
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
12
 
13
+ # Install the application in development mode
14
+ RUN pip install -e .
15
+
16
  # Start the FastAPI app on port 7860, the default port expected by Spaces
17
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,17 +1,7 @@
1
- from fastapi import FastAPI
2
- import uvicorn
3
 
4
- # Create a FastAPI app for Hugging Face Spaces
5
- app = FastAPI(title="Collinear API")
6
-
7
- @app.get("/")
8
- async def root():
9
- return {"message": "Welcome to Collinear API"}
10
-
11
- @app.get("/health")
12
- async def health():
13
- return {"status": "healthy"}
14
 
15
  if __name__ == "__main__":
16
- # This is used when running locally
17
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
 
1
+ from app.main import app
 
2
 
3
+ # This file is needed for Hugging Face Spaces to find the app
 
 
 
 
 
 
 
 
 
4
 
5
  if __name__ == "__main__":
6
+ import uvicorn
7
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True)
app/api/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from app.api.datasets import router as datasets_router
3
+ # from . import batch # Removed batch import
4
+
5
+ api_router = APIRouter()
6
+ api_router.include_router(datasets_router, tags=["datasets"])
7
+ # api_router.include_router(batch.router) # Removed batch router
app/api/datasets.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Query, HTTPException
2
+ from typing import List, Optional, Dict, Any, Set
3
+ from pydantic import BaseModel
4
+ from fastapi.concurrency import run_in_threadpool
5
+ from app.services.hf_datasets import (
6
+ get_dataset_commits,
7
+ get_dataset_files,
8
+ get_file_url,
9
+ get_datasets_page_from_zset,
10
+ get_dataset_commits_async,
11
+ get_dataset_files_async,
12
+ get_file_url_async,
13
+ get_datasets_page_from_cache,
14
+ fetch_and_cache_all_datasets,
15
+ )
16
+ from app.services.redis_client import cache_get
17
+ import logging
18
+ import time
19
+ from fastapi.responses import JSONResponse
20
+ import os
21
+
22
+ router = APIRouter(prefix="/datasets", tags=["datasets"])
23
+ log = logging.getLogger(__name__)
24
+
25
+ SIZE_LOW = 100 * 1024 * 1024
26
+ SIZE_MEDIUM = 1024 * 1024 * 1024
27
+
28
+ class DatasetInfo(BaseModel):
29
+ id: str
30
+ name: Optional[str]
31
+ description: Optional[str]
32
+ size_bytes: Optional[int]
33
+ impact_level: Optional[str]
34
+ downloads: Optional[int]
35
+ likes: Optional[int]
36
+ tags: Optional[List[str]]
37
+ class Config:
38
+ extra = "ignore"
39
+
40
+ class PaginatedDatasets(BaseModel):
41
+ total: int
42
+ items: List[DatasetInfo]
43
+
44
+ class CommitInfo(BaseModel):
45
+ id: str
46
+ title: Optional[str]
47
+ message: Optional[str]
48
+ author: Optional[Dict[str, Any]]
49
+ date: Optional[str]
50
+
51
+ class CacheStatus(BaseModel):
52
+ last_update: Optional[str]
53
+ total_items: int
54
+ warming_up: bool
55
+
56
+ def deduplicate_by_id(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
57
+ seen: Set[str] = set()
58
+ unique_items = []
59
+ for item in items:
60
+ item_id = item.get("id")
61
+ if item_id and item_id not in seen:
62
+ seen.add(item_id)
63
+ unique_items.append(item)
64
+ return unique_items
65
+
66
+ @router.get("/cache-status", response_model=CacheStatus)
67
+ async def cache_status():
68
+ meta = await cache_get("hf:datasets:meta")
69
+ last_update = meta["last_update"] if meta and "last_update" in meta else None
70
+ total_items = meta["total_items"] if meta and "total_items" in meta else 0
71
+ warming_up = not bool(total_items)
72
+ return CacheStatus(last_update=last_update, total_items=total_items, warming_up=warming_up)
73
+
74
+ @router.get("/", response_model=None)
75
+ async def list_datasets(
76
+ limit: int = Query(10, ge=1, le=1000),
77
+ offset: int = Query(0, ge=0),
78
+ search: str = Query(None, description="Search term for dataset id or description"),
79
+ sort_by: str = Query(None, description="Field to sort by (e.g., 'downloads', 'likes', 'created_at')"),
80
+ sort_order: str = Query("desc", regex="^(asc|desc)$", description="Sort order: 'asc' or 'desc'"),
81
+ ):
82
+ # Fetch the full list from cache
83
+ result, status = get_datasets_page_from_cache(1000000, 0) # get all for in-memory filtering
84
+ if status != 200:
85
+ return JSONResponse(result, status_code=status)
86
+ items = result["items"]
87
+ # Filtering
88
+ if search:
89
+ items = [d for d in items if search.lower() in (d.get("id", "") + " " + str(d.get("description", "")).lower())]
90
+ # Sorting
91
+ if sort_by:
92
+ items = sorted(items, key=lambda d: d.get(sort_by) or 0, reverse=(sort_order == "desc"))
93
+ # Pagination
94
+ total = len(items)
95
+ page = items[offset:offset+limit]
96
+ total_pages = (total + limit - 1) // limit
97
+ current_page = (offset // limit) + 1
98
+ next_page = current_page + 1 if offset + limit < total else None
99
+ prev_page = current_page - 1 if current_page > 1 else None
100
+ return {
101
+ "total": total,
102
+ "current_page": current_page,
103
+ "total_pages": total_pages,
104
+ "next_page": next_page,
105
+ "prev_page": prev_page,
106
+ "items": page
107
+ }
108
+
109
+ @router.get("/{dataset_id:path}/commits", response_model=List[CommitInfo])
110
+ async def get_commits(dataset_id: str):
111
+ """
112
+ Get commit history for a dataset.
113
+ """
114
+ try:
115
+ return await get_dataset_commits_async(dataset_id)
116
+ except Exception as e:
117
+ log.error(f"Error fetching commits for {dataset_id}: {e}")
118
+ raise HTTPException(status_code=404, detail=f"Could not fetch commits: {e}")
119
+
120
+ @router.get("/{dataset_id:path}/files", response_model=List[str])
121
+ async def list_files(dataset_id: str):
122
+ """
123
+ List files in a dataset.
124
+ """
125
+ try:
126
+ return await get_dataset_files_async(dataset_id)
127
+ except Exception as e:
128
+ log.error(f"Error listing files for {dataset_id}: {e}")
129
+ raise HTTPException(status_code=404, detail=f"Could not list files: {e}")
130
+
131
+ @router.get("/{dataset_id:path}/file-url")
132
+ async def get_file_url_endpoint(dataset_id: str, filename: str = Query(...), revision: Optional[str] = None):
133
+ """
134
+ Get download URL for a file in a dataset.
135
+ """
136
+ url = await get_file_url_async(dataset_id, filename, revision)
137
+ return {"download_url": url}
138
+
139
+ @router.get("/meta")
140
+ async def get_datasets_meta():
141
+ meta = await cache_get("hf:datasets:meta")
142
+ return meta if meta else {}
143
+
144
+ # Endpoint to trigger cache refresh manually (for admin/testing)
145
+ @router.post("/datasets/refresh-cache")
146
+ def refresh_cache():
147
+ token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
148
+ if not token:
149
+ return JSONResponse({"error": "HUGGINGFACEHUB_API_TOKEN not set"}, status_code=500)
150
+ count = fetch_and_cache_all_datasets(token)
151
+ return {"status": "ok", "cached": count}
app/core/celery_app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Celery configuration for task processing."""
2
+
3
+ import logging
4
+ from celery import Celery
5
+ from celery.signals import task_failure, task_success, worker_ready, worker_shutdown
6
+
7
+ from app.core.config import settings
8
+
9
+ # Configure logging
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Celery configuration
13
+ celery_app = Celery(
14
+ "dataset_impacts",
15
+ broker=settings.REDIS_URL,
16
+ backend=settings.REDIS_URL,
17
+ )
18
+
19
+ # Configure Celery settings
20
+ celery_app.conf.update(
21
+ task_serializer="json",
22
+ accept_content=["json"],
23
+ result_serializer="json",
24
+ timezone="UTC",
25
+ enable_utc=True,
26
+ worker_concurrency=settings.WORKER_CONCURRENCY,
27
+ task_acks_late=True, # Tasks are acknowledged after execution
28
+ task_reject_on_worker_lost=True, # Tasks are rejected if worker is terminated during execution
29
+ task_time_limit=3600, # 1 hour timeout per task
30
+ task_soft_time_limit=3000, # Soft timeout (30 minutes) - allows for graceful shutdown
31
+ worker_prefetch_multiplier=1, # Single prefetch - improves fair distribution of tasks
32
+ broker_connection_retry=True,
33
+ broker_connection_retry_on_startup=True,
34
+ broker_connection_max_retries=10,
35
+ broker_pool_limit=10, # Connection pool size
36
+ result_expires=60 * 60 * 24, # Results expire after 24 hours
37
+ task_track_started=True, # Track when tasks are started
38
+ )
39
+
40
+ # Set up task routes for different task types
41
+ celery_app.conf.task_routes = {
42
+ "app.tasks.dataset_tasks.*": {"queue": "dataset_impacts"},
43
+ "app.tasks.maintenance.*": {"queue": "maintenance"},
44
+ }
45
+
46
+ # Configure retry settings
47
+ celery_app.conf.task_default_retry_delay = 30 # 30 seconds
48
+ celery_app.conf.task_max_retries = 3
49
+
50
+ # Setup beat schedule for periodic tasks if enabled
51
+ celery_app.conf.beat_schedule = {
52
+ "cleanup-stale-tasks": {
53
+ "task": "app.tasks.maintenance.cleanup_stale_tasks",
54
+ "schedule": 3600.0, # Run every hour
55
+ },
56
+ "health-check": {
57
+ "task": "app.tasks.maintenance.health_check",
58
+ "schedule": 300.0, # Run every 5 minutes
59
+ },
60
+ "refresh-hf-datasets-cache": {
61
+ "task": "app.tasks.dataset_tasks.refresh_hf_datasets_cache",
62
+ "schedule": 3600.0, # Run every hour
63
+ },
64
+ }
65
+
66
+ # Signal handlers for monitoring and logging
67
+ @task_failure.connect
68
+ def task_failure_handler(sender=None, task_id=None, exception=None, **kwargs):
69
+ """Log failed tasks."""
70
+ logger.error(f"Task {task_id} {sender.name} failed: {exception}")
71
+
72
+ @task_success.connect
73
+ def task_success_handler(sender=None, result=None, **kwargs):
74
+ """Log successful tasks."""
75
+ task_name = sender.name if sender else "Unknown"
76
+ logger.info(f"Task {task_name} completed successfully")
77
+
78
+ @worker_ready.connect
79
+ def worker_ready_handler(**kwargs):
80
+ """Log when worker is ready."""
81
+ logger.info(f"Celery worker ready: {kwargs.get('hostname')}")
82
+
83
+ @worker_shutdown.connect
84
+ def worker_shutdown_handler(**kwargs):
85
+ """Log when worker is shutting down."""
86
+ logger.info(f"Celery worker shutting down: {kwargs.get('hostname')}")
87
+
88
+ def get_celery_app():
89
+ """Get the Celery app instance."""
90
+ # Import all tasks to ensure they're registered
91
+ try:
92
+ # Using the improved app.tasks module which properly imports all tasks
93
+ import app.tasks
94
+ logger.info(f"Tasks successfully imported; registered {len(celery_app.tasks)} tasks")
95
+ except ImportError as e:
96
+ logger.error(f"Error importing tasks: {e}")
97
+
98
+ return celery_app
app/core/config.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Final, Optional
4
+
5
+ from pydantic import SecretStr, HttpUrl, Field
6
+ from pydantic_settings import BaseSettings, SettingsConfigDict
7
+
8
+ class Settings(BaseSettings):
9
+ """
10
+ Core application settings.
11
+ Reads environment variables and .env file.
12
+ """
13
+ # Supabase Settings
14
+ SUPABASE_URL: HttpUrl
15
+ SUPABASE_SERVICE_KEY: SecretStr
16
+ SUPABASE_ANON_KEY: SecretStr
17
+ SUPABASE_JWT_SECRET: Optional[SecretStr] = None # Optional for local dev
18
+
19
+ # Hugging Face API Token
20
+ HF_API_TOKEN: Optional[SecretStr] = None
21
+
22
+ # Redis settings
23
+ REDIS_URL: str = "redis://localhost:6379/0"
24
+ REDIS_PASSWORD: Optional[SecretStr] = None
25
+
26
+ # Toggle Redis cache layer
27
+ ENABLE_REDIS_CACHE: bool = True
28
+
29
+ # ──────────────────────────────── Security ────────────────────────────────
30
+ # JWT secret key. NEVER hard-code in source; override with env variable in production.
31
+ SECRET_KEY: SecretStr = Field("change-me", env="SECRET_KEY")
32
+ ACCESS_TOKEN_EXPIRE_MINUTES: int = Field(60 * 24 * 7, env="ACCESS_TOKEN_EXPIRE_MINUTES") # 1 week by default
33
+
34
+ # Worker settings
35
+ WORKER_CONCURRENCY: int = 10 # Increased from 5 for better parallel performance
36
+
37
+ # Batch processing chunk size for Celery dataset tasks
38
+ DATASET_BATCH_CHUNK_SIZE: int = 50
39
+
40
+ # Tell pydantic-settings to auto-load `.env` if present
41
+ model_config: Final = SettingsConfigDict(
42
+ env_file=".env",
43
+ case_sensitive=False,
44
+ extra="ignore"
45
+ )
46
+
47
+ # Single, shared instance of settings
48
+ settings = Settings()
app/main.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import json
3
+ from fastapi import FastAPI
4
+ from app.api import api_router
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from fastapi.middleware.gzip import GZipMiddleware
7
+
8
+ class JsonFormatter(logging.Formatter):
9
+ def format(self, record):
10
+ log_record = {
11
+ "level": record.levelname,
12
+ "time": self.formatTime(record, self.datefmt),
13
+ "name": record.name,
14
+ "message": record.getMessage(),
15
+ }
16
+ if record.exc_info:
17
+ log_record["exc_info"] = self.formatException(record.exc_info)
18
+ return json.dumps(log_record)
19
+
20
+ handler = logging.StreamHandler()
21
+ handler.setFormatter(JsonFormatter())
22
+ logging.basicConfig(level=logging.INFO, handlers=[handler])
23
+
24
+ app = FastAPI(title="Collinear API")
25
+
26
+ # Enable CORS for the frontend
27
+ frontend_origin = "http://localhost:5173"
28
+ app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=[frontend_origin],
31
+ allow_credentials=True,
32
+ allow_methods=["*"],
33
+ allow_headers=["*"],
34
+ )
35
+
36
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
37
+
38
+ app.include_router(api_router, prefix="/api")
39
+
40
+ @app.get("/")
41
+ async def root():
42
+ return {"message": "Welcome to the Collinear Data Tool API"}
43
+
44
+ if __name__ == "__main__":
45
+ import uvicorn
46
+ uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
app/schemas/dataset.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, List, Optional, Any
3
+ from datetime import datetime
4
+ from pydantic import BaseModel, Field
5
+
6
+ from app.schemas.dataset_common import ImpactLevel, DatasetMetrics
7
+
8
+ # Log for this module
9
+ log = logging.getLogger(__name__)
10
+
11
+ # Supported strategies for dataset combination
12
+ SUPPORTED_STRATEGIES = ["merge", "intersect", "filter"]
13
+
14
+ class ImpactAssessment(BaseModel):
15
+ dataset_id: str = Field(..., description="The ID of the dataset being assessed")
16
+ impact_level: ImpactLevel = Field(..., description="The impact level: low, medium, or high")
17
+ assessment_method: str = Field(
18
+ "unknown",
19
+ description="Method used to determine impact level (e.g., size_based, downloads_and_likes_based)"
20
+ )
21
+ metrics: DatasetMetrics = Field(
22
+ ...,
23
+ description="Metrics used for impact assessment"
24
+ )
25
+ thresholds: Dict[str, Dict[str, str]] = Field(
26
+ {},
27
+ description="Thresholds used for determining impact levels (for reference)"
28
+ )
29
+
30
+ class DatasetInfo(BaseModel):
31
+ id: str
32
+ impact_level: Optional[ImpactLevel] = None
33
+ impact_assessment: Optional[Dict] = None
34
+ # Add other fields as needed
35
+ class Config:
36
+ extra = "allow" # Allow extra fields from the API
37
+
38
+ class DatasetBase(BaseModel):
39
+ name: str
40
+ description: Optional[str] = None
41
+ tags: Optional[List[str]] = None
42
+
43
+ class DatasetCreate(DatasetBase):
44
+ files: Optional[List[str]] = None
45
+
46
+ class DatasetUpdate(DatasetBase):
47
+ name: Optional[str] = None # Make fields optional for updates
48
+
49
+ class Dataset(DatasetBase):
50
+ id: int # or str depending on your ID format
51
+ owner_id: str # Assuming user IDs are strings
52
+ created_at: Optional[str] = None
53
+ updated_at: Optional[str] = None
54
+ class Config:
55
+ pass # Removed orm_mode = True since ORM is not used
56
+
57
+ class DatasetCombineRequest(BaseModel):
58
+ source_datasets: List[str] = Field(..., description="List of dataset IDs to combine")
59
+ name: str = Field(..., description="Name for the combined dataset")
60
+ description: Optional[str] = Field(None, description="Description for the combined dataset")
61
+ combination_strategy: str = Field("merge", description="Strategy to use when combining datasets (e.g., 'merge', 'intersect', 'filter')")
62
+ filter_criteria: Optional[Dict[str, Any]] = Field(None, description="Criteria for filtering when combining datasets")
63
+
64
+ class CombinedDataset(BaseModel):
65
+ id: str = Field(..., description="ID of the combined dataset")
66
+ name: str = Field(..., description="Name of the combined dataset")
67
+ description: Optional[str] = Field(None, description="Description of the combined dataset")
68
+ source_datasets: List[str] = Field(..., description="IDs of the source datasets")
69
+ created_at: datetime = Field(..., description="Creation timestamp")
70
+ created_by: str = Field(..., description="ID of the user who created this combined dataset")
71
+ impact_level: Optional[ImpactLevel] = Field(None, description="Calculated impact level of the combined dataset")
72
+ status: str = Field("processing", description="Status of the dataset combination process")
73
+ combination_strategy: str = Field(..., description="Strategy used when combining datasets")
74
+ metrics: Optional[DatasetMetrics] = Field(None, description="Metrics for the combined dataset")
75
+ storage_bucket_id: Optional[str] = Field(None, description="ID of the storage bucket containing dataset files")
76
+ storage_folder_path: Optional[str] = Field(None, description="Path to the dataset files within the bucket")
77
+ class Config:
78
+ extra = "allow" # Allow extra fields for flexibility
79
+
80
+ __all__ = ["ImpactLevel", "ImpactAssessment", "DatasetInfo", "DatasetMetrics",
81
+ "Dataset", "DatasetCreate", "DatasetUpdate", "DatasetCombineRequest", "CombinedDataset"]
app/schemas/dataset_common.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from pydantic import BaseModel, Field
3
+ from typing import Optional
4
+
5
+ # Define the impact level as an enum for better type safety
6
+ class ImpactLevel(str, Enum):
7
+ NA = "not_available" # New category for when size information is unavailable
8
+ LOW = "low"
9
+ MEDIUM = "medium"
10
+ HIGH = "high"
11
+
12
+ # Define metrics model for impact assessment
13
+ class DatasetMetrics(BaseModel):
14
+ size_bytes: Optional[int] = Field(None, description="Size of the dataset in bytes")
15
+ file_count: Optional[int] = Field(None, description="Number of files in the dataset")
16
+ downloads: Optional[int] = Field(None, description="Number of downloads (all time)")
17
+ likes: Optional[int] = Field(None, description="Number of likes")
app/services/hf_datasets.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import json
3
+ from typing import Any, List, Optional, Dict, Tuple
4
+ import requests
5
+ from huggingface_hub import HfApi
6
+ from app.core.config import settings
7
+ from app.schemas.dataset_common import ImpactLevel
8
+ from app.services.redis_client import sync_cache_set, sync_cache_get, generate_cache_key, get_redis_sync
9
+ import time
10
+ import asyncio
11
+ import redis
12
+ import gzip
13
+ from datetime import datetime, timezone
14
+ import os
15
+ from app.schemas.dataset import ImpactAssessment
16
+ from app.schemas.dataset_common import DatasetMetrics
17
+ import httpx
18
+ import redis.asyncio as aioredis
19
+
20
+ log = logging.getLogger(__name__)
21
+ api = HfApi()
22
+ redis_client = redis.Redis(host="redis", port=6379, decode_responses=True)
23
+
24
+ # Thresholds for impact categorization
25
+ SIZE_THRESHOLD_LOW = 100 * 1024 * 1024 # 100 MB
26
+ SIZE_THRESHOLD_MEDIUM = 1024 * 1024 * 1024 # 1 GB
27
+ DOWNLOADS_THRESHOLD_LOW = 1000
28
+ DOWNLOADS_THRESHOLD_MEDIUM = 10000
29
+ LIKES_THRESHOLD_LOW = 10
30
+ LIKES_THRESHOLD_MEDIUM = 100
31
+
32
+ HF_API_URL = "https://huggingface.co/api/datasets"
33
+ DATASET_CACHE_TTL = 60 * 60 # 1 hour
34
+
35
+ # Redis and HuggingFace API setup
36
+ REDIS_KEY = "hf:datasets:all:compressed"
37
+ REDIS_META_KEY = "hf:datasets:meta"
38
+ REDIS_TTL = 60 * 60 # 1 hour
39
+
40
+ # Impact thresholds (in bytes)
41
+ SIZE_LOW = 100 * 1024 * 1024
42
+ SIZE_MEDIUM = 1024 * 1024 * 1024
43
+
44
+ def get_hf_token():
45
+ token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
46
+ if not token:
47
+ raise RuntimeError("HUGGINGFACEHUB_API_TOKEN environment variable is not set. Please set it securely.")
48
+ return token
49
+
50
+ def get_dataset_commits(dataset_id: str, limit: int = 20):
51
+ from huggingface_hub import HfApi
52
+ import logging
53
+ log = logging.getLogger(__name__)
54
+ api = HfApi()
55
+ log.info(f"[get_dataset_commits] Fetching commits for dataset_id={dataset_id}")
56
+ try:
57
+ commits = api.list_repo_commits(repo_id=dataset_id, repo_type="dataset")
58
+ log.info(f"[get_dataset_commits] Received {len(commits)} commits for {dataset_id}")
59
+ except Exception as e:
60
+ log.error(f"[get_dataset_commits] Error fetching commits for {dataset_id}: {e}", exc_info=True)
61
+ raise # Let the API layer catch and handle this
62
+ result = []
63
+ for c in commits[:limit]:
64
+ try:
65
+ commit_id = getattr(c, "commit_id", "")
66
+ title = getattr(c, "title", "")
67
+ message = getattr(c, "message", title)
68
+ authors = getattr(c, "authors", [])
69
+ author_name = authors[0] if authors and isinstance(authors, list) else ""
70
+ created_at = getattr(c, "created_at", None)
71
+ if created_at:
72
+ if hasattr(created_at, "isoformat"):
73
+ date = created_at.isoformat()
74
+ else:
75
+ date = str(created_at)
76
+ else:
77
+ date = ""
78
+ result.append({
79
+ "id": commit_id or "",
80
+ "title": title or message or "",
81
+ "message": message or title or "",
82
+ "author": {"name": author_name, "email": ""},
83
+ "date": date,
84
+ })
85
+ except Exception as e:
86
+ log.error(f"[get_dataset_commits] Error parsing commit: {e} | Commit: {getattr(c, '__dict__', str(c))}", exc_info=True)
87
+ log.info(f"[get_dataset_commits] Returning {len(result)} parsed commits for {dataset_id}")
88
+ return result
89
+
90
+ def get_dataset_files(dataset_id: str) -> List[str]:
91
+ return api.list_repo_files(repo_id=dataset_id, repo_type="dataset")
92
+
93
+ def get_file_url(dataset_id: str, filename: str, revision: Optional[str] = None) -> str:
94
+ from huggingface_hub import hf_hub_url
95
+ return hf_hub_url(repo_id=dataset_id, filename=filename, repo_type="dataset", revision=revision)
96
+
97
+ def get_datasets_page_from_zset(offset: int = 0, limit: int = 10, search: str = None) -> dict:
98
+ import redis
99
+ import json
100
+ redis_client = redis.Redis(host="redis", port=6379, db=0, decode_responses=True)
101
+ zset_key = "hf:datasets:all:zset"
102
+ hash_key = "hf:datasets:all:hash"
103
+ # Get total count
104
+ total = redis_client.zcard(zset_key)
105
+ # Get dataset IDs for the page
106
+ ids = redis_client.zrange(zset_key, offset, offset + limit - 1)
107
+ # Fetch metadata for those IDs
108
+ if not ids:
109
+ return {"items": [], "count": total}
110
+ items = redis_client.hmget(hash_key, ids)
111
+ # Parse JSON and filter/search if needed
112
+ parsed = []
113
+ for raw in items:
114
+ if not raw:
115
+ continue
116
+ try:
117
+ item = json.loads(raw)
118
+ parsed.append(item)
119
+ except Exception:
120
+ continue
121
+ if search:
122
+ parsed = [d for d in parsed if search.lower() in (d.get("id") or "").lower()]
123
+ return {"items": parsed, "count": total}
124
+
125
+ async def _fetch_size(session: httpx.AsyncClient, dataset_id: str) -> Optional[int]:
126
+ """Fetch dataset size from the datasets server asynchronously."""
127
+ url = f"https://datasets-server.huggingface.co/size?dataset={dataset_id}"
128
+ try:
129
+ resp = await session.get(url, timeout=30)
130
+ if resp.status_code == 200:
131
+ data = resp.json()
132
+ return data.get("size", {}).get("dataset", {}).get("num_bytes_original_files")
133
+ except Exception as e:
134
+ log.warning(f"Could not fetch size for {dataset_id}: {e}")
135
+ return None
136
+
137
+ async def _fetch_sizes(dataset_ids: List[str]) -> Dict[str, Optional[int]]:
138
+ """Fetch dataset sizes in parallel."""
139
+ results: Dict[str, Optional[int]] = {}
140
+ async with httpx.AsyncClient() as session:
141
+ tasks = {dataset_id: asyncio.create_task(_fetch_size(session, dataset_id)) for dataset_id in dataset_ids}
142
+ for dataset_id, task in tasks.items():
143
+ results[dataset_id] = await task
144
+ return results
145
+
146
+ def process_datasets_page(offset, limit):
147
+ """
148
+ Fetch and process a single page of datasets from Hugging Face and cache them in Redis.
149
+ """
150
+ import redis
151
+ import os
152
+ import json
153
+ import asyncio
154
+ log = logging.getLogger(__name__)
155
+ log.info(f"[process_datasets_page] ENTRY: offset={offset}, limit={limit}")
156
+ token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
157
+ if not token:
158
+ log.error("[process_datasets_page] HUGGINGFACEHUB_API_TOKEN environment variable is not set.")
159
+ raise RuntimeError("HUGGINGFACEHUB_API_TOKEN environment variable is not set. Please set it securely.")
160
+ headers = {
161
+ "Authorization": f"Bearer {token}",
162
+ "User-Agent": "Mozilla/5.0 (compatible; CollinearTool/1.0; +https://yourdomain.com)"
163
+ }
164
+ params = {"limit": limit, "offset": offset, "full": "True"}
165
+ redis_client = redis.Redis(host="redis", port=6379, db=0, decode_responses=True)
166
+ stream_key = "hf:datasets:all:stream"
167
+ zset_key = "hf:datasets:all:zset"
168
+ hash_key = "hf:datasets:all:hash"
169
+ try:
170
+ log.info(f"[process_datasets_page] Requesting {HF_API_URL} with params={params}")
171
+ response = requests.get(HF_API_URL, headers=headers, params=params, timeout=120)
172
+ response.raise_for_status()
173
+
174
+ page_items = response.json()
175
+
176
+ log.info(f"[process_datasets_page] Received {len(page_items)} datasets at offset {offset}")
177
+
178
+ dataset_ids = [ds.get("id") for ds in page_items]
179
+ size_map = asyncio.run(_fetch_sizes(dataset_ids))
180
+
181
+ for ds in page_items:
182
+ dataset_id = ds.get("id")
183
+ size_bytes = size_map.get(dataset_id)
184
+ downloads = ds.get("downloads")
185
+ likes = ds.get("likes")
186
+ impact_level, assessment_method = determine_impact_level_by_criteria(size_bytes, downloads, likes)
187
+ metrics = DatasetMetrics(size_bytes=size_bytes, downloads=downloads, likes=likes)
188
+ thresholds = {
189
+ "size_bytes": {
190
+ "low": str(100 * 1024 * 1024),
191
+ "medium": str(1 * 1024 * 1024 * 1024),
192
+ "high": str(10 * 1024 * 1024 * 1024)
193
+ }
194
+ }
195
+ impact_assessment = ImpactAssessment(
196
+ dataset_id=dataset_id,
197
+ impact_level=impact_level,
198
+ assessment_method=assessment_method,
199
+ metrics=metrics,
200
+ thresholds=thresholds
201
+ ).model_dump()
202
+ item = {
203
+ "id": dataset_id,
204
+ "name": ds.get("name"),
205
+ "description": ds.get("description"),
206
+ "size_bytes": size_bytes,
207
+ "impact_level": impact_level.value if isinstance(impact_level, ImpactLevel) else impact_level,
208
+ "downloads": downloads,
209
+ "likes": likes,
210
+ "tags": ds.get("tags", []),
211
+ "impact_assessment": json.dumps(impact_assessment)
212
+ }
213
+ final_item = {}
214
+ for k, v in item.items():
215
+ if isinstance(v, list) or isinstance(v, dict):
216
+ final_item[k] = json.dumps(v)
217
+ elif v is None:
218
+ final_item[k] = 'null'
219
+ else:
220
+ final_item[k] = str(v)
221
+
222
+ redis_client.xadd(stream_key, final_item)
223
+ redis_client.zadd(zset_key, {dataset_id: offset})
224
+ redis_client.hset(hash_key, dataset_id, json.dumps(item))
225
+
226
+ log.info(f"[process_datasets_page] EXIT: Cached {len(page_items)} datasets at offset {offset}")
227
+ return len(page_items)
228
+ except Exception as exc:
229
+ log.error(f"[process_datasets_page] ERROR: offset={offset}, limit={limit}, exc={exc}", exc_info=True)
230
+ raise
231
+
232
+ def refresh_datasets_cache():
233
+ """
234
+ Orchestrator: Enqueue Celery tasks to fetch all Hugging Face datasets in parallel.
235
+ Uses direct calls to HF API.
236
+ """
237
+ import requests
238
+ log.info("[refresh_datasets_cache] Orchestrating dataset fetch tasks using direct HF API calls.")
239
+ token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
240
+ if not token:
241
+ log.error("[refresh_datasets_cache] HUGGINGFACEHUB_API_TOKEN environment variable is not set.")
242
+ raise RuntimeError("HUGGINGFACEHUB_API_TOKEN environment variable is not set. Please set it securely.")
243
+
244
+ headers = {
245
+ "Authorization": f"Bearer {token}",
246
+ "User-Agent": "Mozilla/5.0 (compatible; CollinearTool/1.0; +https://yourdomain.com)"
247
+ }
248
+ limit = 500
249
+
250
+ params = {"limit": 1, "offset": 0}
251
+ try:
252
+ response = requests.get(HF_API_URL, headers=headers, params=params, timeout=120)
253
+ response.raise_for_status()
254
+ total_str = response.headers.get('X-Total-Count')
255
+ if not total_str:
256
+ log.error("[refresh_datasets_cache] 'X-Total-Count' header not found in HF API response.")
257
+ raise ValueError("'X-Total-Count' header missing from Hugging Face API response.")
258
+ total = int(total_str)
259
+ log.info(f"[refresh_datasets_cache] Total datasets reported by HF API: {total}")
260
+ except requests.RequestException as e:
261
+ log.error(f"[refresh_datasets_cache] Error fetching total dataset count from HF API: {e}")
262
+ raise
263
+ except ValueError as e:
264
+ log.error(f"[refresh_datasets_cache] Error parsing total dataset count: {e}")
265
+ raise
266
+
267
+ num_pages = (total + limit - 1) // limit
268
+ from app.tasks.dataset_tasks import fetch_datasets_page
269
+ from celery import group
270
+ tasks = []
271
+ for page_num in range(num_pages):
272
+ offset = page_num * limit
273
+ tasks.append(fetch_datasets_page.s(offset, limit))
274
+ log.info(f"[refresh_datasets_cache] Scheduled page at offset {offset}, limit {limit}.")
275
+ if tasks:
276
+ group(tasks).apply_async()
277
+ log.info(f"[refresh_datasets_cache] Enqueued {len(tasks)} fetch tasks.")
278
+ else:
279
+ log.warning("[refresh_datasets_cache] No dataset pages found to schedule.")
280
+
281
+ def determine_impact_level_by_criteria(size_bytes, downloads=None, likes=None):
282
+ try:
283
+ size = int(size_bytes) if size_bytes not in (None, 'null') else 0
284
+ except Exception:
285
+ size = 0
286
+
287
+ # Prefer size_bytes if available
288
+ if size >= 10 * 1024 * 1024 * 1024:
289
+ return ("high", "large_size")
290
+ elif size >= 1 * 1024 * 1024 * 1024:
291
+ return ("medium", "medium_size")
292
+ elif size >= 100 * 1024 * 1024:
293
+ return ("low", "small_size")
294
+ # Fallback to downloads if size_bytes is missing or too small
295
+ if downloads is not None:
296
+ try:
297
+ downloads = int(downloads)
298
+ if downloads >= 100000:
299
+ return ("high", "downloads")
300
+ elif downloads >= 10000:
301
+ return ("medium", "downloads")
302
+ elif downloads >= 1000:
303
+ return ("low", "downloads")
304
+ except Exception:
305
+ pass
306
+ # Fallback to likes if downloads is missing
307
+ if likes is not None:
308
+ try:
309
+ likes = int(likes)
310
+ if likes >= 1000:
311
+ return ("high", "likes")
312
+ elif likes >= 100:
313
+ return ("medium", "likes")
314
+ elif likes >= 10:
315
+ return ("low", "likes")
316
+ except Exception:
317
+ pass
318
+ return ("not_available", "size_and_downloads_and_likes_unknown")
319
+
320
+ def get_dataset_size(dataset: dict, dataset_id: str = None):
321
+ """
322
+ Extract the size in bytes from a dataset dictionary.
323
+ Tries multiple locations based on possible HuggingFace API responses.
324
+ """
325
+ # Try top-level key
326
+ size_bytes = dataset.get("size_bytes")
327
+ if size_bytes not in (None, 'null'):
328
+ return size_bytes
329
+ # Try nested structure from the size API
330
+ size_bytes = (
331
+ dataset.get("size", {})
332
+ .get("dataset", {})
333
+ .get("num_bytes_original_files")
334
+ )
335
+ if size_bytes not in (None, 'null'):
336
+ return size_bytes
337
+ # Try metrics or info sub-dictionaries if present
338
+ for key in ["metrics", "info"]:
339
+ sub = dataset.get(key, {})
340
+ if isinstance(sub, dict):
341
+ size_bytes = sub.get("size_bytes")
342
+ if size_bytes not in (None, 'null'):
343
+ return size_bytes
344
+ # Not found
345
+ return None
346
+
347
+ async def get_datasets_page_from_zset_async(offset: int = 0, limit: int = 10, search: str = None) -> dict:
348
+ redis_client = aioredis.Redis(host="redis", port=6379, db=0, decode_responses=True)
349
+ zset_key = "hf:datasets:all:zset"
350
+ hash_key = "hf:datasets:all:hash"
351
+ total = await redis_client.zcard(zset_key)
352
+ ids = await redis_client.zrange(zset_key, offset, offset + limit - 1)
353
+ if not ids:
354
+ return {"items": [], "count": total}
355
+ items = await redis_client.hmget(hash_key, ids)
356
+ parsed = []
357
+ for raw in items:
358
+ if not raw:
359
+ continue
360
+ try:
361
+ item = json.loads(raw)
362
+ parsed.append(item)
363
+ except Exception:
364
+ continue
365
+ if search:
366
+ parsed = [d for d in parsed if search.lower() in (d.get("id") or "").lower()]
367
+ return {"items": parsed, "count": total}
368
+
369
+ async def get_dataset_commits_async(dataset_id: str, limit: int = 20):
370
+ from huggingface_hub import HfApi
371
+ import logging
372
+ log = logging.getLogger(__name__)
373
+ api = HfApi()
374
+ log.info(f"[get_dataset_commits_async] Fetching commits for dataset_id={dataset_id}")
375
+ try:
376
+ # huggingface_hub is sync, so run in threadpool
377
+ import anyio
378
+ commits = await anyio.to_thread.run_sync(api.list_repo_commits, repo_id=dataset_id, repo_type="dataset")
379
+ log.info(f"[get_dataset_commits_async] Received {len(commits)} commits for {dataset_id}")
380
+ except Exception as e:
381
+ log.error(f"[get_dataset_commits_async] Error fetching commits for {dataset_id}: {e}", exc_info=True)
382
+ raise
383
+ result = []
384
+ for c in commits[:limit]:
385
+ try:
386
+ commit_id = getattr(c, "commit_id", "")
387
+ title = getattr(c, "title", "")
388
+ message = getattr(c, "message", title)
389
+ authors = getattr(c, "authors", [])
390
+ author_name = authors[0] if authors and isinstance(authors, list) else ""
391
+ created_at = getattr(c, "created_at", None)
392
+ if created_at:
393
+ if hasattr(created_at, "isoformat"):
394
+ date = created_at.isoformat()
395
+ else:
396
+ date = str(created_at)
397
+ else:
398
+ date = ""
399
+ result.append({
400
+ "id": commit_id or "",
401
+ "title": title or message or "",
402
+ "message": message or title or "",
403
+ "author": {"name": author_name, "email": ""},
404
+ "date": date,
405
+ })
406
+ except Exception as e:
407
+ log.error(f"[get_dataset_commits_async] Error parsing commit: {e} | Commit: {getattr(c, '__dict__', str(c))}", exc_info=True)
408
+ log.info(f"[get_dataset_commits_async] Returning {len(result)} parsed commits for {dataset_id}")
409
+ return result
410
+
411
+ async def get_dataset_files_async(dataset_id: str) -> List[str]:
412
+ from huggingface_hub import HfApi
413
+ import anyio
414
+ api = HfApi()
415
+ # huggingface_hub is sync, so run in threadpool
416
+ return await anyio.to_thread.run_sync(api.list_repo_files, repo_id=dataset_id, repo_type="dataset")
417
+
418
+ async def get_file_url_async(dataset_id: str, filename: str, revision: Optional[str] = None) -> str:
419
+ from huggingface_hub import hf_hub_url
420
+ import anyio
421
+ # huggingface_hub is sync, so run in threadpool
422
+ return await anyio.to_thread.run_sync(hf_hub_url, repo_id=dataset_id, filename=filename, repo_type="dataset", revision=revision)
423
+
424
+ # Fetch and cache all datasets
425
+
426
+ class EnhancedJSONEncoder(json.JSONEncoder):
427
+ def default(self, obj):
428
+ if isinstance(obj, datetime):
429
+ return obj.isoformat()
430
+ return super().default(obj)
431
+
432
+ async def fetch_size(session, dataset_id, token=None):
433
+ url = f"https://datasets-server.huggingface.co/size?dataset={dataset_id}"
434
+ headers = {"Authorization": f"Bearer {token}"} if token else {}
435
+ try:
436
+ resp = await session.get(url, headers=headers, timeout=30)
437
+ if resp.status_code == 200:
438
+ data = resp.json()
439
+ return dataset_id, data.get("size", {}).get("dataset", {}).get("num_bytes_original_files")
440
+ except Exception as e:
441
+ log.warning(f"Could not fetch size for {dataset_id}: {e}")
442
+ return dataset_id, None
443
+
444
+ async def fetch_all_sizes(dataset_ids, token=None, batch_size=50):
445
+ results = {}
446
+ async with httpx.AsyncClient() as session:
447
+ for i in range(0, len(dataset_ids), batch_size):
448
+ batch = dataset_ids[i:i+batch_size]
449
+ tasks = [fetch_size(session, ds_id, token) for ds_id in batch]
450
+ batch_results = await asyncio.gather(*tasks)
451
+ for ds_id, size in batch_results:
452
+ results[ds_id] = size
453
+ return results
454
+
455
+ def fetch_and_cache_all_datasets(token: str):
456
+ api = HfApi(token=token)
457
+ log.info("Fetching all datasets from Hugging Face Hub...")
458
+ all_datasets = list(api.list_datasets())
459
+ all_datasets_dicts = []
460
+ dataset_ids = [d.id for d in all_datasets]
461
+ # Fetch all sizes in batches
462
+ sizes = asyncio.run(fetch_all_sizes(dataset_ids, token=token, batch_size=50))
463
+ for d in all_datasets:
464
+ data = d.__dict__
465
+ size_bytes = sizes.get(d.id)
466
+ downloads = data.get("downloads")
467
+ likes = data.get("likes")
468
+ data["size_bytes"] = size_bytes
469
+ impact_level, _ = determine_impact_level_by_criteria(size_bytes, downloads, likes)
470
+ data["impact_level"] = impact_level
471
+ all_datasets_dicts.append(data)
472
+ compressed = gzip.compress(json.dumps(all_datasets_dicts, cls=EnhancedJSONEncoder).encode("utf-8"))
473
+ r = redis.Redis(host="redis", port=6379, decode_responses=False)
474
+ r.set(REDIS_KEY, compressed)
475
+ log.info(f"Cached {len(all_datasets_dicts)} datasets in Redis under {REDIS_KEY}")
476
+ return len(all_datasets_dicts)
477
+
478
+ # Native pagination from cache
479
+
480
+ def get_datasets_page_from_cache(limit: int, offset: int):
481
+ r = redis.Redis(host="redis", port=6379, decode_responses=False)
482
+ compressed = r.get(REDIS_KEY)
483
+ if not compressed:
484
+ return {"error": "Cache not found. Please refresh datasets."}, 404
485
+ all_datasets = json.loads(gzip.decompress(compressed).decode("utf-8"))
486
+ total = len(all_datasets)
487
+ if offset < 0 or offset >= total:
488
+ return {"error": "Offset out of range.", "total": total}, 400
489
+ page = all_datasets[offset:offset+limit]
490
+ total_pages = (total + limit - 1) // limit
491
+ current_page = (offset // limit) + 1
492
+ next_page = current_page + 1 if offset + limit < total else None
493
+ prev_page = current_page - 1 if current_page > 1 else None
494
+ return {
495
+ "total": total,
496
+ "current_page": current_page,
497
+ "total_pages": total_pages,
498
+ "next_page": next_page,
499
+ "prev_page": prev_page,
500
+ "items": page
501
+ }, 200
app/services/redis_client.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Redis client for caching and task queue management."""
2
+
3
+ import json
4
+ from typing import Any, Dict, Optional, TypeVar
5
+ from datetime import datetime
6
+ import logging
7
+ from time import time as _time
8
+
9
+ import redis.asyncio as redis_async
10
+ import redis as redis_sync # Import synchronous Redis client
11
+ from pydantic import BaseModel
12
+ from tenacity import retry, stop_after_attempt, wait_exponential
13
+
14
+ from app.core.config import settings
15
+
16
+ # Type variable for cache
17
+ T = TypeVar('T')
18
+
19
+ # Configure logging
20
+ log = logging.getLogger(__name__)
21
+
22
+ # Redis connection pools for reusing connections
23
+ _redis_pool_async = None
24
+ _redis_pool_sync = None # Synchronous pool
25
+
26
+ # Default cache expiration (12 hours)
27
+ DEFAULT_CACHE_EXPIRY = 60 * 60 * 12
28
+
29
+ @retry(stop=stop_after_attempt(5), wait=wait_exponential(multiplier=1, min=1, max=10))
30
+ async def get_redis_pool() -> redis_async.Redis:
31
+ """Get or create async Redis connection pool with retry logic."""
32
+ global _redis_pool_async
33
+
34
+ if _redis_pool_async is None:
35
+ # Get Redis configuration from settings
36
+ redis_url = settings.REDIS_URL or "redis://localhost:6379/0"
37
+
38
+ try:
39
+ # Create connection pool with reasonable defaults
40
+ _redis_pool_async = redis_async.ConnectionPool.from_url(
41
+ redis_url,
42
+ max_connections=10,
43
+ decode_responses=True,
44
+ health_check_interval=5,
45
+ socket_connect_timeout=5,
46
+ socket_keepalive=True,
47
+ retry_on_timeout=True
48
+ )
49
+ log.info(f"Created async Redis connection pool with URL: {redis_url}")
50
+ except Exception as e:
51
+ log.error(f"Error creating async Redis connection pool: {e}")
52
+ raise
53
+
54
+ return redis_async.Redis(connection_pool=_redis_pool_async)
55
+
56
+ def get_redis_pool_sync() -> redis_sync.Redis:
57
+ """Get or create synchronous Redis connection pool."""
58
+ global _redis_pool_sync
59
+
60
+ if _redis_pool_sync is None:
61
+ # Get Redis configuration from settings
62
+ redis_url = settings.REDIS_URL or "redis://localhost:6379/0"
63
+
64
+ try:
65
+ # Create connection pool with reasonable defaults
66
+ _redis_pool_sync = redis_sync.ConnectionPool.from_url(
67
+ redis_url,
68
+ max_connections=10,
69
+ decode_responses=True,
70
+ socket_connect_timeout=5,
71
+ socket_keepalive=True,
72
+ retry_on_timeout=True
73
+ )
74
+ log.info(f"Created sync Redis connection pool with URL: {redis_url}")
75
+ except Exception as e:
76
+ log.error(f"Error creating sync Redis connection pool: {e}")
77
+ raise
78
+
79
+ return redis_sync.Redis(connection_pool=_redis_pool_sync)
80
+
81
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=5))
82
+ async def get_redis() -> redis_async.Redis:
83
+ """Get Redis client from pool with retry logic."""
84
+ try:
85
+ redis_client = await get_redis_pool()
86
+ return redis_client
87
+ except Exception as e:
88
+ log.error(f"Error getting Redis client: {e}")
89
+ raise
90
+
91
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=5))
92
+ def get_redis_sync() -> redis_sync.Redis:
93
+ """Get synchronous Redis client from pool with retry logic."""
94
+ try:
95
+ return get_redis_pool_sync()
96
+ except Exception as e:
97
+ log.error(f"Error getting synchronous Redis client: {e}")
98
+ raise
99
+
100
+ # Cache key generation
101
+ def generate_cache_key(prefix: str, *args: Any) -> str:
102
+ """Generate cache key with prefix and args."""
103
+ key_parts = [prefix] + [str(arg) for arg in args if arg]
104
+ return ":".join(key_parts)
105
+
106
+ # JSON serialization helpers
107
+ def _json_serialize(obj: Any) -> str:
108
+ """Serialize object to JSON with datetime support."""
109
+ def _serialize_datetime(o: Any) -> str:
110
+ if isinstance(o, datetime):
111
+ return o.isoformat()
112
+ if isinstance(o, BaseModel):
113
+ return o.dict()
114
+ return str(o)
115
+
116
+ return json.dumps(obj, default=_serialize_datetime)
117
+
118
+ def _json_deserialize(data: str, model_class: Optional[type] = None) -> Any:
119
+ """Deserialize JSON string to object with datetime support."""
120
+ result = json.loads(data)
121
+
122
+ if model_class and issubclass(model_class, BaseModel):
123
+ return model_class.parse_obj(result)
124
+
125
+ return result
126
+
127
+ # Async cache operations
128
+ async def cache_set(key: str, value: Any, expire: int = DEFAULT_CACHE_EXPIRY) -> bool:
129
+ """Set cache value with expiration (async version)."""
130
+ redis_client = await get_redis()
131
+ serialized = _json_serialize(value)
132
+
133
+ try:
134
+ await redis_client.set(key, serialized, ex=expire)
135
+ log.debug(f"Cached data at key: {key}, expires in {expire}s")
136
+ return True
137
+ except Exception as e:
138
+ log.error(f"Error caching data at key {key}: {e}")
139
+ return False
140
+
141
+ async def cache_get(key: str, model_class: Optional[type] = None) -> Optional[Any]:
142
+ """Get cache value with optional model deserialization (async version)."""
143
+ redis_client = await get_redis()
144
+
145
+ try:
146
+ data = await redis_client.get(key)
147
+ if not data:
148
+ return None
149
+
150
+ log.debug(f"Cache hit for key: {key}")
151
+ return _json_deserialize(data, model_class)
152
+ except Exception as e:
153
+ log.error(f"Error retrieving cache for key {key}: {e}")
154
+ return None
155
+
156
+ # Synchronous cache operations for Celery tasks
157
+ def sync_cache_set(key: str, value: Any, expire: int = DEFAULT_CACHE_EXPIRY) -> bool:
158
+ """Set cache value with expiration (synchronous version for Celery tasks). Logs slow operations."""
159
+ redis_client = get_redis_sync()
160
+ serialized = _json_serialize(value)
161
+ start = _time()
162
+ try:
163
+ redis_client.set(key, serialized, ex=expire)
164
+ elapsed = _time() - start
165
+ if elapsed > 2:
166
+ log.warning(f"Slow sync_cache_set for key {key}: {elapsed:.2f}s")
167
+ log.debug(f"Cached data at key: {key}, expires in {expire}s (sync)")
168
+ return True
169
+ except Exception as e:
170
+ log.error(f"Error caching data at key {key}: {e}")
171
+ return False
172
+
173
+ def sync_cache_get(key: str, model_class: Optional[type] = None) -> Optional[Any]:
174
+ """Get cache value with optional model deserialization (synchronous version for Celery tasks). Logs slow operations."""
175
+ redis_client = get_redis_sync()
176
+ start = _time()
177
+ try:
178
+ data = redis_client.get(key)
179
+ elapsed = _time() - start
180
+ if elapsed > 2:
181
+ log.warning(f"Slow sync_cache_get for key {key}: {elapsed:.2f}s")
182
+ if not data:
183
+ return None
184
+ log.debug(f"Cache hit for key: {key} (sync)")
185
+ return _json_deserialize(data, model_class)
186
+ except Exception as e:
187
+ log.error(f"Error retrieving cache for key {key}: {e}")
188
+ return None
189
+
190
+ async def cache_invalidate(key: str) -> bool:
191
+ """Invalidate cache for key."""
192
+ redis_client = await get_redis()
193
+
194
+ try:
195
+ await redis_client.delete(key)
196
+ log.debug(f"Invalidated cache for key: {key}")
197
+ return True
198
+ except Exception as e:
199
+ log.error(f"Error invalidating cache for key {key}: {e}")
200
+ return False
201
+
202
+ async def cache_invalidate_pattern(pattern: str) -> int:
203
+ """Invalidate all cache keys matching pattern."""
204
+ redis_client = await get_redis()
205
+
206
+ try:
207
+ keys = await redis_client.keys(pattern)
208
+ if not keys:
209
+ return 0
210
+
211
+ count = await redis_client.delete(*keys)
212
+ log.debug(f"Invalidated {count} keys matching pattern: {pattern}")
213
+ return count
214
+ except Exception as e:
215
+ log.error(f"Error invalidating keys with pattern {pattern}: {e}")
216
+ return 0
217
+
218
+ # Task queue operations
219
+ async def enqueue_task(queue_name: str, task_id: str, payload: Dict[str, Any]) -> bool:
220
+ """Add task to queue."""
221
+ redis_client = await get_redis()
222
+
223
+ try:
224
+ serialized = _json_serialize(payload)
225
+ await redis_client.lpush(f"queue:{queue_name}", serialized)
226
+ await redis_client.hset(f"tasks:{queue_name}", task_id, "pending")
227
+ log.info(f"Enqueued task {task_id} to queue {queue_name}")
228
+ return True
229
+ except Exception as e:
230
+ log.error(f"Error enqueueing task {task_id} to {queue_name}: {e}")
231
+ return False
232
+
233
+ async def mark_task_complete(queue_name: str, task_id: str, result: Optional[Dict[str, Any]] = None) -> bool:
234
+ """Mark task as complete with optional result."""
235
+ redis_client = await get_redis()
236
+
237
+ try:
238
+ # Store result if provided
239
+ if result:
240
+ await redis_client.hset(
241
+ f"results:{queue_name}",
242
+ task_id,
243
+ _json_serialize(result)
244
+ )
245
+
246
+ # Mark task as complete
247
+ await redis_client.hset(f"tasks:{queue_name}", task_id, "complete")
248
+ await redis_client.expire(f"tasks:{queue_name}", 86400) # Expire after 24 hours
249
+
250
+ log.info(f"Marked task {task_id} as complete in queue {queue_name}")
251
+ return True
252
+ except Exception as e:
253
+ log.error(f"Error marking task {task_id} as complete: {e}")
254
+ return False
255
+
256
+ async def get_task_status(queue_name: str, task_id: str) -> Optional[str]:
257
+ """Get status of a task."""
258
+ redis_client = await get_redis()
259
+
260
+ try:
261
+ status = await redis_client.hget(f"tasks:{queue_name}", task_id)
262
+ return status
263
+ except Exception as e:
264
+ log.error(f"Error getting status for task {task_id}: {e}")
265
+ return None
266
+
267
+ async def get_task_result(queue_name: str, task_id: str) -> Optional[Dict[str, Any]]:
268
+ """Get result of a completed task."""
269
+ redis_client = await get_redis()
270
+
271
+ try:
272
+ data = await redis_client.hget(f"results:{queue_name}", task_id)
273
+ if not data:
274
+ return None
275
+
276
+ return _json_deserialize(data)
277
+ except Exception as e:
278
+ log.error(f"Error getting result for task {task_id}: {e}")
279
+ return None
280
+
281
+ # Stream processing for real-time updates
282
+ async def add_to_stream(stream: str, data: Dict[str, Any], max_len: int = 1000) -> str:
283
+ """Add event to Redis stream."""
284
+ redis_client = await get_redis()
285
+
286
+ try:
287
+ # Convert dict values to strings (Redis streams requirement)
288
+ entry = {k: _json_serialize(v) for k, v in data.items()}
289
+
290
+ # Add to stream with automatic ID generation
291
+ event_id = await redis_client.xadd(
292
+ stream,
293
+ entry,
294
+ maxlen=max_len,
295
+ approximate=True
296
+ )
297
+
298
+ log.debug(f"Added event {event_id} to stream {stream}")
299
+ return event_id
300
+ except Exception as e:
301
+ log.error(f"Error adding to stream {stream}: {e}")
302
+ raise
app/tasks/dataset_tasks.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import asyncio
4
+ from datetime import datetime, timezone
5
+ from typing import Dict, List, Any, Optional, Tuple
6
+ from celery import Task, shared_task
7
+ from app.core.celery_app import get_celery_app
8
+ from app.services.hf_datasets import (
9
+ determine_impact_level_by_criteria,
10
+ get_hf_token,
11
+ get_dataset_size,
12
+ refresh_datasets_cache,
13
+ fetch_and_cache_all_datasets,
14
+ )
15
+ from app.services.redis_client import sync_cache_set, sync_cache_get, generate_cache_key
16
+ from app.core.config import settings
17
+ import requests
18
+ import os
19
+
20
+ # Configure logging
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # Get Celery app instance
24
+ celery_app = get_celery_app()
25
+
26
+ # Constants
27
+ DATASET_CACHE_TTL = 60 * 60 * 24 * 30 # 30 days
28
+ BATCH_PROGRESS_CACHE_TTL = 60 * 60 * 24 * 7 # 7 days for batch progress
29
+ DATASET_SIZE_CACHE_TTL = 60 * 60 * 24 * 30 # 30 days for size info
30
+
31
+ @celery_app.task(name="app.tasks.dataset_tasks.refresh_hf_datasets_cache")
32
+ def refresh_hf_datasets_cache():
33
+ """Celery task to refresh the HuggingFace datasets cache in Redis."""
34
+ logger.info("Starting refresh of HuggingFace datasets cache via Celery task.")
35
+ try:
36
+ refresh_datasets_cache()
37
+ logger.info("Successfully refreshed HuggingFace datasets cache.")
38
+ return {"status": "success"}
39
+ except Exception as e:
40
+ logger.error(f"Failed to refresh HuggingFace datasets cache: {e}")
41
+ return {"status": "error", "error": str(e)}
42
+
43
+ @shared_task(bind=True, max_retries=3, default_retry_delay=10)
44
+ def fetch_datasets_page(self, offset, limit):
45
+ """
46
+ Celery task to fetch and cache a single page of datasets from Hugging Face.
47
+ Retries on failure.
48
+ """
49
+ logger.info(f"[fetch_datasets_page] ENTRY: offset={offset}, limit={limit}")
50
+ try:
51
+ from app.services.hf_datasets import process_datasets_page
52
+ logger.info(f"[fetch_datasets_page] Calling process_datasets_page with offset={offset}, limit={limit}")
53
+ result = process_datasets_page(offset, limit)
54
+ logger.info(f"[fetch_datasets_page] SUCCESS: offset={offset}, limit={limit}, result={result}")
55
+ return result
56
+ except Exception as exc:
57
+ logger.error(f"[fetch_datasets_page] ERROR: offset={offset}, limit={limit}, exc={exc}", exc_info=True)
58
+ raise self.retry(exc=exc)
59
+
60
+ @shared_task(bind=True, max_retries=3, default_retry_delay=60)
61
+ def refresh_hf_datasets_full_cache(self):
62
+ logger.info("[refresh_hf_datasets_full_cache] Starting full Hugging Face datasets cache refresh.")
63
+ try:
64
+ token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
65
+ if not token:
66
+ logger.error("[refresh_hf_datasets_full_cache] HUGGINGFACEHUB_API_TOKEN not set.")
67
+ return {"status": "error", "error": "HUGGINGFACEHUB_API_TOKEN not set"}
68
+ count = fetch_and_cache_all_datasets(token)
69
+ logger.info(f"[refresh_hf_datasets_full_cache] Cached {count} datasets.")
70
+ return {"status": "ok", "cached": count}
71
+ except Exception as exc:
72
+ logger.error(f"[refresh_hf_datasets_full_cache] ERROR: {exc}", exc_info=True)
73
+ raise self.retry(exc=exc)
migrations/20250620000000_create_combined_datasets_table.sql ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- Create combined_datasets table
2
+ CREATE TABLE IF NOT EXISTS public.combined_datasets (
3
+ id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
4
+ name TEXT NOT NULL,
5
+ description TEXT,
6
+ source_datasets TEXT[] NOT NULL,
7
+ created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
8
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
9
+ created_by UUID REFERENCES auth.users(id),
10
+ impact_level TEXT CHECK (impact_level = ANY (ARRAY['low', 'medium', 'high']::text[])),
11
+ status TEXT NOT NULL DEFAULT 'processing',
12
+ combination_strategy TEXT NOT NULL DEFAULT 'merge',
13
+ size_bytes BIGINT,
14
+ file_count INTEGER,
15
+ downloads INTEGER,
16
+ likes INTEGER
17
+ );
18
+
19
+ -- Add indexes for faster querying
20
+ CREATE INDEX IF NOT EXISTS idx_combined_datasets_created_by ON public.combined_datasets(created_by);
21
+ CREATE INDEX IF NOT EXISTS idx_combined_datasets_impact_level ON public.combined_datasets(impact_level);
22
+ CREATE INDEX IF NOT EXISTS idx_combined_datasets_status ON public.combined_datasets(status);
23
+
24
+ -- Add Row Level Security (RLS) policies
25
+ ALTER TABLE public.combined_datasets ENABLE ROW LEVEL SECURITY;
26
+
27
+ -- Policy to allow users to see all combined datasets
28
+ CREATE POLICY "Anyone can view combined datasets"
29
+ ON public.combined_datasets
30
+ FOR SELECT USING (true);
31
+
32
+ -- Policy to allow users to create their own combined datasets
33
+ CREATE POLICY "Users can create their own combined datasets"
34
+ ON public.combined_datasets
35
+ FOR INSERT
36
+ WITH CHECK (auth.uid() = created_by);
37
+
38
+ -- Policy to allow users to update only their own combined datasets
39
+ CREATE POLICY "Users can update their own combined datasets"
40
+ ON public.combined_datasets
41
+ FOR UPDATE
42
+ USING (auth.uid() = created_by);
43
+
44
+ -- Function to automatically update updated_at timestamp
45
+ CREATE OR REPLACE FUNCTION update_combined_datasets_updated_at()
46
+ RETURNS TRIGGER AS $$
47
+ BEGIN
48
+ NEW.updated_at = now();
49
+ RETURN NEW;
50
+ END;
51
+ $$ LANGUAGE plpgsql;
52
+
53
+ -- Trigger to automatically update updated_at timestamp
54
+ CREATE TRIGGER update_combined_datasets_updated_at_trigger
55
+ BEFORE UPDATE ON public.combined_datasets
56
+ FOR EACH ROW
57
+ EXECUTE FUNCTION update_combined_datasets_updated_at();
setup.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="collinear-tool",
5
+ version="0.1.0",
6
+ packages=find_packages(),
7
+ include_package_data=True,
8
+ )
tests/test_datasets.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pytest
3
+ import requests
4
+
5
+ BASE_URL = os.environ.get("BASE_URL", "http://127.0.0.1:8000/api")
6
+
7
+ # --- /datasets ---
8
+ def test_list_datasets_http():
9
+ resp = requests.get(f"{BASE_URL}/datasets")
10
+ assert resp.status_code == 200
11
+ data = resp.json()
12
+ assert "items" in data
13
+ assert "total" in data
14
+ assert "warming_up" in data
15
+
16
+ def test_list_datasets_offset_limit_http():
17
+ resp = requests.get(f"{BASE_URL}/datasets?offset=0&limit=3")
18
+ assert resp.status_code == 200
19
+ data = resp.json()
20
+ assert isinstance(data["items"], list)
21
+ assert len(data["items"]) <= 3
22
+
23
+ def test_list_datasets_large_offset_http():
24
+ resp = requests.get(f"{BASE_URL}/datasets?offset=99999&limit=2")
25
+ assert resp.status_code == 200
26
+ data = resp.json()
27
+ assert data["items"] == []
28
+ assert "warming_up" in data
29
+
30
+ def test_list_datasets_invalid_limit_http():
31
+ resp = requests.get(f"{BASE_URL}/datasets?limit=-5")
32
+ assert resp.status_code == 422
33
+
34
+ # --- /datasets/cache-status ---
35
+ def test_cache_status_http():
36
+ resp = requests.get(f"{BASE_URL}/datasets/cache-status")
37
+ assert resp.status_code == 200
38
+ data = resp.json()
39
+ assert "warming_up" in data
40
+ assert "total_items" in data
41
+ assert "last_update" in data
42
+
43
+ # --- /datasets/{dataset_id}/commits ---
44
+ def test_commits_valid_http():
45
+ resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/commits")
46
+ assert resp.status_code in (200, 404)
47
+ if resp.status_code == 200:
48
+ assert isinstance(resp.json(), list)
49
+
50
+ def test_commits_invalid_http():
51
+ resp = requests.get(f"{BASE_URL}/datasets/invalid-dataset-id/commits")
52
+ assert resp.status_code in (404, 422)
53
+
54
+ # --- /datasets/{dataset_id}/files ---
55
+ def test_files_valid_http():
56
+ resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/files")
57
+ assert resp.status_code in (200, 404)
58
+ if resp.status_code == 200:
59
+ assert isinstance(resp.json(), list)
60
+
61
+ def test_files_invalid_http():
62
+ resp = requests.get(f"{BASE_URL}/datasets/invalid-dataset-id/files")
63
+ assert resp.status_code in (404, 422)
64
+
65
+ # --- /datasets/{dataset_id}/file-url ---
66
+ def test_file_url_valid_http():
67
+ resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/file-url", params={"filename": "README.md"})
68
+ assert resp.status_code in (200, 404)
69
+ if resp.status_code == 200:
70
+ assert "download_url" in resp.json()
71
+
72
+ def test_file_url_invalid_file_http():
73
+ resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/file-url", params={"filename": "not_a_real_file.txt"})
74
+ assert resp.status_code in (404, 200)
75
+
76
+ def test_file_url_missing_filename_http():
77
+ resp = requests.get(f"{BASE_URL}/datasets/openbmb/Ultra-FineWeb/file-url")
78
+ assert resp.status_code in (404, 422)
tests/test_datasets_api.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from fastapi.testclient import TestClient
3
+ from app.main import app
4
+
5
+ client = TestClient(app)
6
+
7
+ # --- /api/datasets ---
8
+ def test_list_datasets_default():
9
+ resp = client.get("/api/datasets")
10
+ assert resp.status_code == 200
11
+ data = resp.json()
12
+ assert "items" in data
13
+ assert isinstance(data["items"], list)
14
+ assert "total" in data
15
+ assert "warming_up" in data
16
+
17
+ def test_list_datasets_offset_limit():
18
+ resp = client.get("/api/datasets?offset=0&limit=2")
19
+ assert resp.status_code == 200
20
+ data = resp.json()
21
+ assert isinstance(data["items"], list)
22
+ assert len(data["items"]) <= 2
23
+
24
+ def test_list_datasets_large_offset():
25
+ resp = client.get("/api/datasets?offset=100000&limit=2")
26
+ assert resp.status_code == 200
27
+ data = resp.json()
28
+ assert data["items"] == []
29
+ assert data["warming_up"] in (True, False)
30
+
31
+ def test_list_datasets_negative_limit():
32
+ resp = client.get("/api/datasets?limit=-1")
33
+ assert resp.status_code == 422
34
+
35
+ def test_list_datasets_missing_params():
36
+ resp = client.get("/api/datasets")
37
+ assert resp.status_code == 200
38
+ data = resp.json()
39
+ assert "items" in data
40
+ assert "total" in data
41
+ assert "warming_up" in data
42
+
43
+ # --- /api/datasets/cache-status ---
44
+ def test_cache_status():
45
+ resp = client.get("/api/datasets/cache-status")
46
+ assert resp.status_code == 200
47
+ data = resp.json()
48
+ assert "warming_up" in data
49
+ assert "total_items" in data
50
+ assert "last_update" in data
51
+
52
+ # --- /api/datasets/{dataset_id}/commits ---
53
+ def test_get_commits_valid():
54
+ resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/commits")
55
+ # Accept 200 (found) or 404 (not found)
56
+ assert resp.status_code in (200, 404)
57
+ if resp.status_code == 200:
58
+ assert isinstance(resp.json(), list)
59
+
60
+ def test_get_commits_invalid():
61
+ resp = client.get("/api/datasets/invalid-dataset-id/commits")
62
+ assert resp.status_code in (404, 422)
63
+
64
+ # --- /api/datasets/{dataset_id}/files ---
65
+ def test_list_files_valid():
66
+ resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/files")
67
+ assert resp.status_code in (200, 404)
68
+ if resp.status_code == 200:
69
+ assert isinstance(resp.json(), list)
70
+
71
+ def test_list_files_invalid():
72
+ resp = client.get("/api/datasets/invalid-dataset-id/files")
73
+ assert resp.status_code in (404, 422)
74
+
75
+ # --- /api/datasets/{dataset_id}/file-url ---
76
+ def test_get_file_url_valid():
77
+ resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/file-url", params={"filename": "README.md"})
78
+ assert resp.status_code in (200, 404)
79
+ if resp.status_code == 200:
80
+ assert "download_url" in resp.json()
81
+
82
+ def test_get_file_url_invalid_file():
83
+ resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/file-url", params={"filename": "not_a_real_file.txt"})
84
+ assert resp.status_code in (404, 200)
85
+
86
+ def test_get_file_url_missing_filename():
87
+ resp = client.get("/api/datasets/openbmb/Ultra-FineWeb/file-url")
88
+ assert resp.status_code in (404, 422)