Spaces:
Runtime error
Runtime error
File size: 5,231 Bytes
fdc5d7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from fastapi import APIRouter, Query, HTTPException
from typing import List, Optional, Dict, Any, Set
from pydantic import BaseModel
from fastapi.concurrency import run_in_threadpool
from app.services.hf_datasets import (
get_dataset_commits,
get_dataset_files,
get_file_url,
get_datasets_page_from_zset,
get_dataset_commits_async,
get_dataset_files_async,
get_file_url_async,
get_datasets_page_from_cache,
fetch_and_cache_all_datasets,
)
from app.services.redis_client import cache_get
import logging
import time
from fastapi.responses import JSONResponse
import os
router = APIRouter(prefix="/datasets", tags=["datasets"])
log = logging.getLogger(__name__)
SIZE_LOW = 100 * 1024 * 1024
SIZE_MEDIUM = 1024 * 1024 * 1024
class DatasetInfo(BaseModel):
id: str
name: Optional[str]
description: Optional[str]
size_bytes: Optional[int]
impact_level: Optional[str]
downloads: Optional[int]
likes: Optional[int]
tags: Optional[List[str]]
class Config:
extra = "ignore"
class PaginatedDatasets(BaseModel):
total: int
items: List[DatasetInfo]
class CommitInfo(BaseModel):
id: str
title: Optional[str]
message: Optional[str]
author: Optional[Dict[str, Any]]
date: Optional[str]
class CacheStatus(BaseModel):
last_update: Optional[str]
total_items: int
warming_up: bool
def deduplicate_by_id(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
seen: Set[str] = set()
unique_items = []
for item in items:
item_id = item.get("id")
if item_id and item_id not in seen:
seen.add(item_id)
unique_items.append(item)
return unique_items
@router.get("/cache-status", response_model=CacheStatus)
async def cache_status():
meta = await cache_get("hf:datasets:meta")
last_update = meta["last_update"] if meta and "last_update" in meta else None
total_items = meta["total_items"] if meta and "total_items" in meta else 0
warming_up = not bool(total_items)
return CacheStatus(last_update=last_update, total_items=total_items, warming_up=warming_up)
@router.get("/", response_model=None)
async def list_datasets(
limit: int = Query(10, ge=1, le=1000),
offset: int = Query(0, ge=0),
search: str = Query(None, description="Search term for dataset id or description"),
sort_by: str = Query(None, description="Field to sort by (e.g., 'downloads', 'likes', 'created_at')"),
sort_order: str = Query("desc", regex="^(asc|desc)$", description="Sort order: 'asc' or 'desc'"),
):
# Fetch the full list from cache
result, status = get_datasets_page_from_cache(1000000, 0) # get all for in-memory filtering
if status != 200:
return JSONResponse(result, status_code=status)
items = result["items"]
# Filtering
if search:
items = [d for d in items if search.lower() in (d.get("id", "") + " " + str(d.get("description", "")).lower())]
# Sorting
if sort_by:
items = sorted(items, key=lambda d: d.get(sort_by) or 0, reverse=(sort_order == "desc"))
# Pagination
total = len(items)
page = items[offset:offset+limit]
total_pages = (total + limit - 1) // limit
current_page = (offset // limit) + 1
next_page = current_page + 1 if offset + limit < total else None
prev_page = current_page - 1 if current_page > 1 else None
return {
"total": total,
"current_page": current_page,
"total_pages": total_pages,
"next_page": next_page,
"prev_page": prev_page,
"items": page
}
@router.get("/{dataset_id:path}/commits", response_model=List[CommitInfo])
async def get_commits(dataset_id: str):
"""
Get commit history for a dataset.
"""
try:
return await get_dataset_commits_async(dataset_id)
except Exception as e:
log.error(f"Error fetching commits for {dataset_id}: {e}")
raise HTTPException(status_code=404, detail=f"Could not fetch commits: {e}")
@router.get("/{dataset_id:path}/files", response_model=List[str])
async def list_files(dataset_id: str):
"""
List files in a dataset.
"""
try:
return await get_dataset_files_async(dataset_id)
except Exception as e:
log.error(f"Error listing files for {dataset_id}: {e}")
raise HTTPException(status_code=404, detail=f"Could not list files: {e}")
@router.get("/{dataset_id:path}/file-url")
async def get_file_url_endpoint(dataset_id: str, filename: str = Query(...), revision: Optional[str] = None):
"""
Get download URL for a file in a dataset.
"""
url = await get_file_url_async(dataset_id, filename, revision)
return {"download_url": url}
@router.get("/meta")
async def get_datasets_meta():
meta = await cache_get("hf:datasets:meta")
return meta if meta else {}
# Endpoint to trigger cache refresh manually (for admin/testing)
@router.post("/datasets/refresh-cache")
def refresh_cache():
token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if not token:
return JSONResponse({"error": "HUGGINGFACEHUB_API_TOKEN not set"}, status_code=500)
count = fetch_and_cache_all_datasets(token)
return {"status": "ok", "cached": count}
|