dataset-tool / app /api /datasets.py
iaroy's picture
Deploy full application code
fdc5d7a
from fastapi import APIRouter, Query, HTTPException
from typing import List, Optional, Dict, Any, Set
from pydantic import BaseModel
from fastapi.concurrency import run_in_threadpool
from app.services.hf_datasets import (
get_dataset_commits,
get_dataset_files,
get_file_url,
get_datasets_page_from_zset,
get_dataset_commits_async,
get_dataset_files_async,
get_file_url_async,
get_datasets_page_from_cache,
fetch_and_cache_all_datasets,
)
from app.services.redis_client import cache_get
import logging
import time
from fastapi.responses import JSONResponse
import os
router = APIRouter(prefix="/datasets", tags=["datasets"])
log = logging.getLogger(__name__)
SIZE_LOW = 100 * 1024 * 1024
SIZE_MEDIUM = 1024 * 1024 * 1024
class DatasetInfo(BaseModel):
id: str
name: Optional[str]
description: Optional[str]
size_bytes: Optional[int]
impact_level: Optional[str]
downloads: Optional[int]
likes: Optional[int]
tags: Optional[List[str]]
class Config:
extra = "ignore"
class PaginatedDatasets(BaseModel):
total: int
items: List[DatasetInfo]
class CommitInfo(BaseModel):
id: str
title: Optional[str]
message: Optional[str]
author: Optional[Dict[str, Any]]
date: Optional[str]
class CacheStatus(BaseModel):
last_update: Optional[str]
total_items: int
warming_up: bool
def deduplicate_by_id(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
seen: Set[str] = set()
unique_items = []
for item in items:
item_id = item.get("id")
if item_id and item_id not in seen:
seen.add(item_id)
unique_items.append(item)
return unique_items
@router.get("/cache-status", response_model=CacheStatus)
async def cache_status():
meta = await cache_get("hf:datasets:meta")
last_update = meta["last_update"] if meta and "last_update" in meta else None
total_items = meta["total_items"] if meta and "total_items" in meta else 0
warming_up = not bool(total_items)
return CacheStatus(last_update=last_update, total_items=total_items, warming_up=warming_up)
@router.get("/", response_model=None)
async def list_datasets(
limit: int = Query(10, ge=1, le=1000),
offset: int = Query(0, ge=0),
search: str = Query(None, description="Search term for dataset id or description"),
sort_by: str = Query(None, description="Field to sort by (e.g., 'downloads', 'likes', 'created_at')"),
sort_order: str = Query("desc", regex="^(asc|desc)$", description="Sort order: 'asc' or 'desc'"),
):
# Fetch the full list from cache
result, status = get_datasets_page_from_cache(1000000, 0) # get all for in-memory filtering
if status != 200:
return JSONResponse(result, status_code=status)
items = result["items"]
# Filtering
if search:
items = [d for d in items if search.lower() in (d.get("id", "") + " " + str(d.get("description", "")).lower())]
# Sorting
if sort_by:
items = sorted(items, key=lambda d: d.get(sort_by) or 0, reverse=(sort_order == "desc"))
# Pagination
total = len(items)
page = items[offset:offset+limit]
total_pages = (total + limit - 1) // limit
current_page = (offset // limit) + 1
next_page = current_page + 1 if offset + limit < total else None
prev_page = current_page - 1 if current_page > 1 else None
return {
"total": total,
"current_page": current_page,
"total_pages": total_pages,
"next_page": next_page,
"prev_page": prev_page,
"items": page
}
@router.get("/{dataset_id:path}/commits", response_model=List[CommitInfo])
async def get_commits(dataset_id: str):
"""
Get commit history for a dataset.
"""
try:
return await get_dataset_commits_async(dataset_id)
except Exception as e:
log.error(f"Error fetching commits for {dataset_id}: {e}")
raise HTTPException(status_code=404, detail=f"Could not fetch commits: {e}")
@router.get("/{dataset_id:path}/files", response_model=List[str])
async def list_files(dataset_id: str):
"""
List files in a dataset.
"""
try:
return await get_dataset_files_async(dataset_id)
except Exception as e:
log.error(f"Error listing files for {dataset_id}: {e}")
raise HTTPException(status_code=404, detail=f"Could not list files: {e}")
@router.get("/{dataset_id:path}/file-url")
async def get_file_url_endpoint(dataset_id: str, filename: str = Query(...), revision: Optional[str] = None):
"""
Get download URL for a file in a dataset.
"""
url = await get_file_url_async(dataset_id, filename, revision)
return {"download_url": url}
@router.get("/meta")
async def get_datasets_meta():
meta = await cache_get("hf:datasets:meta")
return meta if meta else {}
# Endpoint to trigger cache refresh manually (for admin/testing)
@router.post("/datasets/refresh-cache")
def refresh_cache():
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if not token:
return JSONResponse({"error": "HUGGINGFACEHUB_API_TOKEN not set"}, status_code=500)
count = fetch_and_cache_all_datasets(token)
return {"status": "ok", "cached": count}