TextLSRDemo / app.py
SiddharthAK's picture
Update app.py
f4c84bc verified
raw
history blame
12.3 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
import torch
# --- Model Loading ---
tokenizer_splade = None
model_splade = None
tokenizer_splade_lexical = None
model_splade_lexical = None
tokenizer_splade_doc = None # New tokenizer for SPLADE-v3-Doc
model_splade_doc = None # New model for SPLADE-v3-Doc
# Load SPLADE v3 model (original)
try:
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade.eval() # Set to evaluation mode for inference
print("SPLADE v3 (cocondenser) model loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE (cocondenser) model: {e}")
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
# Load SPLADE v3 Lexical model
try:
splade_lexical_model_name = "naver/splade-v3-lexical"
tokenizer_splade_lexical = AutoTokenizer.from_pretrained(splade_lexical_model_name)
model_splade_lexical = AutoModelForMaskedLM.from_pretrained(splade_lexical_model_name)
model_splade_lexical.eval() # Set to evaluation mode for inference
print(f"SPLADE v3 Lexical model '{splade_lexical_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE v3 Lexical model: {e}")
print(f"Please ensure '{splade_lexical_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# Load SPLADE v3 Doc model (NEW)
try:
splade_doc_model_name = "naver/splade-v3-doc"
tokenizer_splade_doc = AutoTokenizer.from_pretrained(splade_doc_model_name)
model_splade_doc = AutoModelForMaskedLM.from_pretrained(splade_doc_model_name)
model_splade_doc.eval() # Set to evaluation mode for inference
print(f"SPLADE v3 Doc model '{splade_doc_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE v3 Doc model: {e}")
print(f"Please ensure '{splade_doc_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# --- Helper function for lexical mask (still needed for splade-v3-lexical) ---
def create_lexical_bow_mask(input_ids, vocab_size, tokenizer):
"""
Creates a binary bag-of-words mask from input_ids,
zeroing out special tokens and padding.
"""
bow_mask = torch.zeros(vocab_size, device=input_ids.device)
meaningful_token_ids = []
for token_id in input_ids.squeeze().tolist():
if token_id not in [
tokenizer.pad_token_id,
tokenizer.cls_token_id,
tokenizer.sep_token_id,
tokenizer.mask_token_id,
tokenizer.unk_token_id
]:
meaningful_token_ids.append(token_id)
if meaningful_token_ids:
bow_mask[list(set(meaningful_token_ids))] = 1
return bow_mask.unsqueeze(0)
# --- Core Representation Functions ---
def get_splade_representation(text):
if tokenizer_splade is None or model_splade is None:
return "SPLADE (cocondenser) model is not loaded. Please check the console for loading errors."
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
else:
return "Model output structure not as expected for SPLADE (cocondenser). 'logits' not found."
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices]
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "SPLADE (cocondenser) Representation (All Non-Zero Terms):\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
for term, weight in sorted_representation:
formatted_output += f"- **{term}**: {weight:.4f}\n"
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
return formatted_output
def get_splade_lexical_representation(text):
if tokenizer_splade_lexical is None or model_splade_lexical is None:
return "SPLADE v3 Lexical model is not loaded. Please check the console for loading errors."
inputs = tokenizer_splade_lexical(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_lexical.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_lexical(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(
torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1),
dim=1
)[0].squeeze()
else:
return "Model output structure not as expected for SPLADE v3 Lexical. 'logits' not found."
# --- Apply Lexical Mask (always applied for this function now) ---
vocab_size = tokenizer_splade_lexical.vocab_size
bow_mask = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_lexical
).squeeze()
splade_vector = splade_vector * bow_mask
# --- End Lexical Mask Logic ---
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices]
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade_lexical.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "SPLADE v3 Lexical Representation (All Non-Zero Terms):\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
for term, weight in sorted_representation:
formatted_output += f"- **{term}**: {weight:.4f}\n"
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_lexical.vocab_size):.2%}\n"
return formatted_output
# NEW: Function for SPLADE-v3-Doc representation (Binary Sparse)
def get_splade_doc_representation(text):
if tokenizer_splade_doc is None or model_splade_doc is None:
return "SPLADE v3 Doc model is not loaded. Please check the console for loading errors."
inputs = tokenizer_splade_doc(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade_doc.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade_doc(**inputs)
if not hasattr(output, "logits"):
return "SPLADE v3 Doc model output structure not as expected. 'logits' not found."
# For SPLADE-v3-Doc, the output is often a binary sparse vector.
# We will assume a simple binarization based on a threshold or selecting active tokens.
# A common way to get "binary" is to use softplus and then binarize, or directly binarize max logits.
# Given the "no weighting, no expansion" request, we'll aim for a strict presence check.
# Option 1: Binarize based on softplus output and threshold (similar to UNICOIL)
# This might still activate some "expanded" terms if the model predicts them strongly.
# transformed_scores = torch.log(1 + torch.exp(output.logits)) # Softplus
# splade_vector_raw = torch.max(transformed_scores * inputs['attention_mask'].unsqueeze(-1), dim=1).values
# binary_splade_vector = (splade_vector_raw > 0.5).float() # Binarize
# Option 2: Rely on the original BoW for terms, with 1 for presence
# This aligns best with "no weighting, no expansion"
vocab_size = tokenizer_splade_doc.vocab_size
binary_splade_vector = create_lexical_bow_mask(
inputs['input_ids'], vocab_size, tokenizer_splade_doc
).squeeze()
# We set values to 1 as it's a binary representation, not weighted
indices = torch.nonzero(binary_splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list): # Handle case where only one non-zero index
indices = [indices] if indices else [] # Ensure it's a list even if empty or single
# Values are all 1 for binary representation
values = [1.0] * len(indices)
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade_doc.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[0]) # Sort alphabetically for binary
formatted_output = "SPLADE v3 Doc Representation (Binary Sparse - Lexical Only):\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
# Display as terms with no weights as they are binary (value 1)
for i, (term, _) in enumerate(sorted_representation):
# Limit display for very long lists for readability
if i >= 50:
formatted_output += f"...and {len(sorted_representation) - 50} more terms.\n"
break
formatted_output += f"- **{term}**\n"
formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
formatted_output += f"Total activated terms: {len(indices)}\n"
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade_doc.vocab_size):.2%}\n"
return formatted_output
# --- Unified Prediction Function for Gradio ---
def predict_representation(model_choice, text):
if model_choice == "SPLADE (cocondenser)":
return get_splade_representation(text)
elif model_choice == "SPLADE-v3-Lexical":
# Always applies lexical mask for this option
return get_splade_lexical_representation(text)
elif model_choice == "SPLADE-v3-Doc": # Simplified to a single option
# This function now intrinsically handles binary, lexical-only output
return get_splade_doc_representation(text)
else:
return "Please select a model."
# --- Gradio Interface Setup ---
demo = gr.Interface(
fn=predict_representation,
inputs=[
gr.Radio(
[
"SPLADE (cocondenser)",
"SPLADE-v3-Lexical",
"SPLADE-v3-Doc" # Only one option for Doc model
],
label="Choose Representation Model",
value="SPLADE (cocondenser)" # Default selection
),
gr.Textbox(
lines=5,
label="Enter your query or document text here:",
placeholder="e.g., Why is Padua the nicest city in Italy?"
)
],
outputs=gr.Markdown(),
title="🌌 Sparse Representation Generator",
description="Enter any text to see its sparse vector representation.", # Simplified description
allow_flagging="never"
)
# Launch the Gradio app
demo.launch()