Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -430,18 +430,73 @@ def calculate_dot_product_and_representations_independent(query_model_choice, do
|
|
430 |
if query_vector is None or doc_vector is None:
|
431 |
return "Failed to generate one or both vectors. Please check model loading and input text.", ""
|
432 |
|
433 |
-
# Calculate dot product
|
434 |
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
|
435 |
|
436 |
-
# Format representations
|
437 |
query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
|
438 |
doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
|
439 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
# Combine output into a single string for the Markdown component
|
441 |
-
|
442 |
-
full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
|
443 |
full_output += "---\n\n"
|
444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
# Query Representation
|
446 |
full_output += f"#### Query Representation ({query_model_name_display}):\n" # Smaller heading for sub-section
|
447 |
full_output += f"> {query_main_rep_str}\n" # Using blockquote for the sparse list
|
|
|
430 |
if query_vector is None or doc_vector is None:
|
431 |
return "Failed to generate one or both vectors. Please check model loading and input text.", ""
|
432 |
|
433 |
+
# Calculate overall dot product
|
434 |
dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
|
435 |
|
436 |
+
# Format representations for display
|
437 |
query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
|
438 |
doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
|
439 |
|
440 |
+
# --- NEW FEATURE: Calculate dot product of overlapping terms ---
|
441 |
+
overlapping_terms_dot_products = {}
|
442 |
+
query_indices = torch.nonzero(query_vector).squeeze().cpu()
|
443 |
+
doc_indices = torch.nonzero(doc_vector).squeeze().cpu()
|
444 |
+
|
445 |
+
# Handle cases where vectors are empty or single element
|
446 |
+
if query_indices.dim() == 0 and query_indices.numel() == 1:
|
447 |
+
query_indices = query_indices.unsqueeze(0)
|
448 |
+
if doc_indices.dim() == 0 and doc_indices.numel() == 1:
|
449 |
+
doc_indices = doc_indices.unsqueeze(0)
|
450 |
+
|
451 |
+
# Convert indices to sets for efficient intersection
|
452 |
+
query_index_set = set(query_indices.tolist())
|
453 |
+
doc_index_set = set(doc_indices.tolist())
|
454 |
+
|
455 |
+
common_indices = sorted(list(query_index_set.intersection(doc_index_set)))
|
456 |
+
|
457 |
+
if common_indices:
|
458 |
+
for idx in common_indices:
|
459 |
+
query_weight = query_vector[idx].item()
|
460 |
+
doc_weight = doc_vector[idx].item()
|
461 |
+
term = query_tokenizer.decode([idx]) # Tokenizers should be the same for this purpose
|
462 |
+
if term not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(term.strip()) > 0:
|
463 |
+
overlapping_terms_dot_products[term] = query_weight * doc_weight
|
464 |
+
|
465 |
+
sorted_overlapping_dot_products = sorted(
|
466 |
+
overlapping_terms_dot_products.items(),
|
467 |
+
key=lambda item: item[1],
|
468 |
+
reverse=True
|
469 |
+
)
|
470 |
+
# --- End NEW FEATURE ---
|
471 |
+
|
472 |
# Combine output into a single string for the Markdown component
|
473 |
+
full_output = f"### Overall Dot Product Score: {dot_product:.6f}\n\n"
|
|
|
474 |
full_output += "---\n\n"
|
475 |
|
476 |
+
# Overlapping Terms Dot Products
|
477 |
+
if sorted_overlapping_dot_products:
|
478 |
+
full_output += "### Dot Products of Overlapping Terms:\n"
|
479 |
+
full_output += "*(Term: Query_Weight x Document_Weight = Product)*\n\n"
|
480 |
+
overlap_list = []
|
481 |
+
for term, product_val in sorted_overlapping_dot_products:
|
482 |
+
# Get individual weights for display
|
483 |
+
query_weight = query_vector[query_tokenizer.encode(term, add_special_tokens=False)[0]].item()
|
484 |
+
doc_weight = doc_vector[doc_tokenizer.encode(term, add_special_tokens=False)[0]].item()
|
485 |
+
|
486 |
+
if query_is_binary and doc_is_binary:
|
487 |
+
overlap_list.append(f"**{term}**: 1.0000 x 1.0000 = {product_val:.4f}")
|
488 |
+
elif query_is_binary:
|
489 |
+
overlap_list.append(f"**{term}**: 1.0000 x {doc_weight:.4f} = {product_val:.4f}")
|
490 |
+
elif doc_is_binary:
|
491 |
+
overlap_list.append(f"**{term}**: {query_weight:.4f} x 1.0000 = {product_val:.4f}")
|
492 |
+
else:
|
493 |
+
overlap_list.append(f"**{term}**: {query_weight:.4f} x {doc_weight:.4f} = {product_val:.4f}")
|
494 |
+
full_output += ", ".join(overlap_list) + ".\n\n"
|
495 |
+
full_output += "---\n\n"
|
496 |
+
else:
|
497 |
+
full_output += "### No Overlapping Terms Found.\n\n"
|
498 |
+
full_output += "---\n\n"
|
499 |
+
|
500 |
# Query Representation
|
501 |
full_output += f"#### Query Representation ({query_model_name_display}):\n" # Smaller heading for sub-section
|
502 |
full_output += f"> {query_main_rep_str}\n" # Using blockquote for the sparse list
|