SiddharthAK commited on
Commit
372cab2
·
verified ·
1 Parent(s): d302c88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +192 -3
app.py CHANGED
@@ -244,11 +244,170 @@ def predict_representation_explorer(model_choice, text):
244
  else:
245
  return "Please select a model."
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  # --- Gradio Interface Setup with Tabs ---
249
  with gr.Blocks(title="SPLADE Demos") as demo:
250
- gr.Markdown("# 🌌 SPLADE Demos: Sparse Representation Explorer") # Updated title
251
- gr.Markdown("Explore different SPLADE models and their sparse representation types.") # Updated description
252
 
253
  with gr.Tabs():
254
  with gr.TabItem("Sparse Representation Explorer"):
@@ -275,5 +434,35 @@ with gr.Blocks(title="SPLADE Demos") as demo:
275
  allow_flagging="never",
276
  # live=True # Setting live=True might be slow for complex models on every keystroke
277
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
- demo.launch()
 
244
  else:
245
  return "Please select a model."
246
 
247
+ # --- NEW: Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
248
+ def get_splade_cocondenser_vector(text):
249
+ if tokenizer_splade is None or model_splade is None:
250
+ return None
251
+
252
+ inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
253
+ inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
254
+
255
+ with torch.no_grad():
256
+ output = model_splade(**inputs)
257
+
258
+ if hasattr(output, 'logits'):
259
+ splade_vector = torch.max(
260
+ torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
261
+ dim=1
262
+ )[0].squeeze()
263
+ return splade_vector
264
+ return None
265
+
266
+ def get_splade_lexical_vector(text):
267
+ if tokenizer_splade_lexical is None or model_splade_lexical is None:
268
+ return None
269
+
270
+ inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
271
+ inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
272
+
273
+ with torch.no_grad():
274
+ output = model_splade_lexical(**inputs)
275
+
276
+ if hasattr(output, 'logits'):
277
+ splade_vector = torch.max(
278
+ torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
279
+ dim=1
280
+ )[0].squeeze()
281
+
282
+ vocab_size = tokenizer_splade_lexical.vocab_size
283
+ bow_mask = create_lexical_bow_mask(
284
+ inputs['input_ids'], vocab_size, tokenizer_splade_lexical
285
+ ).squeeze()
286
+
287
+ splade_vector = splade_vector * bow_mask
288
+ return splade_vector
289
+ return None
290
+
291
+ def get_splade_doc_vector(text):
292
+ if tokenizer_splade_doc is None or model_splade_doc is None:
293
+ return None
294
+
295
+ inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
296
+ inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
297
+
298
+ with torch.no_grad():
299
+ output = model_splade_doc(**inputs)
300
+
301
+ if hasattr(output, "logits"):
302
+ vocab_size = tokenizer_splade_doc.vocab_size
303
+ binary_splade_vector = create_lexical_bow_mask(
304
+ inputs['input_ids'], vocab_size, tokenizer_splade_doc
305
+ ).squeeze()
306
+ return binary_splade_vector
307
+ return None
308
+
309
+
310
+ # --- NEW: Function to get formatted representation from a raw vector and tokenizer ---
311
+ def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
312
+ if splade_vector is None:
313
+ return "Failed to generate vector."
314
+
315
+ indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
316
+ if not isinstance(indices, list):
317
+ indices = [indices] if indices else []
318
+
319
+ if is_binary:
320
+ values = [1.0] * len(indices)
321
+ else:
322
+ values = splade_vector[indices].cpu().tolist()
323
+
324
+ token_weights = dict(zip(indices, values))
325
+
326
+ meaningful_tokens = {}
327
+ for token_id, weight in token_weights.items():
328
+ decoded_token = tokenizer.decode([token_id])
329
+ if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
330
+ meaningful_tokens[decoded_token] = weight
331
+
332
+ if is_binary:
333
+ sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary
334
+ else:
335
+ sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
336
+
337
+ formatted_output = ""
338
+ if not sorted_representation:
339
+ formatted_output += "No significant terms found.\n"
340
+ else:
341
+ for i, (term, weight) in enumerate(sorted_representation):
342
+ if i >= 50 and is_binary: # Limit display for very long binary lists
343
+ formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n"
344
+ break
345
+ if is_binary:
346
+ formatted_output += f"- **{term}**\n"
347
+ else:
348
+ formatted_output += f"- **{term}**: {weight:.4f}\n"
349
+
350
+ formatted_output += f"\nTotal non-zero terms: {len(indices)}\n"
351
+ formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer.vocab_size):.2%}\n"
352
+
353
+ return formatted_output
354
+
355
+
356
+ # --- NEW: Dot Product Calculation Function for the new tab ---
357
+ def calculate_dot_product_and_representations(model_choice, query_text, doc_text):
358
+ query_vector = None
359
+ doc_vector = None
360
+ query_rep_str = ""
361
+ doc_rep_str = ""
362
+
363
+ selected_tokenizer = None
364
+
365
+ if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
366
+ query_vector = get_splade_cocondenser_vector(query_text)
367
+ doc_vector = get_splade_cocondenser_vector(doc_text)
368
+ selected_tokenizer = tokenizer_splade
369
+ query_rep_str = "Query SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
370
+ doc_rep_str = "Document SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
371
+ is_binary = False
372
+ elif model_choice == "SPLADE-v3-Lexical (weighting)":
373
+ query_vector = get_splade_lexical_vector(query_text)
374
+ doc_vector = get_splade_lexical_vector(doc_text)
375
+ selected_tokenizer = tokenizer_splade_lexical
376
+ query_rep_str = "Query SPLADE-v3-Lexical Representation (Weighting):\n"
377
+ doc_rep_str = "Document SPLADE-v3-Lexical Representation (Weighting):\n"
378
+ is_binary = False
379
+ elif model_choice == "SPLADE-v3-Doc (binary)":
380
+ query_vector = get_splade_doc_vector(query_text)
381
+ doc_vector = get_splade_doc_vector(doc_text)
382
+ selected_tokenizer = tokenizer_splade_doc
383
+ query_rep_str = "Query SPLADE-v3-Doc Representation (Binary):\n"
384
+ doc_rep_str = "Document SPLADE-v3-Doc Representation (Binary):\n"
385
+ is_binary = True
386
+ else:
387
+ return "Please select a model.", "", ""
388
+
389
+ if query_vector is None or doc_vector is None:
390
+ return "Failed to generate one or both vectors. Please check model loading.", "", ""
391
+
392
+ # Calculate dot product
393
+ dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
394
+
395
+ # Format representations
396
+ query_rep_str += format_sparse_vector_output(query_vector, selected_tokenizer, is_binary)
397
+ doc_rep_str += format_sparse_vector_output(doc_vector, selected_tokenizer, is_binary)
398
+
399
+ # Combine output
400
+ full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
401
+ full_output += "---\n\n"
402
+ full_output += f"{query_rep_str}\n\n---\n\n{doc_rep_str}"
403
+
404
+ return full_output
405
+
406
 
407
  # --- Gradio Interface Setup with Tabs ---
408
  with gr.Blocks(title="SPLADE Demos") as demo:
409
+ gr.Markdown("# 🌌 SPLADE Demos: Sparse Representation Explorer and Retriever") # Updated title
410
+ gr.Markdown("Explore different SPLADE models and their sparse representation types, and calculate similarity between query and document representations.") # Updated description
411
 
412
  with gr.Tabs():
413
  with gr.TabItem("Sparse Representation Explorer"):
 
434
  allow_flagging="never",
435
  # live=True # Setting live=True might be slow for complex models on every keystroke
436
  )
437
+
438
+ with gr.TabItem("Query-Document Dot Product Calculator"): # NEW TAB
439
+ gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
440
+ gr.Markdown("Select a SPLADE model to encode both your query and document, then see their sparse representations and their similarity score.")
441
+ gr.Interface(
442
+ fn=calculate_dot_product_and_representations,
443
+ inputs=[
444
+ gr.Radio(
445
+ [
446
+ "SPLADE-cocondenser-distil (weighting and expansion)",
447
+ "SPLADE-v3-Lexical (weighting)",
448
+ "SPLADE-v3-Doc (binary)"
449
+ ],
450
+ label="Choose Encoding Model",
451
+ value="SPLADE-cocondenser-distil (weighting and expansion)"
452
+ ),
453
+ gr.Textbox(
454
+ lines=3,
455
+ label="Enter Query Text:",
456
+ placeholder="e.g., best pizza in Naples"
457
+ ),
458
+ gr.Textbox(
459
+ lines=5,
460
+ label="Enter Document Text:",
461
+ placeholder="e.g., Naples is famous for its delicious pizza, known for its soft, chewy crust and fresh ingredients."
462
+ )
463
+ ],
464
+ outputs=gr.Markdown(),
465
+ allow_flagging="never"
466
+ )
467
 
468
+ demo.launch()