Spaces:
Sleeping
Sleeping
import os | |
import time | |
import logging | |
from typing import List, Union, Tuple | |
from cachetools import LRUCache | |
import hashlib | |
import asyncio | |
from functools import lru_cache | |
from contextlib import asynccontextmanager | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# Configure logging | |
logging.basicConfig( | |
level=( | |
logging.DEBUG if os.environ.get("ENVIRONMENT") != "production" else logging.INFO | |
), | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
) | |
logger = logging.getLogger(__name__) | |
from fastapi import FastAPI, HTTPException, Request, Depends | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse, JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.exceptions import RequestValidationError | |
from pydantic import BaseModel, Field, field_validator, ConfigDict # Import ConfigDict | |
from pydantic_settings import BaseSettings | |
from transformers import AutoModel, AutoTokenizer | |
import torch | |
import torch.nn.functional as F | |
import uvicorn | |
from starlette import status | |
from models_config import MODELS, get_model_config, CANONICAL_MODELS, MODEL_ALIASES | |
# --- Configuration Management --- | |
class AppSettings(BaseSettings): | |
""" | |
Application settings loaded from environment variables. | |
""" | |
cuda_cache_clear_enabled: bool = Field( | |
True, json_schema_extra={"env": "CUDA_CACHE_CLEAR_ENABLED"}, description="Enable CUDA cache clearing after each batch." | |
) | |
default_model: str = Field( | |
"text-embedding-3-large", json_schema_extra={"env": "DEFAULT_MODEL"}, description="Default embedding model to use." | |
) | |
warmup_enabled: bool = Field( | |
True, json_schema_extra={"env": "WARMUP_ENABLED"}, description="Enable model warmup on startup." | |
) | |
app_port: int = Field( | |
8000, json_schema_extra={"env": "APP_PORT"}, description="Port for the FastAPI application." | |
) | |
app_host: str = Field( | |
"0.0.0.0", json_schema_extra={"env": "APP_HOST"}, description="Host for the FastAPI application." | |
) | |
embedding_batch_size: int = Field( | |
8, json_schema_extra={"env": "EMBEDDING_BATCH_SIZE"}, description="Batch size for embedding generation." | |
) | |
embeddings_cache_enabled: bool = Field( | |
True, json_schema_extra={"env": "EMBEDDINGS_CACHE_ENABLED"}, description="Enable in-memory embeddings cache." | |
) | |
report_cached_tokens: bool = Field( | |
False, json_schema_extra={"env": "REPORT_CACHED_TOKENS"}, description="Report token count for cached embeddings." | |
) | |
embeddings_cache_maxsize: int = Field( | |
2048, json_schema_extra={"env": "EMBEDDINGS_CACHE_MAXSIZE"}, description="Maximum size of the embeddings cache." | |
) | |
environment: str = Field( | |
"development", json_schema_extra={"env": "ENVIRONMENT"}, description="Application environment (e.g., 'production', 'development')." | |
) | |
model_config = ConfigDict(env_file=".env") # Use ConfigDict instead of class Config | |
# Cache the settings instance for performance | |
def get_app_settings(): | |
return AppSettings() | |
# Set up device configuration | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
logger.info(f"Using device: {DEVICE}") | |
# Initialize global embeddings cache (size determined by settings) | |
embeddings_cache = LRUCache(maxsize=0) # Will be updated on startup based on settings | |
# --- Lifespan Event Handler --- | |
async def lifespan(app: FastAPI): | |
""" | |
Handles application startup and shutdown events. | |
Initializes the embeddings cache and warms up the default model. | |
""" | |
settings = get_app_settings() # Directly get settings here | |
global embeddings_cache | |
embeddings_cache = LRUCache(maxsize=settings.embeddings_cache_maxsize) | |
logger.info(f"Embeddings cache initialized with max size: {settings.embeddings_cache_maxsize}") | |
default_model = settings.default_model | |
if default_model not in MODELS: | |
logger.error(f"Default model '{default_model}' is not configured in MODELS.") | |
raise ValueError( | |
f"Default model '{default_model}' is not configured in MODELS." | |
) | |
if settings.warmup_enabled: | |
logger.info(f"Warming up default model: {default_model}...") | |
try: | |
# Pass settings to get_embeddings_batch | |
await get_embeddings_batch(["warmup"], default_model, settings) | |
logger.info("Model warmup complete.") | |
except Exception as e: | |
logger.error(f"Model warmup failed for {default_model}: {e}", exc_info=True) | |
yield # Application starts here | |
# Clean up code (if any) goes here when application shuts down | |
logger.info("Application shutdown.") | |
app = FastAPI( | |
title="Embedding API", | |
description="API for generating embeddings using a transformer model.", | |
version="0.1.0", | |
lifespan=lifespan # Assign the lifespan context manager | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allows all origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allows all methods | |
allow_headers=["*"], # Allows all headers | |
) | |
# Mount the static directory to serve index.html and other static files. | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Initialize model cache | |
# to avoid reloading models on every request | |
model_cache = {} | |
tokenizer_cache = {} | |
# New: Initialize global dictionary for model loading locks | |
model_load_locks = {} | |
async def load_model(model_name: str): | |
""" | |
Load model and tokenizer if not already loaded, with asynchronous locking. | |
Args: | |
model_name (str): The name of the model to load. | |
Returns: | |
tuple: A tuple containing the loaded model and tokenizer. | |
""" | |
config = get_model_config(model_name) | |
canonical_hf_model_name = config["name"] | |
async with model_load_locks.setdefault(canonical_hf_model_name, asyncio.Lock()): | |
if canonical_hf_model_name not in model_cache: | |
logger.info(f"Loading model: {canonical_hf_model_name}") | |
model_path = config["name"] | |
trust_remote = config.get("requires_remote_code", False) | |
model_cache[canonical_hf_model_name] = AutoModel.from_pretrained( | |
model_path, trust_remote_code=trust_remote | |
).to(DEVICE) | |
model_cache[canonical_hf_model_name].eval() | |
tokenizer_cache[canonical_hf_model_name] = AutoTokenizer.from_pretrained(model_path) | |
logger.info(f"Model loaded: {canonical_hf_model_name}") | |
return model_cache[canonical_hf_model_name], tokenizer_cache[canonical_hf_model_name] | |
class EmbeddingRequest(BaseModel): | |
""" | |
Represents a request for generating embeddings. | |
Attributes: | |
input (Union[str, List[str]]): The input text to embed, can be a single string or a list of strings. | |
model (str): The name of the model to use for embedding. | |
encoding_format (str): The format of the embeddings. Currently only 'float' is supported. | |
""" | |
input: Union[str, List[str]] = Field( | |
..., | |
description="The input text to embed, can be a single string or a list of strings.", | |
json_schema_extra={"example": "This is an example sentence."}, | |
) | |
model: str = Field( | |
"text-embedding-3-large", | |
description="The name of the model to use for embedding. Supports both original model names and OpenAI-compatible names.", | |
json_schema_extra={"example": "text-embedding-3-large"}, | |
) | |
encoding_format: str = Field( | |
"float", | |
description="The format of the embeddings. Currently only 'float' is supported.", | |
json_schema_extra={"example": "float"}, | |
) | |
def validate_model(cls, value: str) -> str: | |
if value not in MODELS: | |
valid_models = list(CANONICAL_MODELS.keys()) + list(MODEL_ALIASES.keys()) | |
raise ValueError(f"Model must be one of: {', '.join(sorted(valid_models))}") | |
return value | |
def validate_encoding_format(cls, value: str) -> str: | |
if value != "float": | |
raise ValueError("Only 'float' encoding format is supported") | |
return value | |
class EmbeddingObject(BaseModel): | |
""" | |
Represents an embedding object. | |
Attributes: | |
object (str): The type of object, which is "embedding". | |
embedding (List[float]): The embedding vector. | |
index (int): The index of the embedding. | |
""" | |
object: str = "embedding" | |
embedding: List[float] | |
index: int | |
class EmbeddingResponse(BaseModel): | |
""" | |
Represents the response containing a list of embedding objects. | |
""" | |
data: List[EmbeddingObject] | |
model: str | |
object: str = "list" | |
usage: dict | |
class ModelObject(BaseModel): | |
""" | |
Represents a single model object in the list of models. | |
""" | |
id: str | |
object: str = "model" | |
created: int | |
owned_by: str | |
class ListModelsResponse(BaseModel): | |
""" | |
Represents the response containing a list of available models. | |
""" | |
data: List[ModelObject] | |
object: str = "list" | |
# --- Helper functions for get_embeddings_batch refactoring --- | |
def _process_texts_for_cache_and_batching( | |
texts: List[str], | |
model_config: dict, | |
settings: AppSettings | |
) -> Tuple[List[torch.Tensor], int, List[str], List[int]]: | |
""" | |
Checks cache for each text and prepares texts for model processing. | |
Returns cached embeddings, total cached tokens, texts to process, and their original indices. | |
""" | |
final_ordered_embeddings = [None] * len(texts) | |
total_prompt_tokens = 0 | |
texts_to_process_in_model = [] | |
original_indices_for_model_output = [] | |
canonical_hf_model_name = model_config["name"] | |
for i, text in enumerate(texts): | |
text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest() | |
cache_key = (text_hash, canonical_hf_model_name) | |
if settings.embeddings_cache_enabled and cache_key in embeddings_cache: | |
cached_embedding, cached_tokens = embeddings_cache[cache_key] | |
final_ordered_embeddings[i] = cached_embedding.unsqueeze(0) | |
if settings.report_cached_tokens: | |
total_prompt_tokens += cached_tokens | |
logger.debug(f"Cache hit for text at index {i}") | |
else: | |
texts_to_process_in_model.append(text) | |
original_indices_for_model_output.append(i) | |
logger.debug(f"Cache miss for text at index {i}") | |
return final_ordered_embeddings, total_prompt_tokens, texts_to_process_in_model, original_indices_for_model_output | |
def _apply_instruction_prefix(texts: List[str], model_config: dict) -> List[str]: | |
""" | |
Applies instruction prefixes to texts if required by the model configuration. | |
""" | |
if model_config.get("instruction_prefix_required", False): | |
processed_texts = [] | |
default_prefix = model_config.get("default_instruction_prefix", "") | |
known_prefixes = model_config.get("known_instruction_prefixes", []) | |
for text in texts: | |
if not any(text.startswith(prefix) for prefix in known_prefixes): | |
processed_texts.append(f"{default_prefix}{text}") | |
else: | |
processed_texts.append(text) | |
return processed_texts | |
return texts | |
def _perform_model_inference( | |
texts_to_tokenize: List[str], | |
model, | |
tokenizer, | |
model_max_tokens: int, | |
model_dimension: int, | |
settings: AppSettings | |
) -> Tuple[torch.Tensor, List[int], int]: | |
""" | |
Performs model inference for a batch of texts and returns embeddings, | |
individual token counts, and total prompt tokens for the batch. | |
Handles CUDA Out of Memory errors. | |
""" | |
try: | |
batch_dict = tokenizer( | |
texts_to_tokenize, | |
max_length=model_max_tokens, | |
padding=True, | |
truncation=True, | |
return_tensors="pt", | |
) | |
individual_tokens_in_batch = [ | |
int(torch.sum(mask).item()) for mask in batch_dict["attention_mask"] | |
] | |
prompt_tokens_current_batch = int(torch.sum(batch_dict["attention_mask"]).item()) | |
batch_dict = {k: v.to(DEVICE) for k, v in batch_dict.items()} | |
with torch.no_grad(): | |
outputs = model(**batch_dict) | |
embeddings = outputs.last_hidden_state[:, 0] | |
embeddings = embeddings[:, :model_dimension] | |
embeddings = F.normalize(embeddings, p=2, dim=1) | |
return embeddings, individual_tokens_in_batch, prompt_tokens_current_batch | |
except torch.cuda.OutOfMemoryError as e: | |
logger.error( | |
f"CUDA Out of Memory Error during embedding generation: {e}. " | |
"Consider reducing EMBEDDING_BATCH_SIZE or using a smaller model.", | |
exc_info=True | |
) | |
raise HTTPException( | |
status_code=status.HTTP_507_INSUFFICIENT_STORAGE, | |
detail=f"GPU out of memory: {e}. Please try with a smaller batch size or a different model." | |
) | |
except Exception as e: | |
logger.error( | |
f"An unexpected error occurred during batch embedding generation: {e}", exc_info=True | |
) | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"Internal server error during embedding generation: {str(e)}" | |
) | |
finally: | |
if settings.cuda_cache_clear_enabled and torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
logger.debug("CUDA cache cleared after processing chunk.") | |
def _store_embeddings_in_cache( | |
embeddings: torch.Tensor, | |
individual_tokens_in_batch: List[int], | |
batch_original_indices: List[int], | |
texts: List[str], | |
model_config: dict, | |
final_ordered_embeddings: List[Union[torch.Tensor, None]], | |
settings: AppSettings | |
): | |
""" | |
Stores newly generated embeddings in the cache and updates the final ordered embeddings list. | |
""" | |
canonical_hf_model_name = model_config["name"] | |
for j, original_idx in enumerate(batch_original_indices): | |
current_text = texts[original_idx] | |
current_embedding = embeddings[j].cpu() | |
current_tokens = individual_tokens_in_batch[j] | |
current_text_hash = hashlib.sha256(current_text.encode('utf-8')).hexdigest() | |
if settings.embeddings_cache_enabled: | |
embeddings_cache[(current_text_hash, canonical_hf_model_name)] = (current_embedding, current_tokens) | |
final_ordered_embeddings[original_idx] = current_embedding.unsqueeze(0) | |
async def get_embeddings_batch( | |
texts: List[str], | |
model_name: str, | |
settings: AppSettings = Depends(get_app_settings) | |
) -> Tuple[torch.Tensor, int]: | |
""" | |
Generates embeddings for a batch of texts using the specified model. | |
Handles potential CUDA out of memory errors by processing texts in chunks. | |
Includes an in-memory cache for individual text-model pairs. | |
Args: | |
texts (List[str]): The list of input texts to embed. | |
model_name (str): The name of the model to use. | |
settings (AppSettings): Application settings injected via FastAPI's Depends. | |
""" | |
config = get_model_config(model_name) | |
model, tokenizer = await load_model(model_name) | |
model_max_tokens = config.get("max_tokens", 8192) | |
model_dimension = config["dimension"] | |
max_batch_size = settings.embedding_batch_size | |
final_ordered_embeddings, total_prompt_tokens, texts_to_process_in_model, original_indices_for_model_output = \ | |
_process_texts_for_cache_and_batching(texts, config, settings) | |
if texts_to_process_in_model: | |
for i in range(0, len(texts_to_process_in_model), max_batch_size): | |
batch_texts = texts_to_process_in_model[i : i + max_batch_size] | |
batch_original_indices = original_indices_for_model_output[i : i + max_batch_size] | |
texts_to_tokenize = _apply_instruction_prefix(batch_texts, config) | |
embeddings, individual_tokens_in_batch, prompt_tokens_current_batch = \ | |
_perform_model_inference(texts_to_tokenize, model, tokenizer, model_max_tokens, model_dimension, settings) | |
total_prompt_tokens += prompt_tokens_current_batch | |
_store_embeddings_in_cache( | |
embeddings, | |
individual_tokens_in_batch, | |
batch_original_indices, | |
texts, | |
config, | |
final_ordered_embeddings, | |
settings | |
) | |
final_embeddings_tensor = torch.cat([e for e in final_ordered_embeddings if e is not None], dim=0) | |
return final_embeddings_tensor, total_prompt_tokens | |
async def read_root(): | |
""" | |
Serve the static index.html file at the root route. | |
""" | |
return FileResponse("static/index.html") | |
async def list_models(): | |
""" | |
Lists the available embedding models. | |
Returns: | |
ListModelsResponse: The response containing a list of model objects. | |
""" | |
model_list = [] | |
current_time = int(time.time()) | |
for model_name in MODELS.keys(): | |
model_list.append( | |
ModelObject( | |
id=model_name, | |
created=current_time, | |
owned_by="local", | |
) | |
) | |
return ListModelsResponse(data=model_list) | |
async def get_model(model_id: str): | |
""" | |
Retrieves information about a specific embedding model. | |
Args: | |
model_id (str): The ID of the model to retrieve. | |
""" | |
if model_id in MODELS: | |
current_time = int(time.time()) | |
return ModelObject( | |
id=model_id, | |
created=current_time, | |
owned_by="local", | |
) | |
else: | |
raise HTTPException(status_code=404, detail="Model not found") | |
async def create_embeddings(request: EmbeddingRequest, settings: AppSettings = Depends(get_app_settings)): | |
""" | |
Generates embeddings for the given input text(s) using batch processing. | |
Compatible with OpenAI's Embeddings API format. | |
The input can be a single string or a list of strings. | |
Returns a list of embedding objects, each containing the embedding vector. | |
""" | |
try: | |
start_time = time.time() | |
if isinstance(request.input, str): | |
texts = [request.input] | |
else: | |
texts = request.input | |
if not texts: | |
return EmbeddingResponse( | |
data=[], | |
model=request.model, | |
object="list", | |
usage={"prompt_tokens": 0, "total_tokens": 0}, | |
) | |
embeddings_tensor, total_tokens = await get_embeddings_batch(texts, request.model, settings) | |
data = [ | |
EmbeddingObject(embedding=embeddings_tensor[i].tolist(), index=i) | |
for i in range(len(texts)) | |
] | |
usage = { | |
"prompt_tokens": total_tokens, | |
"total_tokens": total_tokens, | |
} | |
end_time = time.time() | |
processing_time = end_time - start_time | |
if settings.environment != "production": | |
logger.debug( | |
f"Processed {len(texts)} inputs in {processing_time:.4f} seconds. " | |
f"Model: {request.model}. Tokens: {total_tokens}." | |
) | |
return EmbeddingResponse( | |
data=data, model=request.model, object="list", usage=usage | |
) | |
except ValueError as e: | |
logger.error(f"Validation error in /v1/embeddings: {e}", exc_info=True) | |
raise HTTPException(status_code=422, detail=str(e)) | |
except HTTPException as e: | |
logger.error(f"HTTPException in /v1/embeddings: {e.detail}", exc_info=True) | |
raise e | |
except Exception as e: | |
logger.error(f"Unhandled error in /v1/embeddings: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
async def validation_exception_handler(request: Request, exc: RequestValidationError): | |
logger.error(f"Validation error for request to {request.url}: {exc.errors()}") | |
raise HTTPException(status_code=422, detail=str(exc.errors())) | |
if __name__ == "__main__": | |
uvicorn.run(app, host=get_app_settings().app_host, port=get_app_settings().app_port) | |