File size: 15,590 Bytes
b4c92f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
"""
FastText embedding module for Tibetan text.
This module provides functions to train and use FastText models for Tibetan text.
"""

import os
import math
import logging
import numpy as np
import fasttext
from typing import List, Optional
from huggingface_hub import hf_hub_download

# Set up logging
logger = logging.getLogger(__name__)

# Default parameters optimized for Tibetan
DEFAULT_DIM = 100
DEFAULT_EPOCH = 5
DEFAULT_MIN_COUNT = 5
DEFAULT_WINDOW = 5
DEFAULT_MINN = 3
DEFAULT_MAXN = 6
DEFAULT_NEG = 5

# Define paths for model storage
DEFAULT_MODEL_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
DEFAULT_MODEL_PATH = os.path.join(DEFAULT_MODEL_DIR, "fasttext_model.bin")

# Facebook's official Tibetan FastText model
FACEBOOK_TIBETAN_MODEL_ID = "facebook/fasttext-bo-vectors"
FACEBOOK_TIBETAN_MODEL_FILE = "model.bin"

# Create models directory if it doesn't exist
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)

def ensure_dir_exists(directory: str) -> None:
    """
    Ensure that a directory exists, creating it if necessary.
    
    Args:
        directory: Directory path to ensure exists
    """
    if not os.path.exists(directory):
        os.makedirs(directory, exist_ok=True)


def train_fasttext_model(
    corpus_path: str,
    model_path: str = DEFAULT_MODEL_PATH,
    dim: int = DEFAULT_DIM,
    epoch: int = DEFAULT_EPOCH,
    min_count: int = DEFAULT_MIN_COUNT,
    window: int = DEFAULT_WINDOW,
    minn: int = DEFAULT_MINN,
    maxn: int = DEFAULT_MAXN,
    neg: int = DEFAULT_NEG,
    model_type: str = "skipgram"
) -> fasttext.FastText._FastText:
    """
    Train a FastText model on Tibetan corpus using optimized parameters.
    
    Args:
        corpus_path: Path to the corpus file
        model_path: Path where to save the trained model
        dim: Embedding dimension (default: 300)
        epoch: Number of training epochs (default: 15)
        min_count: Minimum count of words (default: 3)
        window: Size of context window (default: 5)
        minn: Minimum length of char n-gram (default: 3)
        maxn: Maximum length of char n-gram (default: 6)
        neg: Number of negatives in negative sampling (default: 10)
        model_type: FastText model type ('skipgram' or 'cbow')
        
    Returns:
        Trained FastText model
    """
    ensure_dir_exists(os.path.dirname(model_path))
    
    logger.info("Training FastText model with %s, dim=%d, epoch=%d, window=%d, minn=%d, maxn=%d...", 
               model_type, dim, epoch, window, minn, maxn)
    
    # Preprocess corpus for Tibetan - segment by syllable points
    # This is based on research showing syllable segmentation works better for Tibetan
    try:
        with open(corpus_path, 'r', encoding='utf-8') as f:
            content = f.read()
            
        # Ensure syllable segmentation by adding spaces after Tibetan syllable markers (if not already present)
        # This improves model quality for Tibetan text according to research
        processed_content = content.replace('་', '་ ')
        
        # Write back the processed content
        with open(corpus_path, 'w', encoding='utf-8') as f:
            f.write(processed_content)
        
        logger.info("Preprocessed corpus with syllable segmentation for Tibetan text")
    except Exception as e:
        logger.warning("Could not preprocess corpus for syllable segmentation: %s", str(e))
    
    # Train the model with optimized parameters
    if model_type == "skipgram":
        model = fasttext.train_unsupervised(
            corpus_path,
            model="skipgram",
            dim=dim,
            epoch=epoch,
            minCount=min_count,
            wordNgrams=1,
            minn=minn,
            maxn=maxn,
            neg=neg,
            window=window
        )
    else:  # cbow
        model = fasttext.train_unsupervised(
            corpus_path,
            model="cbow",
            dim=dim,
            epoch=epoch,
            minCount=min_count,
            wordNgrams=1,
            minn=minn,
            maxn=maxn,
            neg=neg,
            window=window
        )
    
    # Save the model
    model.save_model(model_path)
    logger.info("FastText model trained and saved to %s", model_path)
    
    return model


def load_fasttext_model(model_path: str = DEFAULT_MODEL_PATH) -> Optional[fasttext.FastText._FastText]:
    """
    Load a FastText model from file, with fallback to official Facebook model.
    
    Args:
        model_path: Path to the model file
        
    Returns:
        Loaded FastText model or None if loading fails
    """
    try:
        # First try to load the official Facebook FastText Tibetan model
        try:
            # Try to download the official Facebook FastText Tibetan model
            logger.info("Attempting to download and load official Facebook FastText Tibetan model")
            facebook_model_path = hf_hub_download(
                repo_id=FACEBOOK_TIBETAN_MODEL_ID,
                filename=FACEBOOK_TIBETAN_MODEL_FILE,
                cache_dir=DEFAULT_MODEL_DIR
            )
            logger.info("Loading official Facebook FastText Tibetan model from %s", facebook_model_path)
            return fasttext.load_model(facebook_model_path)
        except Exception as e:
            logger.warning("Could not load official Facebook FastText Tibetan model: %s", str(e))
            logger.info("Falling back to local model")
        
        # Fall back to local model
        if os.path.exists(model_path):
            logger.info("Loading local FastText model from %s", model_path)
            return fasttext.load_model(model_path)
        else:
            logger.warning("Model path %s does not exist", model_path)
            return None
    except Exception as e:
        logger.error("Error loading FastText model: %s", str(e))
        return None


def get_text_embedding(
    text: str,
    model: fasttext.FastText._FastText,
    tokenize_fn=None,
    use_stopwords: bool = True,
    stopwords_set=None,
    use_tfidf_weighting: bool = True,  # Enabled by default for better results
    corpus_token_freq=None
) -> np.ndarray:
    """
    Get embedding for a text using a FastText model with optional TF-IDF weighting.
    
    Args:
        text: Input text
        model: FastText model
        tokenize_fn: Optional tokenization function or pre-tokenized list
        use_stopwords: Whether to filter out stopwords before computing embeddings
        stopwords_set: Set of stopwords to filter out (if use_stopwords is True)
        use_tfidf_weighting: Whether to use TF-IDF weighting for averaging word vectors
        corpus_token_freq: Dictionary of token frequencies across corpus (required for TF-IDF)
        
    Returns:
        Text embedding vector
    """
    if not text.strip():
        return np.zeros(model.get_dimension())
    
    # Handle tokenization
    if tokenize_fn is None:
        # Simple whitespace tokenization as fallback
        tokens = text.split()
    elif isinstance(tokenize_fn, list):
        # If tokenize_fn is already a list of tokens, use it directly
        tokens = tokenize_fn
    elif callable(tokenize_fn):
        # If tokenize_fn is a function, call it
        tokens = tokenize_fn(text)
    else:
        # If tokenize_fn is something else (like a string), use whitespace tokenization
        logger.warning(f"Unexpected tokenize_fn type: {type(tokenize_fn)}. Using default whitespace tokenization.")
        tokens = text.split()
    
    # Filter out stopwords if enabled and stopwords_set is provided
    if use_stopwords and stopwords_set is not None:
        tokens = [token for token in tokens if token not in stopwords_set]
    
    # If all tokens were filtered out as stopwords, return zero vector
    if not tokens:
        return np.zeros(model.get_dimension())
    
    # Filter out empty tokens
    tokens = [token for token in tokens if token.strip()]
    
    if not tokens:
        return np.zeros(model.get_dimension())
    
    # Calculate TF-IDF weighted average if requested
    if use_tfidf_weighting and corpus_token_freq is not None:
        # Calculate term frequencies in this document
        token_counts = {}
        for token in tokens:
            token_counts[token] = token_counts.get(token, 0) + 1
        
        # Calculate IDF for each token with improved stability
        N = sum(corpus_token_freq.values())  # Total number of tokens in corpus
        N = max(N, 1)  # Ensure N is at least 1 to avoid division by zero
        
        # Compute TF-IDF weights with safeguards against extreme values
        weights = []
        for token in tokens:
            # Term frequency in this document
            tf = token_counts.get(token, 0) / max(len(tokens), 1) if len(tokens) > 0 else 0
            
            # Inverse document frequency with smoothing to avoid extreme values
            token_freq = corpus_token_freq.get(token, 0)
            idf = math.log((N + 1) / (token_freq + 1)) + 1  # Add 1 for smoothing
            
            # TF-IDF weight with bounds to prevent extreme values
            weight = tf * idf
            weight = min(max(weight, 0.1), 10.0)  # Limit to reasonable range
            weights.append(weight)
        
        # Normalize weights to sum to 1 with stability checks
        total_weight = sum(weights)
        if total_weight > 0:
            weights = [w / total_weight for w in weights]
        else:
            # If all weights are 0, use uniform weights
            weights = [1.0 / len(tokens) if len(tokens) > 0 else 0 for _ in tokens]
            
        # Check for NaN or infinite values and replace with uniform weights if found
        if any(math.isnan(w) or math.isinf(w) for w in weights):
            logger.warning("Found NaN or infinite weights in TF-IDF calculation. Using uniform weights instead.")
            weights = [1.0 / len(tokens) if len(tokens) > 0 else 0 for _ in tokens]
        
        # Get vectors for each token and apply weights
        vectors = [model.get_word_vector(token) for token in tokens]
        weighted_vectors = [w * v for w, v in zip(weights, vectors)]
        
        # Sum the weighted vectors
        return np.sum(weighted_vectors, axis=0)
    else:
        # Simple averaging if TF-IDF is not enabled or corpus frequencies not provided
        vectors = [model.get_word_vector(token) for token in tokens]
        return np.mean(vectors, axis=0)


def get_batch_embeddings(
    texts: List[str], 
    model: fasttext.FastText._FastText,
    tokenize_fn=None,
    use_stopwords: bool = True,
    stopwords_set=None,
    use_tfidf_weighting: bool = True,  # Enabled by default for better results
    corpus_token_freq=None
) -> np.ndarray:
    """
    Get embeddings for a batch of texts with optional TF-IDF weighting.
    
    Args:
        texts: List of input texts
        model: FastText model
        tokenize_fn: Optional tokenization function or pre-tokenized list of tokens
        use_stopwords: Whether to filter out stopwords before computing embeddings
        stopwords_set: Set of stopwords to filter out (if use_stopwords is True)
        use_tfidf_weighting: Whether to use TF-IDF weighting for averaging word vectors
        corpus_token_freq: Dictionary of token frequencies across corpus (required for TF-IDF)
        
    Returns:
        Array of text embedding vectors
    """
    # If corpus_token_freq is not provided but TF-IDF is requested, build it from the texts
    if use_tfidf_weighting and corpus_token_freq is None:
        logger.info("Building corpus token frequency dictionary for TF-IDF weighting")
        corpus_token_freq = {}
        
        # Process each text to build corpus token frequencies
        for text in texts:
            if not text.strip():
                continue
                
            # Handle tokenization
            if tokenize_fn is None:
                tokens = text.split()
            elif isinstance(tokenize_fn, list):
                # In this case, tokenize_fn should be a list of lists (one list of tokens per text)
                # This is not a common use case, so we'll just use the first one as fallback
                tokens = tokenize_fn[0] if tokenize_fn else []
            else:
                tokens = tokenize_fn(text)
            
            # Filter out stopwords if enabled
            if use_stopwords and stopwords_set is not None:
                tokens = [token for token in tokens if token not in stopwords_set]
            
            # Update corpus token frequencies
            for token in tokens:
                if token.strip():  # Skip empty tokens
                    corpus_token_freq[token] = corpus_token_freq.get(token, 0) + 1
        
        logger.info("Built corpus token frequency dictionary with %d unique tokens", len(corpus_token_freq))
    
    # Get embeddings for each text
    embeddings = []
    for i, text in enumerate(texts):
        # Handle pre-tokenized input
        tokens = None
        if isinstance(tokenize_fn, list):
            if i < len(tokenize_fn):
                tokens = tokenize_fn[i]
        
        embedding = get_text_embedding(
            text, 
            model, 
            tokenize_fn=tokens,  # Pass the tokens directly, not the function
            use_stopwords=use_stopwords,
            stopwords_set=stopwords_set,
            use_tfidf_weighting=use_tfidf_weighting,
            corpus_token_freq=corpus_token_freq
        )
        embeddings.append(embedding)
    
    return np.array(embeddings)


def generate_embeddings(
    texts: List[str],
    model: fasttext.FastText._FastText,
    device: str,
    model_type: str = "sentence_transformer",
    tokenize_fn=None,
    use_stopwords: bool = True,
    use_lite_stopwords: bool = False
) -> np.ndarray:
    """
    Generate embeddings for a list of texts using a FastText model.
    
    Args:
        texts: List of input texts
        model: FastText model
        device: Device to use for computation (not used for FastText)
        model_type: Model type ('sentence_transformer' or 'fasttext')
        tokenize_fn: Optional tokenization function or pre-tokenized list of tokens
        use_stopwords: Whether to filter out stopwords
        use_lite_stopwords: Whether to use a lighter set of stopwords
        
    Returns:
        Array of text embedding vectors
    """
    if model_type != "fasttext":
        logger.warning("Model type %s not supported for FastText. Using FastText anyway.", model_type)
    
    # Generate embeddings using FastText
    try:
        # Load stopwords if needed
        stopwords_set = None
        if use_stopwords:
            from .tibetan_stopwords import get_stopwords
            stopwords_set = get_stopwords(use_lite=use_lite_stopwords)
            logger.info("Loaded %d Tibetan stopwords", len(stopwords_set))
        
        # Generate embeddings
        embeddings = get_batch_embeddings(
            texts, 
            model, 
            tokenize_fn=tokenize_fn,
            use_stopwords=use_stopwords,
            stopwords_set=stopwords_set,
            use_tfidf_weighting=True  # Enable TF-IDF weighting for better results
        )
        
        logger.info("FastText embeddings generated with shape: %s", str(embeddings.shape))
        return embeddings
    except Exception as e:
        logger.error("Error generating FastText embeddings: %s", str(e))
        # Return empty embeddings as fallback
        return np.zeros((len(texts), model.get_dimension()))