File size: 11,242 Bytes
4bf5701
 
b4c92f5
4bf5701
 
 
 
 
 
 
 
 
 
 
b4c92f5
 
4bf5701
b4c92f5
 
4bf5701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4c92f5
4bf5701
 
b4c92f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bf5701
 
b4c92f5
 
4bf5701
 
 
 
b4c92f5
 
4bf5701
 
 
b4c92f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bf5701
 
b4c92f5
 
4bf5701
 
 
 
 
b4c92f5
4bf5701
b4c92f5
4bf5701
 
 
b4c92f5
 
 
 
 
4bf5701
 
 
 
 
 
 
 
 
 
 
 
b4c92f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bf5701
b4c92f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bf5701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4c92f5
4bf5701
 
 
b4c92f5
 
 
4bf5701
b4c92f5
 
4bf5701
 
 
b4c92f5
 
4bf5701
 
 
 
 
 
 
b4c92f5
4bf5701
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import logging
import torch
from typing import List, Any
from sentence_transformers import SentenceTransformer

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# Define the model ID for the fine-tuned Tibetan MiniLM
DEFAULT_MODEL_NAME = "buddhist-nlp/buddhist-sentence-similarity"

# FastText model identifier - this is just an internal identifier, not a HuggingFace model ID
FASTTEXT_MODEL_ID = "fasttext-tibetan"


def get_model_and_device(
    model_id: str = DEFAULT_MODEL_NAME, device_preference: str = "auto"
):
    """
    Loads the Sentence Transformer model and determines the device.
    Priority: CUDA -> MPS (Apple Silicon) -> CPU.

    Args:
        model_id (str): The Hugging Face model ID.
        device_preference (str): Preferred device ("cuda", "mps", "cpu", "auto").

    Returns:
        tuple: (model, device_str)
               - model: The loaded SentenceTransformer model.
               - device_str: The device the model is loaded on ("cuda", "mps", or "cpu").
    """
    selected_device_str = ""

    if device_preference == "auto":
        if torch.cuda.is_available():
            selected_device_str = "cuda"
        elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
            selected_device_str = "mps"
        else:
            selected_device_str = "cpu"
    elif device_preference == "cuda" and torch.cuda.is_available():
        selected_device_str = "cuda"
    elif (
        device_preference == "mps"
        and hasattr(torch.backends, "mps")
        and torch.backends.mps.is_available()
    ):
        selected_device_str = "mps"
    else:  # Handles explicit "cpu" preference or fallback if preferred is unavailable
        selected_device_str = "cpu"

    logger.info("Attempting to use device: %s", selected_device_str)

    try:
        # Check if this is a FastText model request
        if model_id == FASTTEXT_MODEL_ID:
            try:
                # Import here to avoid dependency issues if FastText is not installed
                import fasttext
                from .fasttext_embedding import load_fasttext_model
                
                # Try to load the FastText model
                model = load_fasttext_model()
                
                if model is None:
                    error_msg = "Failed to load FastText model. Semantic similarity will not be available."
                    logger.error(error_msg)
                    raise Exception(error_msg)
                    
                logger.info("FastText model loaded successfully.")
                # FastText always runs on CPU
                return model, "cpu", "fasttext"
            except ImportError:
                logger.error("FastText module not found. Please install it with 'pip install fasttext'.")
                raise
        else:
            logger.info(
                "Loading Sentence Transformer model: %s on device: %s",
                model_id, selected_device_str
            )
            # SentenceTransformer expects a string like 'cuda', 'mps', or 'cpu'
            model = SentenceTransformer(model_id, device=selected_device_str)
            logger.info("Model %s loaded successfully on %s.", model_id, selected_device_str)
            return model, selected_device_str, "sentence_transformer"
    except Exception as e:
        logger.error(
            "Error loading model %s on device %s: %s",
            model_id, selected_device_str, str(e)
        )
        # Fallback to CPU if the initially selected device (CUDA or MPS) failed
        if selected_device_str != "cpu":
            logger.warning(
                "Failed to load model on %s, attempting to load on CPU...",
                selected_device_str
            )
            fallback_device_str = "cpu"
            try:
                # Check if this is a FastText model request during fallback
                if model_id == FASTTEXT_MODEL_ID:
                    # Import here to avoid dependency issues if FastText is not installed
                    from .fasttext_embedding import load_fasttext_model
                    
                    # Try to load the FastText model
                    model = load_fasttext_model()
                    
                    if model is None:
                        logger.error("Failed to load FastText model during fallback. Semantic similarity will not be available.")
                        raise Exception("Failed to load FastText model. Please check if the model file exists.")
                        
                    logger.info("FastText model loaded successfully during fallback.")
                    # FastText always runs on CPU
                    return model, "cpu", "fasttext"
                else:
                    # Try to load as a sentence transformer
                    model = SentenceTransformer(model_id, device=fallback_device_str)
                    logger.info(
                        "Model %s loaded successfully on CPU after fallback.",
                        model_id
                    )
                    return model, fallback_device_str, "sentence_transformer"
            except Exception as fallback_e:
                logger.error(
                    "Error loading model %s on CPU during fallback: %s",
                    model_id, str(fallback_e)
                )
                raise fallback_e  # Re-raise exception if CPU fallback also fails
        raise e  # Re-raise original exception if selected_device_str was already CPU or no fallback attempted


def generate_embeddings(texts: List[str], model: Any, device: str, model_type: str = "sentence_transformer", tokenize_fn=None, use_stopwords: bool = True, use_lite_stopwords: bool = False):
    """
    Generates embeddings for a list of texts using the provided model.

    Args:
        texts (list[str]): A list of texts to embed.
        model: The loaded model (SentenceTransformer or FastText).
        device (str): The device to use ("cuda", "mps", or "cpu").
        model_type (str): Type of model ("sentence_transformer" or "fasttext")
        tokenize_fn: Optional tokenization function or pre-tokenized list for FastText
        use_stopwords (bool): Whether to filter out stopwords for FastText embeddings

    Returns:
        torch.Tensor: A tensor containing the embeddings, moved to CPU.
    """
    if not texts:
        logger.warning(
            "No texts provided to generate_embeddings. Returning empty tensor."
        )
        return torch.empty(0)

    logger.info(f"Generating embeddings for {len(texts)} texts...")

    if model_type == "fasttext":
        try:
            # Import here to avoid dependency issues if FastText is not installed
            from .fasttext_embedding import get_batch_embeddings
            from .stopwords_bo import TIBETAN_STOPWORDS_SET
            
            # For FastText, get appropriate stopwords set if filtering is enabled
            stopwords_set = None
            if use_stopwords:
                # Choose between regular and lite stopwords sets
                if use_lite_stopwords:
                    from .stopwords_lite_bo import TIBETAN_STOPWORDS_LITE_SET
                    stopwords_set = TIBETAN_STOPWORDS_LITE_SET
                else:
                    from .stopwords_bo import TIBETAN_STOPWORDS_SET
                    stopwords_set = TIBETAN_STOPWORDS_SET
            
            # Pass pre-tokenized tokens if available, otherwise pass None
            # tokenize_fn should be a list of lists (tokens for each text) or None
            embeddings = get_batch_embeddings(
                texts, 
                model, 
                tokenize_fn=tokenize_fn, 
                use_stopwords=use_stopwords, 
                stopwords_set=stopwords_set
            )
            logger.info("FastText embeddings generated with shape: %s", str(embeddings.shape))
            # Convert numpy array to torch tensor for consistency
            return torch.tensor(embeddings)
        except ImportError:
            logger.error("FastText module not found. Please install it with 'pip install fasttext'.")
            raise
    else:  # sentence_transformer
        # The encode method of SentenceTransformer handles tokenization and pooling internally.
        # It also manages moving data to the model's device.
        embeddings = model.encode(texts, convert_to_tensor=True, show_progress_bar=True)
        logger.info("Sentence Transformer embeddings generated with shape: %s", str(embeddings.shape))
        return (
            embeddings.cpu()
        )  # Ensure embeddings are on CPU for consistent further processing


def train_fasttext_model(corpus_texts: List[str], **kwargs):
    """
    Train a FastText model on the provided corpus texts.
    
    Args:
        corpus_texts: List of texts to use for training
        **kwargs: Additional parameters for training (dim, epoch, etc.)
        
    Returns:
        Trained model and path to the model file
    """
    try:
        from .fasttext_embedding import prepare_corpus_file, train_fasttext_model as train_ft
        
        # Prepare corpus file
        corpus_path = prepare_corpus_file(corpus_texts)
        
        # Train the model
        model = train_ft(corpus_path=corpus_path, **kwargs)
        
        return model
    except ImportError:
        logger.error("FastText module not found. Please install it with 'pip install fasttext'.")
        raise


if __name__ == "__main__":
    # Example usage:
    logger.info("Starting example usage of semantic_embedding module...")

    test_texts = [
        "བཀྲ་ཤིས་བདེ་ལེགས།",
        "hello world",  # Test with non-Tibetan to see behavior
        "དེ་རིང་གནམ་གཤིས་ཡག་པོ་འདུག",
    ]

    logger.info("Attempting to load model using default cache directory.")
    try:
        # Forcing CPU for this example to avoid potential CUDA issues in diverse environments
        # or if CUDA is not intended for this specific test.
        model, device, model_type = get_model_and_device(
            device_preference="cpu"  # Explicitly use CPU for this test run
        )

        if model:
            logger.info("Test model loaded on device: %s, type: %s", device, model_type)
            example_embeddings = generate_embeddings(test_texts, model, device, model_type)
            logger.info(
                "Generated example embeddings shape: %s",
                str(example_embeddings.shape)
            )
            if example_embeddings.nelement() > 0:  # Check if tensor is not empty
                logger.info(
                    "First embedding (first 10 dims): %s...",
                    str(example_embeddings[0][:10])
                )
            else:
                logger.info("Generated example embeddings tensor is empty.")
        else:
            logger.error("Failed to load model for example usage.")

    except Exception as e:
        logger.error("An error occurred during the example usage: %s", str(e))

    logger.info("Finished example usage.")