Spaces:
Running
Running
import logging | |
import torch | |
from typing import List, Any | |
from sentence_transformers import SentenceTransformer | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
logger = logging.getLogger(__name__) | |
# Define the model ID for the fine-tuned Tibetan MiniLM | |
DEFAULT_MODEL_NAME = "buddhist-nlp/buddhist-sentence-similarity" | |
# FastText model identifier - this is just an internal identifier, not a HuggingFace model ID | |
FASTTEXT_MODEL_ID = "fasttext-tibetan" | |
def get_model_and_device( | |
model_id: str = DEFAULT_MODEL_NAME, device_preference: str = "auto" | |
): | |
""" | |
Loads the Sentence Transformer model and determines the device. | |
Priority: CUDA -> MPS (Apple Silicon) -> CPU. | |
Args: | |
model_id (str): The Hugging Face model ID. | |
device_preference (str): Preferred device ("cuda", "mps", "cpu", "auto"). | |
Returns: | |
tuple: (model, device_str) | |
- model: The loaded SentenceTransformer model. | |
- device_str: The device the model is loaded on ("cuda", "mps", or "cpu"). | |
""" | |
selected_device_str = "" | |
if device_preference == "auto": | |
if torch.cuda.is_available(): | |
selected_device_str = "cuda" | |
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
selected_device_str = "mps" | |
else: | |
selected_device_str = "cpu" | |
elif device_preference == "cuda" and torch.cuda.is_available(): | |
selected_device_str = "cuda" | |
elif ( | |
device_preference == "mps" | |
and hasattr(torch.backends, "mps") | |
and torch.backends.mps.is_available() | |
): | |
selected_device_str = "mps" | |
else: # Handles explicit "cpu" preference or fallback if preferred is unavailable | |
selected_device_str = "cpu" | |
logger.info("Attempting to use device: %s", selected_device_str) | |
try: | |
# Check if this is a FastText model request | |
if model_id == FASTTEXT_MODEL_ID: | |
try: | |
# Import here to avoid dependency issues if FastText is not installed | |
import fasttext | |
from .fasttext_embedding import load_fasttext_model | |
# Try to load the FastText model | |
model = load_fasttext_model() | |
if model is None: | |
error_msg = "Failed to load FastText model. Semantic similarity will not be available." | |
logger.error(error_msg) | |
raise Exception(error_msg) | |
logger.info("FastText model loaded successfully.") | |
# FastText always runs on CPU | |
return model, "cpu", "fasttext" | |
except ImportError: | |
logger.error("FastText module not found. Please install it with 'pip install fasttext'.") | |
raise | |
else: | |
logger.info( | |
"Loading Sentence Transformer model: %s on device: %s", | |
model_id, selected_device_str | |
) | |
# SentenceTransformer expects a string like 'cuda', 'mps', or 'cpu' | |
model = SentenceTransformer(model_id, device=selected_device_str) | |
logger.info("Model %s loaded successfully on %s.", model_id, selected_device_str) | |
return model, selected_device_str, "sentence_transformer" | |
except Exception as e: | |
logger.error( | |
"Error loading model %s on device %s: %s", | |
model_id, selected_device_str, str(e) | |
) | |
# Fallback to CPU if the initially selected device (CUDA or MPS) failed | |
if selected_device_str != "cpu": | |
logger.warning( | |
"Failed to load model on %s, attempting to load on CPU...", | |
selected_device_str | |
) | |
fallback_device_str = "cpu" | |
try: | |
# Check if this is a FastText model request during fallback | |
if model_id == FASTTEXT_MODEL_ID: | |
# Import here to avoid dependency issues if FastText is not installed | |
from .fasttext_embedding import load_fasttext_model | |
# Try to load the FastText model | |
model = load_fasttext_model() | |
if model is None: | |
logger.error("Failed to load FastText model during fallback. Semantic similarity will not be available.") | |
raise Exception("Failed to load FastText model. Please check if the model file exists.") | |
logger.info("FastText model loaded successfully during fallback.") | |
# FastText always runs on CPU | |
return model, "cpu", "fasttext" | |
else: | |
# Try to load as a sentence transformer | |
model = SentenceTransformer(model_id, device=fallback_device_str) | |
logger.info( | |
"Model %s loaded successfully on CPU after fallback.", | |
model_id | |
) | |
return model, fallback_device_str, "sentence_transformer" | |
except Exception as fallback_e: | |
logger.error( | |
"Error loading model %s on CPU during fallback: %s", | |
model_id, str(fallback_e) | |
) | |
raise fallback_e # Re-raise exception if CPU fallback also fails | |
raise e # Re-raise original exception if selected_device_str was already CPU or no fallback attempted | |
def generate_embeddings(texts: List[str], model: Any, device: str, model_type: str = "sentence_transformer", tokenize_fn=None, use_stopwords: bool = True, use_lite_stopwords: bool = False): | |
""" | |
Generates embeddings for a list of texts using the provided model. | |
Args: | |
texts (list[str]): A list of texts to embed. | |
model: The loaded model (SentenceTransformer or FastText). | |
device (str): The device to use ("cuda", "mps", or "cpu"). | |
model_type (str): Type of model ("sentence_transformer" or "fasttext") | |
tokenize_fn: Optional tokenization function or pre-tokenized list for FastText | |
use_stopwords (bool): Whether to filter out stopwords for FastText embeddings | |
Returns: | |
torch.Tensor: A tensor containing the embeddings, moved to CPU. | |
""" | |
if not texts: | |
logger.warning( | |
"No texts provided to generate_embeddings. Returning empty tensor." | |
) | |
return torch.empty(0) | |
logger.info(f"Generating embeddings for {len(texts)} texts...") | |
if model_type == "fasttext": | |
try: | |
# Import here to avoid dependency issues if FastText is not installed | |
from .fasttext_embedding import get_batch_embeddings | |
from .stopwords_bo import TIBETAN_STOPWORDS_SET | |
# For FastText, get appropriate stopwords set if filtering is enabled | |
stopwords_set = None | |
if use_stopwords: | |
# Choose between regular and lite stopwords sets | |
if use_lite_stopwords: | |
from .stopwords_lite_bo import TIBETAN_STOPWORDS_LITE_SET | |
stopwords_set = TIBETAN_STOPWORDS_LITE_SET | |
else: | |
from .stopwords_bo import TIBETAN_STOPWORDS_SET | |
stopwords_set = TIBETAN_STOPWORDS_SET | |
# Pass pre-tokenized tokens if available, otherwise pass None | |
# tokenize_fn should be a list of lists (tokens for each text) or None | |
embeddings = get_batch_embeddings( | |
texts, | |
model, | |
tokenize_fn=tokenize_fn, | |
use_stopwords=use_stopwords, | |
stopwords_set=stopwords_set | |
) | |
logger.info("FastText embeddings generated with shape: %s", str(embeddings.shape)) | |
# Convert numpy array to torch tensor for consistency | |
return torch.tensor(embeddings) | |
except ImportError: | |
logger.error("FastText module not found. Please install it with 'pip install fasttext'.") | |
raise | |
else: # sentence_transformer | |
# The encode method of SentenceTransformer handles tokenization and pooling internally. | |
# It also manages moving data to the model's device. | |
embeddings = model.encode(texts, convert_to_tensor=True, show_progress_bar=True) | |
logger.info("Sentence Transformer embeddings generated with shape: %s", str(embeddings.shape)) | |
return ( | |
embeddings.cpu() | |
) # Ensure embeddings are on CPU for consistent further processing | |
def train_fasttext_model(corpus_texts: List[str], **kwargs): | |
""" | |
Train a FastText model on the provided corpus texts. | |
Args: | |
corpus_texts: List of texts to use for training | |
**kwargs: Additional parameters for training (dim, epoch, etc.) | |
Returns: | |
Trained model and path to the model file | |
""" | |
try: | |
from .fasttext_embedding import prepare_corpus_file, train_fasttext_model as train_ft | |
# Prepare corpus file | |
corpus_path = prepare_corpus_file(corpus_texts) | |
# Train the model | |
model = train_ft(corpus_path=corpus_path, **kwargs) | |
return model | |
except ImportError: | |
logger.error("FastText module not found. Please install it with 'pip install fasttext'.") | |
raise | |
if __name__ == "__main__": | |
# Example usage: | |
logger.info("Starting example usage of semantic_embedding module...") | |
test_texts = [ | |
"བཀྲ་ཤིས་བདེ་ལེགས།", | |
"hello world", # Test with non-Tibetan to see behavior | |
"དེ་རིང་གནམ་གཤིས་ཡག་པོ་འདུག", | |
] | |
logger.info("Attempting to load model using default cache directory.") | |
try: | |
# Forcing CPU for this example to avoid potential CUDA issues in diverse environments | |
# or if CUDA is not intended for this specific test. | |
model, device, model_type = get_model_and_device( | |
device_preference="cpu" # Explicitly use CPU for this test run | |
) | |
if model: | |
logger.info("Test model loaded on device: %s, type: %s", device, model_type) | |
example_embeddings = generate_embeddings(test_texts, model, device, model_type) | |
logger.info( | |
"Generated example embeddings shape: %s", | |
str(example_embeddings.shape) | |
) | |
if example_embeddings.nelement() > 0: # Check if tensor is not empty | |
logger.info( | |
"First embedding (first 10 dims): %s...", | |
str(example_embeddings[0][:10]) | |
) | |
else: | |
logger.info("Generated example embeddings tensor is empty.") | |
else: | |
logger.error("Failed to load model for example usage.") | |
except Exception as e: | |
logger.error("An error occurred during the example usage: %s", str(e)) | |
logger.info("Finished example usage.") | |