Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -244,7 +244,8 @@ def predict_representation_explorer(model_choice, text):
|
|
244 |
else:
|
245 |
return "Please select a model."
|
246 |
|
247 |
-
# ---
|
|
|
248 |
def get_splade_cocondenser_vector(text):
|
249 |
if tokenizer_splade is None or model_splade is None:
|
250 |
return None
|
@@ -307,7 +308,8 @@ def get_splade_doc_vector(text):
|
|
307 |
return None
|
308 |
|
309 |
|
310 |
-
# ---
|
|
|
311 |
def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
|
312 |
if splade_vector is None:
|
313 |
return "Failed to generate vector."
|
@@ -353,48 +355,42 @@ def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
|
|
353 |
return formatted_output
|
354 |
|
355 |
|
356 |
-
# --- NEW:
|
357 |
-
def
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
if model_choice == "SPLADE-cocondenser-distil (weighting and expansion)":
|
366 |
-
query_vector = get_splade_cocondenser_vector(query_text)
|
367 |
-
doc_vector = get_splade_cocondenser_vector(doc_text)
|
368 |
-
selected_tokenizer = tokenizer_splade
|
369 |
-
query_rep_str = "Query SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
|
370 |
-
doc_rep_str = "Document SPLADE-cocondenser-distil Representation (Weighting and Expansion):\n"
|
371 |
-
is_binary = False
|
372 |
-
elif model_choice == "SPLADE-v3-Lexical (weighting)":
|
373 |
-
query_vector = get_splade_lexical_vector(query_text)
|
374 |
-
doc_vector = get_splade_lexical_vector(doc_text)
|
375 |
-
selected_tokenizer = tokenizer_splade_lexical
|
376 |
-
query_rep_str = "Query SPLADE-v3-Lexical Representation (Weighting):\n"
|
377 |
-
doc_rep_str = "Document SPLADE-v3-Lexical Representation (Weighting):\n"
|
378 |
-
is_binary = False
|
379 |
-
elif model_choice == "SPLADE-v3-Doc (binary)":
|
380 |
-
query_vector = get_splade_doc_vector(query_text)
|
381 |
-
doc_vector = get_splade_doc_vector(doc_text)
|
382 |
-
selected_tokenizer = tokenizer_splade_doc
|
383 |
-
query_rep_str = "Query SPLADE-v3-Doc Representation (Binary):\n"
|
384 |
-
doc_rep_str = "Document SPLADE-v3-Doc Representation (Binary):\n"
|
385 |
-
is_binary = True
|
386 |
else:
|
387 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
|
389 |
if query_vector is None or doc_vector is None:
|
390 |
-
return "Failed to generate one or both vectors. Please check model loading.", "", ""
|
391 |
|
392 |
# Calculate dot product
|
|
|
|
|
393 |
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
|
394 |
|
395 |
# Format representations
|
396 |
-
query_rep_str
|
397 |
-
|
|
|
|
|
|
|
398 |
|
399 |
# Combine output
|
400 |
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
|
@@ -437,18 +433,27 @@ with gr.Blocks(title="SPLADE Demos") as demo:
|
|
437 |
|
438 |
with gr.TabItem("Query-Document Dot Product Calculator"): # NEW TAB
|
439 |
gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
|
440 |
-
gr.Markdown("Select
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
gr.Interface(
|
442 |
-
fn=
|
443 |
inputs=[
|
444 |
gr.Radio(
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
|
|
452 |
),
|
453 |
gr.Textbox(
|
454 |
lines=3,
|
|
|
244 |
else:
|
245 |
return "Please select a model."
|
246 |
|
247 |
+
# --- Core Representation Functions (Return RAW TENSORS - for Dot Product Tab) ---
|
248 |
+
# These functions remain unchanged from the previous iteration, as they return the raw tensors.
|
249 |
def get_splade_cocondenser_vector(text):
|
250 |
if tokenizer_splade is None or model_splade is None:
|
251 |
return None
|
|
|
308 |
return None
|
309 |
|
310 |
|
311 |
+
# --- Function to get formatted representation from a raw vector and tokenizer ---
|
312 |
+
# This function remains unchanged as it's a generic formatter for any sparse vector.
|
313 |
def format_sparse_vector_output(splade_vector, tokenizer, is_binary=False):
|
314 |
if splade_vector is None:
|
315 |
return "Failed to generate vector."
|
|
|
355 |
return formatted_output
|
356 |
|
357 |
|
358 |
+
# --- NEW/MODIFIED: Helper to get the correct vector function, tokenizer, and binary flag ---
|
359 |
+
def get_model_assets(model_choice_str):
|
360 |
+
if model_choice_str == "SPLADE-cocondenser-distil (weighting and expansion)":
|
361 |
+
return get_splade_cocondenser_vector, tokenizer_splade, False, "SPLADE-cocondenser-distil (Weighting and Expansion)"
|
362 |
+
elif model_choice_str == "SPLADE-v3-Lexical (weighting)":
|
363 |
+
return get_splade_lexical_vector, tokenizer_splade_lexical, False, "SPLADE-v3-Lexical (Weighting)"
|
364 |
+
elif model_choice_str == "SPLADE-v3-Doc (binary)":
|
365 |
+
return get_splade_doc_vector, tokenizer_splade_doc, True, "SPLADE-v3-Doc (Binary)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
else:
|
367 |
+
return None, None, False, "Unknown Model"
|
368 |
+
|
369 |
+
# --- MODIFIED: Dot Product Calculation Function for the new tab ---
|
370 |
+
def calculate_dot_product_and_representations_independent(query_model_choice, doc_model_choice, query_text, doc_text):
|
371 |
+
query_vector_fn, query_tokenizer, query_is_binary, query_model_name_display = get_model_assets(query_model_choice)
|
372 |
+
doc_vector_fn, doc_tokenizer, doc_is_binary, doc_model_name_display = get_model_assets(doc_model_choice)
|
373 |
+
|
374 |
+
if query_vector_fn is None or doc_vector_fn is None:
|
375 |
+
return "Please select valid models for both query and document encoding.", "", ""
|
376 |
+
|
377 |
+
query_vector = query_vector_fn(query_text)
|
378 |
+
doc_vector = doc_vector_fn(doc_text)
|
379 |
|
380 |
if query_vector is None or doc_vector is None:
|
381 |
+
return "Failed to generate one or both vectors. Please check model loading and input text.", "", ""
|
382 |
|
383 |
# Calculate dot product
|
384 |
+
# Ensure both vectors are on CPU before dot product to avoid device mismatch issues
|
385 |
+
# and to ensure .item() works reliably for conversion to float.
|
386 |
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
|
387 |
|
388 |
# Format representations
|
389 |
+
query_rep_str = f"Query Representation ({query_model_name_display}):\n"
|
390 |
+
query_rep_str += format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
|
391 |
+
|
392 |
+
doc_rep_str = f"Document Representation ({doc_model_name_display}):\n"
|
393 |
+
doc_rep_str += format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
|
394 |
|
395 |
# Combine output
|
396 |
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
|
|
|
433 |
|
434 |
with gr.TabItem("Query-Document Dot Product Calculator"): # NEW TAB
|
435 |
gr.Markdown("### Calculate Dot Product Similarity between Query and Document")
|
436 |
+
gr.Markdown("Select **independent** SPLADE models to encode your query and document, then see their sparse representations and their similarity score.")
|
437 |
+
|
438 |
+
# Define the common model choices for cleaner code
|
439 |
+
model_choices = [
|
440 |
+
"SPLADE-cocondenser-distil (weighting and expansion)",
|
441 |
+
"SPLADE-v3-Lexical (weighting)",
|
442 |
+
"SPLADE-v3-Doc (binary)"
|
443 |
+
]
|
444 |
+
|
445 |
gr.Interface(
|
446 |
+
fn=calculate_dot_product_and_representations_independent, # MODIFIED FUNCTION NAME
|
447 |
inputs=[
|
448 |
gr.Radio(
|
449 |
+
model_choices,
|
450 |
+
label="Choose Query Encoding Model",
|
451 |
+
value="SPLADE-cocondenser-distil (weighting and expansion)" # Default value
|
452 |
+
),
|
453 |
+
gr.Radio(
|
454 |
+
model_choices,
|
455 |
+
label="Choose Document Encoding Model",
|
456 |
+
value="SPLADE-cocondenser-distil (weighting and expansion)" # Default value
|
457 |
),
|
458 |
gr.Textbox(
|
459 |
lines=3,
|