ttm-webapp-hf / pipeline /semantic_embedding.py
daniel-wojahn's picture
Upload 19 files
4bf5701 verified
raw
history blame
6 kB
import logging
import torch
from sentence_transformers import SentenceTransformer
# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# Define the model ID for the fine-tuned Tibetan MiniLM
DEFAULT_MODEL_NAME = "buddhist-nlp/buddhist-sentence-similarity"
def get_sentence_transformer_model_and_device(
model_id: str = DEFAULT_MODEL_NAME, device_preference: str = "auto"
):
"""
Loads the Sentence Transformer model and determines the device.
Priority: CUDA -> MPS (Apple Silicon) -> CPU.
Args:
model_id (str): The Hugging Face model ID.
device_preference (str): Preferred device ("cuda", "mps", "cpu", "auto").
Returns:
tuple: (model, device_str)
- model: The loaded SentenceTransformer model.
- device_str: The device the model is loaded on ("cuda", "mps", or "cpu").
"""
selected_device_str = ""
if device_preference == "auto":
if torch.cuda.is_available():
selected_device_str = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
selected_device_str = "mps"
else:
selected_device_str = "cpu"
elif device_preference == "cuda" and torch.cuda.is_available():
selected_device_str = "cuda"
elif (
device_preference == "mps"
and hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
):
selected_device_str = "mps"
else: # Handles explicit "cpu" preference or fallback if preferred is unavailable
selected_device_str = "cpu"
logger.info(f"Attempting to use device: {selected_device_str}")
try:
logger.info(
f"Loading Sentence Transformer model: {model_id} on device: {selected_device_str}"
)
# SentenceTransformer expects a string like 'cuda', 'mps', or 'cpu'
model = SentenceTransformer(model_id, device=selected_device_str)
logger.info(f"Model {model_id} loaded successfully on {selected_device_str}.")
return model, selected_device_str
except Exception as e:
logger.error(
f"Error loading model {model_id} on device {selected_device_str}: {e}"
)
# Fallback to CPU if the initially selected device (CUDA or MPS) failed
if selected_device_str != "cpu":
logger.warning(
f"Failed to load model on {selected_device_str}, attempting to load on CPU..."
)
fallback_device_str = "cpu"
try:
model = SentenceTransformer(model_id, device=fallback_device_str)
logger.info(
f"Model {model_id} loaded successfully on CPU after fallback."
)
return model, fallback_device_str
except Exception as fallback_e:
logger.error(
f"Error loading model {model_id} on CPU during fallback: {fallback_e}"
)
raise fallback_e # Re-raise exception if CPU fallback also fails
raise e # Re-raise original exception if selected_device_str was already CPU or no fallback attempted
def generate_embeddings(texts: list[str], model, device: str):
"""
Generates embeddings for a list of texts using the provided Sentence Transformer model.
Args:
texts (list[str]): A list of texts to embed.
model: The loaded SentenceTransformer model.
device (str): The device the model is on (primarily for logging, model.encode handles device).
Returns:
torch.Tensor: A tensor containing the embeddings, moved to CPU.
"""
if not texts:
logger.warning(
"No texts provided to generate_embeddings. Returning empty tensor."
)
return torch.empty(0)
logger.info(f"Generating embeddings for {len(texts)} texts...")
# The encode method of SentenceTransformer handles tokenization and pooling internally.
# It also manages moving data to the model's device.
embeddings = model.encode(texts, convert_to_tensor=True, show_progress_bar=True)
logger.info(f"Embeddings generated with shape: {embeddings.shape}")
return (
embeddings.cpu()
) # Ensure embeddings are on CPU for consistent further processing
if __name__ == "__main__":
# Example usage:
logger.info("Starting example usage of semantic_embedding module...")
test_texts = [
"བཀྲ་ཤིས་བདེ་ལེགས།",
"hello world", # Test with non-Tibetan to see behavior
"དེ་རིང་གནམ་གཤིས་ཡག་པོ་འདུག",
]
logger.info("Attempting to load model using default cache directory.")
try:
# Forcing CPU for this example to avoid potential CUDA issues in diverse environments
# or if CUDA is not intended for this specific test.
st_model, st_device = get_sentence_transformer_model_and_device(
device_preference="cpu" # Explicitly use CPU for this test run
)
if st_model:
logger.info(f"Test model loaded on device: {st_device}")
example_embeddings = generate_embeddings(test_texts, st_model, st_device)
logger.info(
f"Generated example embeddings shape: {example_embeddings.shape}"
)
if example_embeddings.nelement() > 0: # Check if tensor is not empty
logger.info(
f"First embedding (first 10 dims): {example_embeddings[0][:10]}..."
)
else:
logger.info("Generated example embeddings tensor is empty.")
else:
logger.error("Failed to load model for example usage.")
except Exception as e:
logger.error(f"An error occurred during the example usage: {e}")
logger.info("Finished example usage.")