File size: 2,990 Bytes
b2ce320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b44d470
 
 
 
 
 
b2ce320
 
 
 
 
 
b44d470
 
b2ce320
 
b44d470
b2ce320
 
b44d470
 
 
 
b2ce320
 
 
b44d470
 
 
 
 
 
b2ce320
 
 
 
b44d470
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import logging
from typing import List, Any, Optional, Tuple
import numpy as np
from sentence_transformers import SentenceTransformer

logger = logging.getLogger(__name__)

# Cache for loaded models
_model_cache = {}

def get_model(model_id: str) -> Tuple[Optional[SentenceTransformer], Optional[str]]:
    """
    Loads a SentenceTransformer model from the Hugging Face Hub.

    Args:
        model_id (str): The identifier for the model to load (e.g., 'sentence-transformers/LaBSE').

    Returns:
        Tuple[Optional[SentenceTransformer], Optional[str]]: A tuple containing the loaded model and its type ('sentence-transformer'),
                                                              or (None, None) if loading fails.
    """
    if model_id in _model_cache:
        logger.info(f"Returning cached model: {model_id}")
        return _model_cache[model_id], "sentence-transformer"

    logger.info(f"Loading SentenceTransformer model: {model_id}")
    try:
        model = SentenceTransformer(model_id)
        _model_cache[model_id] = model
        logger.info(f"Model '{model_id}' loaded successfully.")
        return model, "sentence-transformer"
    except Exception as e:
        logger.error(f"Failed to load SentenceTransformer model '{model_id}': {e}", exc_info=True)
        return None, None

def generate_embeddings(
    texts: List[str], 
    model: SentenceTransformer, 
    batch_size: int = 32, 
    show_progress_bar: bool = False
) -> np.ndarray:
    """
    Generates embeddings for a list of texts using a SentenceTransformer model.

    Args:
        texts (list[str]): A list of texts to embed.
        model (SentenceTransformer): The loaded SentenceTransformer model.
        batch_size (int): The batch size for encoding.
        show_progress_bar (bool): Whether to display a progress bar.

    Returns:
        np.ndarray: A numpy array containing the embeddings. Returns an empty array of the correct shape on failure.
    """
    if not texts or not isinstance(model, SentenceTransformer):
        logger.warning("Invalid input for generating embeddings. Returning empty array.")
        # Return a correctly shaped empty array
        embedding_dim = model.get_sentence_embedding_dimension() if isinstance(model, SentenceTransformer) else 768 # Fallback
        return np.zeros((len(texts), embedding_dim))

    logger.info(f"Generating embeddings for {len(texts)} texts with {type(model).__name__}...")
    try:
        embeddings = model.encode(
            texts, 
            batch_size=batch_size,
            convert_to_numpy=True, 
            show_progress_bar=show_progress_bar
        )
        logger.info(f"Embeddings generated with shape: {embeddings.shape}")
        return embeddings
    except Exception as e:
        logger.error(f"An unexpected error occurred during embedding generation: {e}", exc_info=True)
        embedding_dim = model.get_sentence_embedding_dimension()
        return np.zeros((len(texts), embedding_dim))