SiddharthAK commited on
Commit
8e4067c
·
verified ·
1 Parent(s): 7ffc91a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -4
app.py CHANGED
@@ -430,18 +430,73 @@ def calculate_dot_product_and_representations_independent(query_model_choice, do
430
  if query_vector is None or doc_vector is None:
431
  return "Failed to generate one or both vectors. Please check model loading and input text.", ""
432
 
433
- # Calculate dot product
434
  dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
435
 
436
- # Format representations - these functions now return two strings (main_output, info_output)
437
  query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
438
  doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  # Combine output into a single string for the Markdown component
441
- # Using Markdown's blockquote for representations can help them stand out as code/data
442
- full_output = f"### Dot Product Score: {dot_product:.6f}\n\n"
443
  full_output += "---\n\n"
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
  # Query Representation
446
  full_output += f"#### Query Representation ({query_model_name_display}):\n" # Smaller heading for sub-section
447
  full_output += f"> {query_main_rep_str}\n" # Using blockquote for the sparse list
 
430
  if query_vector is None or doc_vector is None:
431
  return "Failed to generate one or both vectors. Please check model loading and input text.", ""
432
 
433
+ # Calculate overall dot product
434
  dot_product = float(torch.dot(query_vector.cpu(), doc_vector.cpu()).item())
435
 
436
+ # Format representations for display
437
  query_main_rep_str, query_info_str = format_sparse_vector_output(query_vector, query_tokenizer, query_is_binary)
438
  doc_main_rep_str, doc_info_str = format_sparse_vector_output(doc_vector, doc_tokenizer, doc_is_binary)
439
 
440
+ # --- NEW FEATURE: Calculate dot product of overlapping terms ---
441
+ overlapping_terms_dot_products = {}
442
+ query_indices = torch.nonzero(query_vector).squeeze().cpu()
443
+ doc_indices = torch.nonzero(doc_vector).squeeze().cpu()
444
+
445
+ # Handle cases where vectors are empty or single element
446
+ if query_indices.dim() == 0 and query_indices.numel() == 1:
447
+ query_indices = query_indices.unsqueeze(0)
448
+ if doc_indices.dim() == 0 and doc_indices.numel() == 1:
449
+ doc_indices = doc_indices.unsqueeze(0)
450
+
451
+ # Convert indices to sets for efficient intersection
452
+ query_index_set = set(query_indices.tolist())
453
+ doc_index_set = set(doc_indices.tolist())
454
+
455
+ common_indices = sorted(list(query_index_set.intersection(doc_index_set)))
456
+
457
+ if common_indices:
458
+ for idx in common_indices:
459
+ query_weight = query_vector[idx].item()
460
+ doc_weight = doc_vector[idx].item()
461
+ term = query_tokenizer.decode([idx]) # Tokenizers should be the same for this purpose
462
+ if term not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(term.strip()) > 0:
463
+ overlapping_terms_dot_products[term] = query_weight * doc_weight
464
+
465
+ sorted_overlapping_dot_products = sorted(
466
+ overlapping_terms_dot_products.items(),
467
+ key=lambda item: item[1],
468
+ reverse=True
469
+ )
470
+ # --- End NEW FEATURE ---
471
+
472
  # Combine output into a single string for the Markdown component
473
+ full_output = f"### Overall Dot Product Score: {dot_product:.6f}\n\n"
 
474
  full_output += "---\n\n"
475
 
476
+ # Overlapping Terms Dot Products
477
+ if sorted_overlapping_dot_products:
478
+ full_output += "### Dot Products of Overlapping Terms:\n"
479
+ full_output += "*(Term: Query_Weight x Document_Weight = Product)*\n\n"
480
+ overlap_list = []
481
+ for term, product_val in sorted_overlapping_dot_products:
482
+ # Get individual weights for display
483
+ query_weight = query_vector[query_tokenizer.encode(term, add_special_tokens=False)[0]].item()
484
+ doc_weight = doc_vector[doc_tokenizer.encode(term, add_special_tokens=False)[0]].item()
485
+
486
+ if query_is_binary and doc_is_binary:
487
+ overlap_list.append(f"**{term}**: 1.0000 x 1.0000 = {product_val:.4f}")
488
+ elif query_is_binary:
489
+ overlap_list.append(f"**{term}**: 1.0000 x {doc_weight:.4f} = {product_val:.4f}")
490
+ elif doc_is_binary:
491
+ overlap_list.append(f"**{term}**: {query_weight:.4f} x 1.0000 = {product_val:.4f}")
492
+ else:
493
+ overlap_list.append(f"**{term}**: {query_weight:.4f} x {doc_weight:.4f} = {product_val:.4f}")
494
+ full_output += ", ".join(overlap_list) + ".\n\n"
495
+ full_output += "---\n\n"
496
+ else:
497
+ full_output += "### No Overlapping Terms Found.\n\n"
498
+ full_output += "---\n\n"
499
+
500
  # Query Representation
501
  full_output += f"#### Query Representation ({query_model_name_display}):\n" # Smaller heading for sub-section
502
  full_output += f"> {query_main_rep_str}\n" # Using blockquote for the sparse list