File size: 28,808 Bytes
a4de739
01b1a90
3035463
01b1a90
b7f6be7
 
a4de739
0c9e4b7
e64fe0e
0c9e4b7
9a12ea3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e64fe0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c9e4b7
 
 
4e0cddb
da3acda
 
ab88097
 
6024481
 
da3acda
ab88097
da3acda
 
 
01b1a90
6024481
da3acda
6024481
da3acda
 
ab88097
3035463
ab88097
 
 
01b1a90
6024481
3035463
6024481
ab88097
da3acda
b7f6be7
4a365e4
 
 
b7f6be7
01b1a90
6024481
4a365e4
6024481
4a365e4
 
 
4e0cddb
b0796be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fb581d
b0796be
 
 
8fb581d
b0796be
7c4de94
da3acda
01b1a90
5bf8193
6024481
da3acda
5bf8193
3035463
da3acda
 
3035463
 
da3acda
 
3035463
45b666a
 
 
b0796be
3035463
5bf8193
3035463
 
 
01b1a90
45b666a
3035463
 
 
 
 
da3acda
3035463
 
 
 
 
ff45445
3035463
 
 
da0c779
 
45b666a
da0c779
ea39258
45b666a
6e37521
ea39258
8fb581d
da3acda
5bf8193
da3acda
 
4a365e4
ab88097
5bf8193
45b666a
ab88097
 
da3acda
 
ab88097
22a278f
ab88097
 
 
 
b0796be
ab88097
5bf8193
22a278f
6024481
17afa62
b0796be
17afa62
 
b0796be
17afa62
7c4de94
ab88097
 
01b1a90
22a278f
ab88097
 
22a278f
ab88097
 
 
22a278f
ab88097
22a278f
ab88097
22a278f
ea0b3fe
ab88097
 
22a278f
da0c779
 
ab88097
da0c779
ea39258
22a278f
795a6cd
ea39258
8fb581d
3035463
5bf8193
3035463
da3acda
f4c84bc
b7f6be7
5bf8193
4a365e4
 
b7f6be7
f4c84bc
 
b7f6be7
 
f4c84bc
b0796be
8fb581d
b7f6be7
6024481
 
8fb581d
6024481
4a365e4
 
 
 
 
 
 
 
6024481
4a365e4
4beff90
4a365e4
 
 
da0c779
 
 
 
ea39258
f4c84bc
795a6cd
3356cfa
4a365e4
5bf8193
4a365e4
 
01b1a90
 
b7f6be7
6024481
f384e43
4a365e4
4beff90
01b1a90
da3acda
5bf8193
da3acda
3bcd060
 
372cab2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fb581d
372cab2
 
 
 
8fb581d
372cab2
 
 
 
 
b7f6be7
372cab2
 
 
b7f6be7
372cab2
b7f6be7
 
 
 
 
 
372cab2
 
3bcd060
 
372cab2
 
da0c779
372cab2
 
 
 
 
 
 
 
 
8fb581d
372cab2
 
 
 
 
 
 
 
 
 
 
 
 
ea39258
372cab2
 
 
da0c779
372cab2
d357027
372cab2
da0c779
372cab2
da0c779
ea39258
372cab2
ea39258
 
8fb581d
 
ea39258
372cab2
 
3bcd060
 
f384e43
 
 
 
4beff90
ea39258
372cab2
3bcd060
 
 
 
 
 
 
 
955e16b
3bcd060
 
 
372cab2
 
955e16b
372cab2
8e4067c
372cab2
 
8e4067c
5bf8193
 
 
8e4067c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da0c779
8e4067c
372cab2
8fb581d
8e4067c
 
d9c4db0
d3df2dd
 
 
 
 
 
8e4067c
 
 
 
da0c779
dd23928
955e16b
 
 
da0c779
 
dd23928
955e16b
 
8fb581d
372cab2
 
01b1a90
e64fe0e
 
945e73b
8cbba1e
 
 
9a12ea3
01b1a90
0c9e4b7
8cbba1e
 
01b1a90
 
a0d59cd
5bf8193
 
8cbba1e
5bf8193
01b1a90
a0d59cd
 
4beff90
01b1a90
a0d59cd
f384e43
5bf8193
 
01b1a90
 
 
 
e64fe0e
 
9a12ea3
 
8cbba1e
9a12ea3
 
 
e64fe0e
8cbba1e
 
955e16b
 
8cbba1e
9a12ea3
e64fe0e
5bf8193
 
 
9a12ea3
5bf8193
8cbba1e
5bf8193
8fb581d
8cbba1e
945e73b
8cbba1e
 
945e73b
8cbba1e
 
945e73b
 
e64fe0e
 
5bf8193
 
 
 
01b1a90
5bf8193
 
 
 
 
8fb581d
5bf8193
 
e64fe0e
5bf8193
 
 
e64fe0e
13a4dcd
8fb581d
 
3bcd060
f384e43
 
4beff90
3bcd060
 
d357027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44e6d15
 
 
 
 
 
 
 
 
 
 
 
 
 
d357027
 
44e6d15
 
 
 
 
 
 
372cab2
e64fe0e
372cab2
d357027
 
 
 
372cab2
44e6d15
e64fe0e
372cab2
d24dd3c
83bb0e4
d24dd3c
 
 
 
 
83bb0e4
d24dd3c
83bb0e4
 
 
d24dd3c
 
 
01b1a90
e64fe0e
945e73b
8cbba1e
 
 
 
955e16b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import numpy as np
from tqdm.auto import tqdm
import os


# CSS to style the custom share button (for the "Sparse Representation" tab)
css = """
.share-button-container {
    display: flex;
    justify-content: center;
    margin-top: 10px;
    margin-bottom: 20px;
}

.custom-share-button {
    background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
    border: none;
    color: white;
    padding: 8px 16px;
    text-align: center;
    text-decoration: none;
    display: inline-block;
    font-size: 14px;
    margin: 4px 2px;
    cursor: pointer;
    border-radius: 6px;
    transition: all 0.3s ease;
}

.custom-share-button:hover {
    transform: translateY(-2px);
    box-shadow: 0 4px 8px rgba(0,0,0,0.2);
}
/* IMPORTANT: This CSS targets Gradio's *default* share button that appears
   when demo.launch(share=True) is used.
   You might want to comment this out if you prefer Gradio's default positioning
   for the main share button (usually in the header/footer) and rely only on your custom one.
*/
.share-button {
    position: fixed !important;
    top: 20px !important;
    right: 20px !important;
    z-index: 1000 !important;
    background: #4CAF50 !important;
    color: white !important;
    border-radius: 8px !important;
    padding: 8px 16px !important;
    font-weight: bold !important;
    box-shadow: 0 2px 10px rgba(0,0,0,0.2) !important;
}

.share-button:hover {
    background: #45a049 !important;
    transform: translateY(-1px) !important;
}
"""


# --- Model Loading ---
tokenizer_splade = None
model_splade = None
tokenizer_splade_lexical = None
model_splade_lexical = None
tokenizer_splade_doc = None
model_splade_doc = None

# Load SPLADE v3 model (original)
try:
    tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
    model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
    model_splade.eval()
    print("SPLADE-cocondenser-distil model loaded successfully!")
except Exception as e:
    print(f"Error loading SPLADE-cocondenser-distil model: {e}")
    print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")

# Load SPLADE v3 Lexical model
try:
    splade_lexical_model_name = "naver/splade-v3-lexical"
    tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
    model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
    model_splade_lexical.eval()
    print(f"SPLADE-v3-Lexical model '{splade_lexical_model_name}' loaded successfully!")
except Exception as e:
    print(f"Error loading SPLADE-v3-Lexical model: {e}")
    print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")

# Load SPLADE v3 Doc model - Model loading is still necessary even if its logits aren't used for BoW
try:
    splade_doc_model_name = "naver/splade-v3-doc"
    tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
    model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name) # Still load the model
    model_splade_doc.eval()
    print(f"SPLADE-v3-Doc model '{splade_doc_model_name}' loaded successfully!")
except Exception as e:
    print(f"Error loading SPLADE-v3-Doc model: {e}")
    print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")


# --- Helper function for lexical mask (now handles batches, but used for single input here) ---
def create_lexical_bow_mask(input_ids_batch, vocab_size, tokenizer):
    """
    Creates a batch of lexical BOW masks.
    input_ids_batch: torch.Tensor of shape (batch_size, sequence_length)
    vocab_size: int, size of the tokenizer vocabulary
    tokenizer: the tokenizer object
    Returns: torch.Tensor of shape (batch_size, vocab_size)
    """
    batch_size = input_ids_batch.shape[0]
    bow_masks = torch.zeros(batch_size, vocab_size, device=input_ids_batch.device)

    for i in range(batch_size):
        input_ids = input_ids_batch[i] # Get input_ids for the current item in the batch
        meaningful_token_ids = []
        for token_id in input_ids.tolist():
            if token_id not in [
                tokenizer.pad_token_id,
                tokenizer.cls_token_id,
                tokenizer.sep_token_id,
                tokenizer.mask_token_id,
                tokenizer.unk_token_id
            ]:
                meaningful_token_ids.append(token_id)

        if meaningful_token_ids:
            # Apply mask to the current row in the batch
            bow_masks[i, list(set(meaningful_token_ids))] = 1

    return bow_masks


# --- Core Representation Functions (Return Formatted Strings - for Explorer Tab) ---
# These functions now return a tuple: (main_representation_str, info_str)
def get_splade_cocondenser_representation(text):
    if tokenizer_splade is None or model_splade is None:
        return "SPLADE-cocondenser-distil model is not loaded. Please check the console for loading errors.", ""

    inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}

    with torch.no_grad():
        output = model_splade(**inputs)

    if hasattr(output, 'logits'):
        splade_vector = torch.max(
            torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
            dim=1
        )[0].squeeze() # Squeeze is fine here as it's a single input
    else:
        return "Model output structure not as expected for SPLADE-cocondenser-distil. 'logits' not found.", ""

    indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
    if not isinstance(indices, list):
        indices = [indices] if indices else []

    values = splade_vector[indices].cpu().tolist()
    token_weights = dict(zip(indices, values))

    meaningful_tokens = {}
    for token_id, weight in token_weights.items():
        decoded_token = tokenizer_splade.decode([token_id])
        if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
            meaningful_tokens[decoded_token] = weight

    sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)

    formatted_output = "MLM Representation:\n\n"
    if not sorted_representation:
        formatted_output += "No significant terms found for this input.\n"
    else:
        # Changed to paragraph style
        terms_list = []
        for term, weight in sorted_representation:
            terms_list.append(f"**{term}**: {weight:.4f}")
        formatted_output += ", ".join(terms_list) + "."

    info_output = f"" # Line 1
    info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)


    return formatted_output, info_output


def get_splade_lexical_representation(text):
    if tokenizer_splade_lexical is None or model_splade_lexical is None:
        return "SPLADE-v3-Lexical model is not loaded. Please check the console for loading errors.", ""

    inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}

    with torch.no_grad():
        output = model_splade_lexical(**inputs)

    if hasattr(output, 'logits'):
        splade_vector = torch.max(
            torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
            dim=1
        )[0].squeeze() # Squeeze is fine here
    else:
        return "Model output structure not as expected for SPLADE-v3-Lexical. 'logits' not found.", ""

    # Always apply lexical mask for this model's specific behavior
    vocab_size = tokenizer_splade_lexical.vocab_size
    # Call with unsqueezed input_ids for single sample processing
    bow_mask = create_lexical_bow_mask(
        inputs['input_ids'], vocab_size, tokenizer_splade_lexical
    ).squeeze() # Squeeze back for single output
    splade_vector = splade_vector * bow_mask

    indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
    if not isinstance(indices, list):
        indices = [indices] if indices else []

    values = splade_vector[indices].cpu().tolist()
    token_weights = dict(zip(indices, values))

    meaningful_tokens = {}
    for token_id, weight in token_weights.items():
        decoded_token = tokenizer_splade_lexical.decode([token_id])
        if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
            meaningful_tokens[decoded_token] = weight

    sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)

    formatted_output = "MLP Representation:\n\n"
    if not sorted_representation:
        formatted_output += "No significant terms found for this input.\n"
    else:
        # Changed to paragraph style
        terms_list = []
        for term, weight in sorted_representation:
            terms_list.append(f"**{term}**: {weight:.4f}")
        formatted_output += ", ".join(terms_list) + "."

    info_output = f"" # Line 1
    info_output += f"Total non-zero terms in vector: {len(indices)}\n" # Line 2 (and onwards for sparsity)


    return formatted_output, info_output


def get_splade_doc_representation(text):
    if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
        return "SPLADE-v3-Doc tokenizer is not loaded. Please check the console for loading errors.", ""

    inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation

    vocab_size = tokenizer_splade_doc.vocab_size
    # Directly create the binary Bag-of-Words vector using the input_ids
    binary_bow_vector = create_lexical_bow_mask(
        inputs['input_ids'], vocab_size, tokenizer_splade_doc
    ).squeeze() # Squeeze back for single output

    indices = torch.nonzero(binary_bow_vector).squeeze().cpu().tolist()
    if not isinstance(indices, list):
        indices = [indices] if indices else []

    values = [1.0] * len(indices) # All values are 1 for binary representation
    token_weights = dict(zip(indices, values))

    meaningful_tokens = {}
    for token_id, weight in token_weights.items():
        decoded_token = tokenizer_splade_doc.decode([token_id])
        if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
            meaningful_tokens[decoded_token] = weight

    sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for clarity

    formatted_output = "Binary:\n\n"
    if not sorted_representation:
        formatted_output += "No significant terms found for this input.\n"
    else:
        # Changed to paragraph style
        terms_list = []
        for term, _ in sorted_representation: # For binary, weight is always 1, so no need to display
            terms_list.append(f"**{term}**")
        formatted_output += ", ".join(terms_list) + "."

    info_output = f"" # Line 1
    info_output += f"Total non-zero terms in vector: {len(indices)}" # Line 2

    return formatted_output, info_output


# --- Unified Prediction Function for the Explorer Tab ---
def predict_representation_explorer(model_choice, text):
    if model_choice == "MLM encoder (SPLADE-cocondenser-distil)":
        return get_splade_cocondenser_representation(text)
    elif model_choice == "MLP encoder (SPLADE-v3-lexical)":
        return get_splade_lexical_representation(text)
    elif model_choice == "Binary": # Changed name
        return get_splade_doc_representation(text)
    else:
        return "Please select a model.", "" # Return two empty strings for consistency

# --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
# These functions remain unchanged from the previous iteration, as they return the raw tensors.
def get_splade_cocondenser_vector(text):
    if tokenizer_splade is None or model_splade is None:
        return None

    inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}

    with torch.no_grad():
        output = model_splade(**inputs)

    if hasattr(output, 'logits'):
        splade_vector = torch.max(
            torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
            dim=1
        )[0].squeeze()
        return splade_vector
    return None

def get_splade_lexical_vector(text):
    if tokenizer_splade_lexical is None or model_splade_lexical is None:
        return None

    inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}

    with torch.no_grad():
        output = model_splade_lexical(**inputs)

    if hasattr(output, 'logits'):
        splade_vector = torch.max(
            torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
            dim=1
        )[0].squeeze()

        vocab_size = tokenizer_splade_lexical.vocab_size
        bow_mask = create_lexical_bow_mask(
            inputs['input_ids'], vocab_size, tokenizer_splade_lexical
        ).squeeze()

        splade_vector = splade_vector * bow_mask
        return splade_vector
    return None

def get_splade_doc_vector(text):
    if tokenizer_splade_doc is None: # No longer need model_splade_doc to be loaded for 'logits'
        return None

    inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
    inputs = {k: v.to(torch.device("cpu")) for k, v in inputs.items()} # Ensure on CPU for direct mask creation

    vocab_size = tokenizer_splade_doc.vocab_size
    # Directly create the binary Bag-of-Words vector using the input_ids
    binary_bow_vector = create_lexical_bow_mask(
        inputs['input_ids'], vocab_size, tokenizer_splade_doc
    ).squeeze()
    return binary_bow_vector


# --- Function to get formatted representation from a raw vector and tokenizer ---
# This function remains unchanged as it's a generic formatter for any sparse vector.
def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
    if splade_vector is None:
        return "Failed to generate vector.", ""

    indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
    if not isinstance(indices, list):
        indices = [indices] if indices else []

    if is_binary:
        values = [1.0] * len(indices)
    else:
        values = splade_vector[indices].cpu().tolist()

    token_weights = dict(zip(indices, values))

    meaningful_tokens = {}
    for token_id, weight in token_weights.items():
        decoded_token = tokenizer.decode([token_id])
        if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
            meaningful_tokens[decoded_token] = weight

    if is_binary:
        sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary
    else:
        sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)

    formatted_output = ""
    if not sorted_representation:
        formatted_output += "No significant terms found.\n"
    else:
        terms_list = []
        for i, (term, weight) in enumerate(sorted_representation):
            # Limiting to 50 terms for display to avoid overly long output
            if is_binary:
                terms_list.append(f"**{term}**")
            else:
                terms_list.append(f"**{term}**: {weight:.4f}")
        formatted_output += ", ".join(terms_list) + "."

    # This is the line that will now always be split into two
    info_output = f"Total non-zero terms: {len(indices)}\n" # Line 1


    return formatted_output, info_output


# --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
def get_model_assets(model_choice_str):
    if model_choice_str == "MLM encoder (SPLADE-cocondenser-distil)":
        return get_splade_cocondenser_vector, tokenizer_splade, False, "MLM encoder (SPLADE-cocondenser-distil)"
    elif model_choice_str == "MLP encoder (SPLADE-v3-lexical)":
        return get_splade_lexical_vector, tokenizer_splade_lexical, False, "MLP encoder (SPLADE-v3-lexical)"
    elif model_choice_str == "Binary":
        return get_splade_doc_vector, tokenizer_splade_doc, True, "Binary Bag-of-Words"
    else:
        return None, None, False, "Unknown Model"

# --- MODIFIED: Dot Product Calculation Function for the new tab ---
def calculate_dot_product_and_representations_independent(query_model_choice, doc_model_choice, query_text, doc_text):
    query_vector_fn, query_tokenizer, query_is_binary, query_model_name_display = get_model_assets(query_model_choice)
    doc_vector_fn, doc_tokenizer, doc_is_binary, doc_model_name_display = get_model_assets(doc_model_choice)

    if query_vector_fn is None or doc_vector_fn is None:
        return "Please select valid models for both query and document encoding.", ""

    query_vector = query_vector_fn(query_text)
    doc_vector = doc_vector_fn(doc_text)

    if query_vector is None or doc_vector is None:
        return "Failed to generate one or both vectors. Please check model loading and input text.", ""

    # Calculate overall dot product
    dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())

    # Format representations for display
    query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
    doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)

    # --- NEW FEATURE: Calculate dot product of overlapping terms ---
    overlapping_terms_dot_products = {}
    query_indices = torch.nonzero(query_vector).squeeze().cpu()
    doc_indices = torch.nonzero(doc_vector).squeeze().cpu()

    # Handle cases where vectors are empty or single element
    if query_indices.dim() == 0 and query_indices.numel() == 1:
        query_indices = query_indices.unsqueeze(0)
    if doc_indices.dim() == 0 and doc_indices.numel() == 1:
        doc_indices = doc_indices.unsqueeze(0)

    # Convert indices to sets for efficient intersection
    query_index_set = set(query_indices.tolist())
    doc_index_set = set(doc_indices.tolist())

    common_indices = sorted(list(query_index_set.intersection(doc_index_set)))

    if common_indices:
        for idx in common_indices:
            query_weight = query_vector[idx].item()
            doc_weight = doc_vector[idx].item()
            term = query_tokenizer.decode([idx]) # Tokenizers should be the same for this purpose
            if term not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(term.strip()) > 0:
                overlapping_terms_dot_products[term] = query_weight * doc_weight

    sorted_overlapping_dot_products = sorted(
        overlapping_terms_dot_products.items(),
        key=lambda item: item[1],
        reverse=True
    )
    # --- End NEW FEATURE ---

    # Combine output into a single string for the Markdown component
    full_output = f"### Overall Dot Product Score: {dot_product:.6f}\n\n"
    full_output += "---\n\n"

    # Overlapping Terms Dot Products
    if sorted_overlapping_dot_products:
                full_output += "### Product of Query and Document Term Scores:\n"
                full_output += "\n" # Removed the individual weight explanation
                overlap_list = []
                for term, product_val in sorted_overlapping_dot_products:
                    overlap_list.append(f"**{term}**: {product_val:.4f}") # Simplified to just the dot product
                full_output += ", ".join(overlap_list) + ".\n\n"
                full_output += "---\n\n"
    else:
        full_output += "### No Overlapping Terms Found.\n\n"
        full_output += "---\n\n"

    # Query Representation
    full_output += f"#### Query Representation: {query_model_name_display}\n" # Smaller heading for sub-section
    full_output += f"> {query_main_rep_str}\n" # Using blockquote for the sparse list
    full_output += f"> {query_info_str}\n" # Using blockquote for info as well
    full_output += "\n---\n\n" # Separator

    # Document Representation
    full_output += f"#### Document Representation: {doc_model_name_display}\n" # Smaller heading for sub-section
    full_output += f"> {doc_main_rep_str}\n" # Using blockquote
    full_output += f"> {doc_info_str}" # Using blockquote

    return full_output


# Global variable to store the share URL once the app is launched
global_share_url = None

def get_current_share_url():
    """Returns the globally stored share URL."""
    return global_share_url if global_share_url else "Share URL not available yet."

# --- Gradio Interface Setup with Tabs ---
with gr.Blocks(title="SPLADE Demos", css=css) as demo:
    gr.Markdown("# 🌌 Sparse Encoder Playground") # Updated title
    gr.Markdown("Explore different SPLADE models and their sparse representation types, and calculate similarity between query and document representations.") # Updated description

    with gr.Tabs():
        with gr.TabItem("Sparse Representation"):
            gr.Markdown("### Produce a Sparse Representation of an Input Text")
            with gr.Row():
                with gr.Column(scale=1): # Left column for inputs and info
                    model_radio = gr.Radio(
                        [
                            "MLM encoder (SPLADE-cocondenser-distil)",
                            "MLP encoder (SPLADE-v3-lexical)",
                            "Binary"
                        ],
                        label="Choose Sparse Encoder",
                        value="MLM encoder (SPLADE-cocondenser-distil)"
                    )
                    input_text = gr.Textbox(
                        lines=5,
                        label="Enter your query or document text here:",
                        placeholder="e.g., Why is Padua the nicest city in Italy?"
                    )

                    # Custom Share Button and URL display
                    with gr.Row(elem_classes="share-button-container"):
                        share_button = gr.Button(
                            "🔗 Get Share Link",
                            elem_classes="custom-share-button",
                            size="sm"
                        )

                    share_output = gr.Textbox(
                        label="Share URL",
                        interactive=True,
                        visible=False,
                        placeholder="Click 'Get Share Link' to generate URL..."
                    )

                    info_output_display = gr.Markdown(
                        value="",
                        label="Vector Information",
                        elem_id="info_output_display"
                    )
                with gr.Column(scale=2): # Right column for the main representation output
                    main_representation_output = gr.Markdown()

            # Connect share button.
            share_button.click(
                fn=get_current_share_url,
                outputs=share_output
            ).then(
                fn=lambda: gr.update(visible=True),
                outputs=share_output
            )


            # Connect the core prediction logic
            model_radio.change(
                fn=predict_representation_explorer,
                inputs=[model_radio, input_text],
                outputs=[main_representation_output, info_output_display]
            )
            input_text.change(
                fn=predict_representation_explorer,
                inputs=[model_radio, input_text],
                outputs=[main_representation_output, info_output_display]
            )

            # Initial call to populate on load (optional, but good for demo)
            demo.load(
                fn=lambda: predict_representation_explorer(model_radio.value, input_text.value),
                outputs=[main_representation_output, info_output_display]
            )


        with gr.TabItem("Compute Query-Document Similarity Score"):
            gr.Markdown("### Calculate Dot Product Similarity Between Encoded Query and Document")

            model_choices = [
                "MLM encoder (SPLADE-cocondenser-distil)",
                "MLP encoder (SPLADE-v3-lexical)",
                "Binary"
            ]

            # Input components for the second tab
            query_model_radio = gr.Radio(
                model_choices,
                label="Choose Query Encoding Model",
                value="MLM encoder (SPLADE-cocondenser-distil)"
            )
            doc_model_radio = gr.Radio(
                model_choices,
                label="Choose Document Encoding Model",
                value="MLM encoder (SPLADE-cocondenser-distil)"
            )
            query_text_input = gr.Textbox(
                lines=3,
                label="Enter Query Text:",
                placeholder="e.g., famous dishes of Padua"
            )
            doc_text_input = gr.Textbox(
                lines=5,
                label="Enter Document Text:",
                placeholder="e.g., Padua's cuisine is as famous as its legendary University."
            )

            # --- MODIFIED: Output component as a gr.Markdown with scrolling ---
            # Reverting to gr.Markdown, and adding height/scroll for it
            output_dot_product_markdown = gr.Markdown(
                # Use value="" to initialize, content will be set by the function
                value="",
                # Fixed height for the scrollable area
                # You can adjust this value (e.g., "500px") to your preference
                # Or set it as a percentage of available space, e.g., "80%"
                height=500, # Example: 500 pixels height
                # Enable vertical scrolling if content overflows
                # "auto" is often good, "scroll" always shows scrollbar
                # Gradio uses `css` for this, so these parameters might translate to inline styles
                # or custom CSS classes automatically added by Gradio.
                elem_classes=["scrollable-output"] # Add a custom class for CSS targeting if needed
            )

            # Add CSS specifically for this scrollable markdown output
            # This needs to be added to the overall `css` string or handled directly here
            # For simplicity, let's assume `height` itself will enable scroll in newer Gradio,
            # or add a specific CSS class targeting the markdown.
            # However, for pure markdown, `height` is the primary way.

            # Update the gr.Interface call to use the new Markdown output
            gr.Interface(
                fn=calculate_dot_product_and_representations_independent,
                inputs=[
                    query_model_radio,
                    doc_model_radio,
                    query_text_input,
                    doc_text_input
                ],
                outputs=output_dot_product_markdown, # Changed back to Markdown
                allow_flagging="never"
            )
    
    # --- UPDATED CITATION BLOCK WITH TWO REFERENCES ---
    gr.Markdown(
        """
        ---
        ### References

        This demo utilizes **SPLADE** models. For more details, please refer to the following papers:

        1.  Formal, T., Lassance, C., Piwowarski, B., & Clinchant, S. (2022). **From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective**. *arXiv preprint arXiv:2205.04733*. Available at: [https://arxiv.org/abs/2205.04733](https://arxiv.org/abs/2205.04733)

        2.  Lassance, C., Déjean, H., Formal, T., & Clinchant, S. (2024). **SPLADE-v3: New baselines for SPLADE**. *arXiv preprint arXiv:2403.06789*. Available at: [https://arxiv.org/abs/2403.06789](https://arxiv.org/abs/2403.06789)
        """
    )


# This block ensures the share URL is captured when the app launches
if __name__ == "__main__":
    launched_demo = demo.launch(share=True)
    print("\n--- Gradio App Launched ---")
    print("If a public share link is generated, it will be displayed in your console.")
    print("You can also use the '🔗 Get Share Link' button on the 'Sparse Representation' tab.")
    print("---------------------------\n")