import time import pandas as pd import polars as pl import torch import logging from datasets import Dataset from sentence_transformers import SentenceTransformer from typing import Optional logger = logging.getLogger(__name__) def sts(modelname: str, data1: str, data2: str, score: float) -> Optional[pl.DataFrame]: """ Calculate semantic textual similarity between two sets of sentences. Args: modelname: Name of the model to use data1: Path to first input CSV file data2: Path to second input CSV file score: Minimum similarity score threshold Returns: Optional[pl.DataFrame]: DataFrame with similarity results or None if error occurs """ try: st = time.time() # Initialize model device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") model = SentenceTransformer( modelname, device=device, trust_remote_code=True, ) # Read and validate input data sentences1 = Dataset.from_pandas(pd.read_csv(data1, on_bad_lines='skip', header=0, sep="\t")) sentences2 = Dataset.from_pandas(pd.read_csv(data2, on_bad_lines='skip', header=0, sep="\t")) if sentences1.num_rows == 0 or sentences2.num_rows == 0: logger.error("Empty input data found") return None # Generate embeddings logger.info("Generating embeddings for first set...") embeddings1 = model.encode( sentences1["text"], normalize_embeddings=True, batch_size=1024, show_progress_bar=True ) logger.info("Generating embeddings for second set...") embeddings2 = model.encode( sentences2["text"], normalize_embeddings=True, batch_size=1024, show_progress_bar=True ) # Calculate similarity matrix logger.info("Calculating similarity matrix...") similarity_matrix = model.similarity(embeddings1, embeddings2) # Process results df_pd = pd.DataFrame(similarity_matrix) dfi = df_pd.__dataframe__() df = pl.from_dataframe(dfi) # Transform matrix to long format df_matrix_with_index = df.with_row_index(name="row_index").with_columns( pl.col("row_index").cast(pl.UInt64) ) df_long = df_matrix_with_index.unpivot( index="row_index", variable_name="column_index", value_name="score" ).with_columns(pl.col("column_index").cast(pl.UInt64)) # Join with original text df_sentences1 = pl.DataFrame(sentences1.to_pandas()).with_row_index(name="row_index").with_columns( pl.col("row_index").cast(pl.UInt64) ) df_sentences2 = pl.DataFrame(sentences2.to_pandas()).with_row_index(name="column_index").with_columns( pl.col("column_index").cast(pl.UInt64) ) # Process final results df_long = (df_long .with_columns([pl.col("score").round(4).cast(pl.Float32)]) .join(df_sentences1, on="row_index") .join(df_sentences2, on="column_index")) df_long = df_long.rename({ "text": "sentences1", "text_right": "sentences2", }).drop(["row_index", "column_index"]) # Filter and sort results result_df = df_long.filter(pl.col("score") > score).sort(["score"], descending=True) elapsed_time = time.time() - st logger.info(f'Execution time: {time.strftime("%H:%M:%S", time.gmtime(elapsed_time))}') logger.info(f'Found {len(result_df)} pairs above score threshold {score}') return result_df except Exception as e: logger.error(f"Error in STS process: {str(e)}") return None