ttm-webapp-hf / pipeline /semantic_embedding.py
daniel-wojahn's picture
Upload 19 files
b4c92f5 verified
raw
history blame
11.2 kB
import logging
import torch
from typing import List, Any
from sentence_transformers import SentenceTransformer
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Define the model ID for the fine-tuned Tibetan MiniLM
DEFAULT_MODEL_NAME = "buddhist-nlp/buddhist-sentence-similarity"
# FastText model identifier - this is just an internal identifier, not a HuggingFace model ID
FASTTEXT_MODEL_ID = "fasttext-tibetan"
def get_model_and_device(
model_id: str = DEFAULT_MODEL_NAME, device_preference: str = "auto"
):
"""
Loads the Sentence Transformer model and determines the device.
Priority: CUDA -> MPS (Apple Silicon) -> CPU.
Args:
model_id (str): The Hugging Face model ID.
device_preference (str): Preferred device ("cuda", "mps", "cpu", "auto").
Returns:
tuple: (model, device_str)
- model: The loaded SentenceTransformer model.
- device_str: The device the model is loaded on ("cuda", "mps", or "cpu").
"""
selected_device_str = ""
if device_preference == "auto":
if torch.cuda.is_available():
selected_device_str = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
selected_device_str = "mps"
else:
selected_device_str = "cpu"
elif device_preference == "cuda" and torch.cuda.is_available():
selected_device_str = "cuda"
elif (
device_preference == "mps"
and hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
):
selected_device_str = "mps"
else: # Handles explicit "cpu" preference or fallback if preferred is unavailable
selected_device_str = "cpu"
logger.info("Attempting to use device: %s", selected_device_str)
try:
# Check if this is a FastText model request
if model_id == FASTTEXT_MODEL_ID:
try:
# Import here to avoid dependency issues if FastText is not installed
import fasttext
from .fasttext_embedding import load_fasttext_model
# Try to load the FastText model
model = load_fasttext_model()
if model is None:
error_msg = "Failed to load FastText model. Semantic similarity will not be available."
logger.error(error_msg)
raise Exception(error_msg)
logger.info("FastText model loaded successfully.")
# FastText always runs on CPU
return model, "cpu", "fasttext"
except ImportError:
logger.error("FastText module not found. Please install it with 'pip install fasttext'.")
raise
else:
logger.info(
"Loading Sentence Transformer model: %s on device: %s",
model_id, selected_device_str
)
# SentenceTransformer expects a string like 'cuda', 'mps', or 'cpu'
model = SentenceTransformer(model_id, device=selected_device_str)
logger.info("Model %s loaded successfully on %s.", model_id, selected_device_str)
return model, selected_device_str, "sentence_transformer"
except Exception as e:
logger.error(
"Error loading model %s on device %s: %s",
model_id, selected_device_str, str(e)
)
# Fallback to CPU if the initially selected device (CUDA or MPS) failed
if selected_device_str != "cpu":
logger.warning(
"Failed to load model on %s, attempting to load on CPU...",
selected_device_str
)
fallback_device_str = "cpu"
try:
# Check if this is a FastText model request during fallback
if model_id == FASTTEXT_MODEL_ID:
# Import here to avoid dependency issues if FastText is not installed
from .fasttext_embedding import load_fasttext_model
# Try to load the FastText model
model = load_fasttext_model()
if model is None:
logger.error("Failed to load FastText model during fallback. Semantic similarity will not be available.")
raise Exception("Failed to load FastText model. Please check if the model file exists.")
logger.info("FastText model loaded successfully during fallback.")
# FastText always runs on CPU
return model, "cpu", "fasttext"
else:
# Try to load as a sentence transformer
model = SentenceTransformer(model_id, device=fallback_device_str)
logger.info(
"Model %s loaded successfully on CPU after fallback.",
model_id
)
return model, fallback_device_str, "sentence_transformer"
except Exception as fallback_e:
logger.error(
"Error loading model %s on CPU during fallback: %s",
model_id, str(fallback_e)
)
raise fallback_e # Re-raise exception if CPU fallback also fails
raise e # Re-raise original exception if selected_device_str was already CPU or no fallback attempted
def generate_embeddings(texts: List[str], model: Any, device: str, model_type: str = "sentence_transformer", tokenize_fn=None, use_stopwords: bool = True, use_lite_stopwords: bool = False):
"""
Generates embeddings for a list of texts using the provided model.
Args:
texts (list[str]): A list of texts to embed.
model: The loaded model (SentenceTransformer or FastText).
device (str): The device to use ("cuda", "mps", or "cpu").
model_type (str): Type of model ("sentence_transformer" or "fasttext")
tokenize_fn: Optional tokenization function or pre-tokenized list for FastText
use_stopwords (bool): Whether to filter out stopwords for FastText embeddings
Returns:
torch.Tensor: A tensor containing the embeddings, moved to CPU.
"""
if not texts:
logger.warning(
"No texts provided to generate_embeddings. Returning empty tensor."
)
return torch.empty(0)
logger.info(f"Generating embeddings for {len(texts)} texts...")
if model_type == "fasttext":
try:
# Import here to avoid dependency issues if FastText is not installed
from .fasttext_embedding import get_batch_embeddings
from .stopwords_bo import TIBETAN_STOPWORDS_SET
# For FastText, get appropriate stopwords set if filtering is enabled
stopwords_set = None
if use_stopwords:
# Choose between regular and lite stopwords sets
if use_lite_stopwords:
from .stopwords_lite_bo import TIBETAN_STOPWORDS_LITE_SET
stopwords_set = TIBETAN_STOPWORDS_LITE_SET
else:
from .stopwords_bo import TIBETAN_STOPWORDS_SET
stopwords_set = TIBETAN_STOPWORDS_SET
# Pass pre-tokenized tokens if available, otherwise pass None
# tokenize_fn should be a list of lists (tokens for each text) or None
embeddings = get_batch_embeddings(
texts,
model,
tokenize_fn=tokenize_fn,
use_stopwords=use_stopwords,
stopwords_set=stopwords_set
)
logger.info("FastText embeddings generated with shape: %s", str(embeddings.shape))
# Convert numpy array to torch tensor for consistency
return torch.tensor(embeddings)
except ImportError:
logger.error("FastText module not found. Please install it with 'pip install fasttext'.")
raise
else: # sentence_transformer
# The encode method of SentenceTransformer handles tokenization and pooling internally.
# It also manages moving data to the model's device.
embeddings = model.encode(texts, convert_to_tensor=True, show_progress_bar=True)
logger.info("Sentence Transformer embeddings generated with shape: %s", str(embeddings.shape))
return (
embeddings.cpu()
) # Ensure embeddings are on CPU for consistent further processing
def train_fasttext_model(corpus_texts: List[str], **kwargs):
"""
Train a FastText model on the provided corpus texts.
Args:
corpus_texts: List of texts to use for training
**kwargs: Additional parameters for training (dim, epoch, etc.)
Returns:
Trained model and path to the model file
"""
try:
from .fasttext_embedding import prepare_corpus_file, train_fasttext_model as train_ft
# Prepare corpus file
corpus_path = prepare_corpus_file(corpus_texts)
# Train the model
model = train_ft(corpus_path=corpus_path, **kwargs)
return model
except ImportError:
logger.error("FastText module not found. Please install it with 'pip install fasttext'.")
raise
if __name__ == "__main__":
# Example usage:
logger.info("Starting example usage of semantic_embedding module...")
test_texts = [
"བཀྲ་ཤིས་བདེ་ལེགས།",
"hello world", # Test with non-Tibetan to see behavior
"དེ་རིང་གནམ་གཤིས་ཡག་པོ་འདུག",
]
logger.info("Attempting to load model using default cache directory.")
try:
# Forcing CPU for this example to avoid potential CUDA issues in diverse environments
# or if CUDA is not intended for this specific test.
model, device, model_type = get_model_and_device(
device_preference="cpu" # Explicitly use CPU for this test run
)
if model:
logger.info("Test model loaded on device: %s, type: %s", device, model_type)
example_embeddings = generate_embeddings(test_texts, model, device, model_type)
logger.info(
"Generated example embeddings shape: %s",
str(example_embeddings.shape)
)
if example_embeddings.nelement() > 0: # Check if tensor is not empty
logger.info(
"First embedding (first 10 dims): %s...",
str(example_embeddings[0][:10])
)
else:
logger.info("Generated example embeddings tensor is empty.")
else:
logger.error("Failed to load model for example usage.")
except Exception as e:
logger.error("An error occurred during the example usage: %s", str(e))
logger.info("Finished example usage.")