TextLSRDemo / app.py
SiddharthAK's picture
output all nonzero terms (instead of top 20)
45b666a verified
raw
history blame
6.82 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
import torch
# --- Model Loading ---
tokenizer_splade = None
model_splade = None
tokenizer_unicoil = None
model_unicoil = None
# Load SPLADE v3 model
try:
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade.eval() # Set to evaluation mode for inference
print("SPLADE v3 model loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE model: {e}")
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
# Load UNICOIL model for binary sparse encoding
# Load UNICOIL model for binary sparse encoding
try:
unicoil_model_name = "castorini/unicoil-msmarco-passage"
tokenizer_unicoil = AutoTokenizer.from_pretrained(unicoil_model_name)
# --- FIX IS HERE ---
model_unicoil = AutoModelForMaskedLM.from_pretrained(unicoil_model_name)
# -------------------
model_unicoil.eval() # Set to evaluation mode for inference
print(f"UNICOIL model '{unicoil_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading UNICOIL model: {e}")
print(f"Please ensure '{unicoil_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# --- Core Representation Functions ---
def get_splade_representation(text):
if tokenizer_splade is None or model_splade is None:
return "SPLADE model is not loaded. Please check the console for loading errors."
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
else:
return "Model output structure not as expected for SPLADE. 'logits' not found."
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices]
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "SPLADE Representation (All Non-Zero Terms):\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
for term, weight in sorted_representation:
formatted_output += f"- **{term}**: {weight:.4f}\n"
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
return formatted_output
def get_unicoil_binary_representation(text):
if tokenizer_unicoil is None or model_unicoil is None:
return "UNICOIL model is not loaded. Please check the console for loading errors."
inputs = tokenizer_unicoil(text, return_tensors="pt", padding=True, truncation=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
inputs = {k: v.to(model_unicoil.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_unicoil(**inputs)
if not hasattr(output, "logits"):
return "UNICOIL model output structure not as expected. 'logits' not found."
logits = output.logits.squeeze(0) # [seq_len, vocab_size]
token_ids = input_ids.squeeze(0) # [seq_len]
mask = attention_mask.squeeze(0) # [seq_len]
transformed_scores = torch.log(1 + torch.exp(logits)) # softplus
token_scores = transformed_scores[range(len(token_ids)), token_ids] # only scores for input tokens
token_scores = token_scores * mask # mask out padding
# Binarize: threshold scores > 0.5 (tune as needed)
binary_mask = (token_scores > 0.5)
activated_token_ids = token_ids[binary_mask].cpu().tolist()
# Map token ids to strings
binary_terms = {}
for token_id in activated_token_ids:
decoded_token = tokenizer_unicoil.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
binary_terms[decoded_token] = 1
sorted_binary_terms = sorted(binary_terms.items(), key=lambda item: item[0])
formatted_output = "UNICOIL Binary Sparse Representation (Activated Terms):\n"
if not sorted_binary_terms:
formatted_output += "No significant terms activated for this input.\n"
else:
for i, (term, _) in enumerate(sorted_binary_terms):
if i >= 50:
formatted_output += f"...and {len(sorted_binary_terms) - 50} more terms.\n"
break
formatted_output += f"- **{term}**\n"
formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
formatted_output += f"Total activated terms: {len(sorted_binary_terms)}\n"
formatted_output += f"Sparsity: {1 - (len(sorted_binary_terms) / tokenizer_unicoil.vocab_size):.2%}\n"
return formatted_output
# --- Unified Prediction Function for Gradio ---
def predict_representation(model_choice, text):
if model_choice == "SPLADE":
return get_splade_representation(text)
elif model_choice == "UNICOIL (Binary Sparse)":
return get_unicoil_binary_representation(text)
else:
return "Please select a model."
# --- Gradio Interface Setup ---
demo = gr.Interface(
fn=predict_representation,
inputs=[
gr.Radio(
["SPLADE", "UNICOIL (Binary Sparse)"], # Added UNICOIL option
label="Choose Representation Model",
value="SPLADE" # Default selection
),
gr.Textbox(
lines=5,
label="Enter your query or document text here:",
placeholder="e.g., Why is Padua the nicest city in Italy?"
)
],
outputs=gr.Markdown(),
title="🌌 Sparse and Binary Sparse Representation Generator",
description="Enter any text to see its SPLADE sparse vector or UNICOIL binary sparse representation.",
allow_flagging="never"
)
# Launch the Gradio app
demo.launch()