SiddharthAK commited on
Commit
3bcd060
·
verified ·
1 Parent(s): 372cab2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -45
app.py CHANGED
@@ -244,7 +244,8 @@ def predict_representation_explorer(model_choice, text):
244
  else:
245
  return "Please select a model."
246
 
247
- # --- NEW: Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
 
248
  def get_splade_cocondenser_vector(text):
249
  if tokenizer_splade is None or model_splade is None:
250
  return None
@@ -307,7 +308,8 @@ def get_splade_doc_vector(text):
307
  return None
308
 
309
 
310
- # --- NEW: Function to get formatted representation from a raw vector and tokenizer ---
 
311
  def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
312
  if splade_vector is None:
313
  return "Failed to generate vector."
@@ -353,48 +355,42 @@ def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
353
  return formatted_output
354
 
355
 
356
- # --- NEW: Dot Product Calculation Function for the new tab ---
357
- def calculate_dot_product_and_representations(model_choice, query_text, doc_text):
358
- query_vector = None
359
- doc_vector = None
360
- query_rep_str = ""
361
- doc_rep_str = ""
362
-
363
- selected_tokenizer = None
364
-
365
- if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
366
- query_vector = get_splade_cocondenser_vector(query_text)
367
- doc_vector = get_splade_cocondenser_vector(doc_text)
368
- selected_tokenizer = tokenizer_splade
369
- query_rep_str = "Query SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
370
- doc_rep_str = "Document SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
371
- is_binary = False
372
- elif model_choice == "SPLADE-v3-Lexical (weighting)":
373
- query_vector = get_splade_lexical_vector(query_text)
374
- doc_vector = get_splade_lexical_vector(doc_text)
375
- selected_tokenizer = tokenizer_splade_lexical
376
- query_rep_str = "Query SPLADE-v3-Lexical Representation (Weighting):\n"
377
- doc_rep_str = "Document SPLADE-v3-Lexical Representation (Weighting):\n"
378
- is_binary = False
379
- elif model_choice == "SPLADE-v3-Doc (binary)":
380
- query_vector = get_splade_doc_vector(query_text)
381
- doc_vector = get_splade_doc_vector(doc_text)
382
- selected_tokenizer = tokenizer_splade_doc
383
- query_rep_str = "Query SPLADE-v3-Doc Representation (Binary):\n"
384
- doc_rep_str = "Document SPLADE-v3-Doc Representation (Binary):\n"
385
- is_binary = True
386
  else:
387
- return "Please select a model.", "", ""
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  if query_vector is None or doc_vector is None:
390
- return "Failed to generate one or both vectors. Please check model loading.", "", ""
391
 
392
  # Calculate dot product
 
 
393
  dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
394
 
395
  # Format representations
396
- query_rep_str += format_sparse_vector_output(query_vector, selected_tokenizer, is_binary)
397
- doc_rep_str += format_sparse_vector_output(doc_vector, selected_tokenizer, is_binary)
 
 
 
398
 
399
  # Combine output
400
  full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
@@ -437,18 +433,27 @@ with gr.Blocks(title="SPLADE Demos") as demo:
437
 
438
  with gr.TabItem("Query-Document Dot Product Calculator"): # NEW TAB
439
  gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
440
- gr.Markdown("Select a SPLADE model to encode both your query and document, then see their sparse representations and their similarity score.")
 
 
 
 
 
 
 
 
441
  gr.Interface(
442
- fn=calculate_dot_product_and_representations,
443
  inputs=[
444
  gr.Radio(
445
- [
446
- "SPLADE-cocondenser-distil (weighting and expansion)",
447
- "SPLADE-v3-Lexical (weighting)",
448
- "SPLADE-v3-Doc (binary)"
449
- ],
450
- label="Choose Encoding Model",
451
- value="SPLADE-cocondenser-distil (weighting and expansion)"
 
452
  ),
453
  gr.Textbox(
454
  lines=3,
 
244
  else:
245
  return "Please select a model."
246
 
247
+ # --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
248
+ # These functions remain unchanged from the previous iteration, as they return the raw tensors.
249
  def get_splade_cocondenser_vector(text):
250
  if tokenizer_splade is None or model_splade is None:
251
  return None
 
308
  return None
309
 
310
 
311
+ # --- Function to get formatted representation from a raw vector and tokenizer ---
312
+ # This function remains unchanged as it's a generic formatter for any sparse vector.
313
  def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
314
  if splade_vector is None:
315
  return "Failed to generate vector."
 
355
  return formatted_output
356
 
357
 
358
+ # --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
359
+ def get_model_assets(model_choice_str):
360
+ if model_choice_str == "SPLADE-cocondenser-distil (weighting and expansion)":
361
+ return get_splade_cocondenser_vector, tokenizer_splade, False, "SPLADE-cocondenser-distil (Weighting and Expansion)"
362
+ elif model_choice_str == "SPLADE-v3-Lexical (weighting)":
363
+ return get_splade_lexical_vector, tokenizer_splade_lexical, False, "SPLADE-v3-Lexical (Weighting)"
364
+ elif model_choice_str == "SPLADE-v3-Doc (binary)":
365
+ return get_splade_doc_vector, tokenizer_splade_doc, True, "SPLADE-v3-Doc (Binary)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  else:
367
+ return None, None, False, "Unknown Model"
368
+
369
+ # --- MODIFIED: Dot Product Calculation Function for the new tab ---
370
+ def calculate_dot_product_and_representations_independent(query_model_choice, doc_model_choice, query_text, doc_text):
371
+ query_vector_fn, query_tokenizer, query_is_binary, query_model_name_display = get_model_assets(query_model_choice)
372
+ doc_vector_fn, doc_tokenizer, doc_is_binary, doc_model_name_display = get_model_assets(doc_model_choice)
373
+
374
+ if query_vector_fn is None or doc_vector_fn is None:
375
+ return "Please select valid models for both query and document encoding.", "", ""
376
+
377
+ query_vector = query_vector_fn(query_text)
378
+ doc_vector = doc_vector_fn(doc_text)
379
 
380
  if query_vector is None or doc_vector is None:
381
+ return "Failed to generate one or both vectors. Please check model loading and input text.", "", ""
382
 
383
  # Calculate dot product
384
+ # Ensure both vectors are on CPU before dot product to avoid device mismatch issues
385
+ # and to ensure .item() works reliably for conversion to float.
386
  dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
387
 
388
  # Format representations
389
+ query_rep_str = f"Query Representation ({query_model_name_display}):\n"
390
+ query_rep_str += format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
391
+
392
+ doc_rep_str = f"Document Representation ({doc_model_name_display}):\n"
393
+ doc_rep_str += format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
394
 
395
  # Combine output
396
  full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
 
433
 
434
  with gr.TabItem("Query-Document Dot Product Calculator"): # NEW TAB
435
  gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
436
+ gr.Markdown("Select **independent** SPLADE models to encode your query and document, then see their sparse representations and their similarity score.")
437
+
438
+ # Define the common model choices for cleaner code
439
+ model_choices = [
440
+ "SPLADE-cocondenser-distil (weighting and expansion)",
441
+ "SPLADE-v3-Lexical (weighting)",
442
+ "SPLADE-v3-Doc (binary)"
443
+ ]
444
+
445
  gr.Interface(
446
+ fn=calculate_dot_product_and_representations_independent, # MODIFIED FUNCTION NAME
447
  inputs=[
448
  gr.Radio(
449
+ model_choices,
450
+ label="Choose Query Encoding Model",
451
+ value="SPLADE-cocondenser-distil (weighting and expansion)" # Default value
452
+ ),
453
+ gr.Radio(
454
+ model_choices,
455
+ label="Choose Document Encoding Model",
456
+ value="SPLADE-cocondenser-distil (weighting and expansion)" # Default value
457
  ),
458
  gr.Textbox(
459
  lines=3,