File size: 5,231 Bytes
fdc5d7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from fastapi import APIRouter, Query, HTTPException
from typing import List, Optional, Dict, Any, Set
from pydantic import BaseModel
from fastapi.concurrency import run_in_threadpool
from app.services.hf_datasets import (
    get_dataset_commits,
    get_dataset_files,
    get_file_url,
    get_datasets_page_from_zset,
    get_dataset_commits_async,
    get_dataset_files_async,
    get_file_url_async,
    get_datasets_page_from_cache,
    fetch_and_cache_all_datasets,
)
from app.services.redis_client import cache_get
import logging
import time
from fastapi.responses import JSONResponse
import os

router = APIRouter(prefix="/datasets", tags=["datasets"])
log = logging.getLogger(__name__)

SIZE_LOW = 100 * 1024 * 1024
SIZE_MEDIUM = 1024 * 1024 * 1024

class DatasetInfo(BaseModel):
    id: str
    name: Optional[str]
    description: Optional[str]
    size_bytes: Optional[int]
    impact_level: Optional[str]
    downloads: Optional[int]
    likes: Optional[int]
    tags: Optional[List[str]]
    class Config:
        extra = "ignore"

class PaginatedDatasets(BaseModel):
    total: int
    items: List[DatasetInfo]

class CommitInfo(BaseModel):
    id: str
    title: Optional[str]
    message: Optional[str]
    author: Optional[Dict[str, Any]]
    date: Optional[str]

class CacheStatus(BaseModel):
    last_update: Optional[str]
    total_items: int
    warming_up: bool

def deduplicate_by_id(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    seen: Set[str] = set()
    unique_items = []
    for item in items:
        item_id = item.get("id")
        if item_id and item_id not in seen:
            seen.add(item_id)
            unique_items.append(item)
    return unique_items

@router.get("/cache-status", response_model=CacheStatus)
async def cache_status():
    meta = await cache_get("hf:datasets:meta")
    last_update = meta["last_update"] if meta and "last_update" in meta else None
    total_items = meta["total_items"] if meta and "total_items" in meta else 0
    warming_up = not bool(total_items)
    return CacheStatus(last_update=last_update, total_items=total_items, warming_up=warming_up)

@router.get("/", response_model=None)
async def list_datasets(
    limit: int = Query(10, ge=1, le=1000),
    offset: int = Query(0, ge=0),
    search: str = Query(None, description="Search term for dataset id or description"),
    sort_by: str = Query(None, description="Field to sort by (e.g., 'downloads', 'likes', 'created_at')"),
    sort_order: str = Query("desc", regex="^(asc|desc)$", description="Sort order: 'asc' or 'desc'"),
):
    # Fetch the full list from cache
    result, status = get_datasets_page_from_cache(1000000, 0)  # get all for in-memory filtering
    if status != 200:
        return JSONResponse(result, status_code=status)
    items = result["items"]
    # Filtering
    if search:
        items = [d for d in items if search.lower() in (d.get("id", "") + " " + str(d.get("description", "")).lower())]
    # Sorting
    if sort_by:
        items = sorted(items, key=lambda d: d.get(sort_by) or 0, reverse=(sort_order == "desc"))
    # Pagination
    total = len(items)
    page = items[offset:offset+limit]
    total_pages = (total + limit - 1) // limit
    current_page = (offset // limit) + 1
    next_page = current_page + 1 if offset + limit < total else None
    prev_page = current_page - 1 if current_page > 1 else None
    return {
        "total": total,
        "current_page": current_page,
        "total_pages": total_pages,
        "next_page": next_page,
        "prev_page": prev_page,
        "items": page
    }

@router.get("/{dataset_id:path}/commits", response_model=List[CommitInfo])
async def get_commits(dataset_id: str):
    """
    Get commit history for a dataset.
    """
    try:
        return await get_dataset_commits_async(dataset_id)
    except Exception as e:
        log.error(f"Error fetching commits for {dataset_id}: {e}")
        raise HTTPException(status_code=404, detail=f"Could not fetch commits: {e}")

@router.get("/{dataset_id:path}/files", response_model=List[str])
async def list_files(dataset_id: str):
    """
    List files in a dataset.
    """
    try:
        return await get_dataset_files_async(dataset_id)
    except Exception as e:
        log.error(f"Error listing files for {dataset_id}: {e}")
        raise HTTPException(status_code=404, detail=f"Could not list files: {e}")

@router.get("/{dataset_id:path}/file-url")
async def get_file_url_endpoint(dataset_id: str, filename: str = Query(...), revision: Optional[str] = None):
    """
    Get download URL for a file in a dataset.
    """
    url = await get_file_url_async(dataset_id, filename, revision)
    return {"download_url": url}

@router.get("/meta")
async def get_datasets_meta():
    meta = await cache_get("hf:datasets:meta")
    return meta if meta else {}

# Endpoint to trigger cache refresh manually (for admin/testing)
@router.post("/datasets/refresh-cache")
def refresh_cache():
    token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
    if not token:
        return JSONResponse({"error": "HUGGINGFACEHUB_API_TOKEN not set"}, status_code=500)
    count = fetch_and_cache_all_datasets(token)
    return {"status": "ok", "cached": count}