TextLSRDemo / app.py
SiddharthAK's picture
added unicoil
da3acda verified
raw
history blame
12.5 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel
import torch
# --- Model Loading ---
tokenizer_splade = None
model_splade = None
tokenizer_unicoil = None
model_unicoil = None
# Load SPLADE v3 model
try:
tokenizer_splade = AutoTokenizer.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade = AutoModelForMaskedLM.from_pretrained("naver/splade-cocondenser-selfdistil")
model_splade.eval() # Set to evaluation mode for inference
print("SPLADE v3 model loaded successfully!")
except Exception as e:
print(f"Error loading SPLADE model: {e}")
print("Please ensure you have accepted any user access agreements on the Hugging Face Hub page for 'naver/splade-cocondenser-selfdistil'.")
# Load UNICOIL model for binary sparse encoding
try:
# UNICOIL models are typically just AutoModel as they add a linear layer
# on top of a BERT-like encoder to predict weights.
# 'castorini/unicoil-msmarco-passage' is a common UNICOIL checkpoint.
unicoil_model_name = "castorini/unicoil-msmarco-passage"
tokenizer_unicoil = AutoTokenizer.from_pretrained(unicoil_model_name)
model_unicoil = AutoModel.from_pretrained(unicoil_model_name)
model_unicoil.eval() # Set to evaluation mode for inference
print(f"UNICOIL model '{unicoil_model_name}' loaded successfully!")
except Exception as e:
print(f"Error loading UNICOIL model: {e}")
print(f"Please ensure '{unicoil_model_name}' is accessible (check Hugging Face Hub for potential agreements).")
# --- Core Representation Functions ---
def get_splade_representation(text):
if tokenizer_splade is None or model_splade is None:
return "SPLADE model is not loaded. Please check the console for loading errors."
inputs = tokenizer_splade(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_splade.device) for k, v in inputs.items()}
with torch.no_grad():
output = model_splade(**inputs)
if hasattr(output, 'logits'):
splade_vector = torch.max(torch.log(1 + torch.relu(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
else:
return "Model output structure not as expected for SPLADE. 'logits' not found."
indices = torch.nonzero(splade_vector).squeeze().cpu().tolist()
if not isinstance(indices, list):
indices = [indices]
values = splade_vector[indices].cpu().tolist()
token_weights = dict(zip(indices, values))
meaningful_tokens = {}
for token_id, weight in token_weights.items():
decoded_token = tokenizer_splade.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
meaningful_tokens[decoded_token] = weight
sorted_representation = sorted(meaningful_tokens.items(), key=lambda item: item[1], reverse=True)
formatted_output = "SPLADE Representation (Top 20 Terms):\n"
if not sorted_representation:
formatted_output += "No significant terms found for this input.\n"
else:
for i, (term, weight) in enumerate(sorted_representation):
if i >= 20:
break
formatted_output += f"- **{term}**: {weight:.4f}\n"
formatted_output += "\n--- Raw SPLADE Vector Info ---\n"
formatted_output += f"Total non-zero terms in vector: {len(indices)}\n"
formatted_output += f"Sparsity: {1 - (len(indices) / tokenizer_splade.vocab_size):.2%}\n"
return formatted_output
def get_unicoil_binary_representation(text):
if tokenizer_unicoil is None or model_unicoil is None:
return "UNICOIL model is not loaded. Please check the console for loading errors."
inputs = tokenizer_unicoil(text, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model_unicoil.device) for k, v in inputs.items()}
with torch.no_grad():
# UNICOIL models often output a dictionary where 'token_scores' or similar
# contain the learned weights for each token. The structure can vary.
# For 'castorini/unicoil-msmarco-passage', the token scores are typically
# the last hidden state of the model, which is then mapped by a linear layer
# into the sparse weights. We might need to manually extract those,
# or the model itself might be set up to produce the weights directly.
# Based on typical UNICOIL implementations, we usually take the output
# from the last layer and map it to vocabulary size.
# In many UNICOIL variations, the model itself is designed to output
# the "re-weight" scores for each token. Let's assume for simplicity
# that the model's forward pass returns something that can be interpreted
# as per-token scores, perhaps in `output.last_hidden_state`
# and then a simple linear layer (not part of the AutoModel usually)
# would project this to vocab size.
# For simplicity and to fit `AutoModel`, we'll treat the last hidden state
# directly as the basis for term importance for now, which is common in similar models,
# or if the model already has a head, we use it.
# A more robust UNICOIL implementation would involve a specific head
# if not using AutoModelForMaskedLM. However, AutoModel gives us the
# last hidden states from which we can infer.
# For UNICOIL, we're interested in the weighted token scores.
# `model(**inputs)` will typically return a `BaseModelOutput`
# or `MaskedLMOutput` if it's based on an MLM.
# Let's assume the model's output provides the token importance.
# A common way to get UNICOIL scores if not explicitly provided as logits:
# It's usually a linear layer on top of the last hidden state.
# Since AutoModel just gives the base model, we'll mimic the output
# as a direct mapping if the model doesn't have a specific head for scores.
# However, looking at `castorini/unicoil-msmarco-passage`
# its `config.json` might give hints or the model itself is structured.
# Often, it uses `BertForMaskedLM` and then applies `log(1+relu)` to the logits.
# Let's assume it behaves similar to SPLADE for simplicity of extraction for now,
# or we might need to load it as `AutoModelForMaskedLM` if its internal structure
# is indeed like that, and then apply a binarization.
# Re-evaluating: UNICOIL typically *learns* explicit token weights.
# The common approach for UNICOIL with Hugging Face is indeed to load it
# as `AutoModelForMaskedLM` and use its `logits` output, similar to SPLADE,
# but with a different aggregation strategy.
# Let's verify the model type for 'castorini/unicoil-msmarco-passage'.
# Its config.json and architecture implies it's a BertForMaskedLM variant.
output = model_unicoil(**inputs) # This should be a BaseModelOutputWithPooling or similar
if not hasattr(output, 'logits'):
# If `model_unicoil` is an `AutoModel` without a classification head,
# we need to add a way to get per-token scores.
# This is where a custom model head or a specific model class would be needed.
# For `castorini/unicoil-msmarco-passage`, it *is* an MLM variant.
# So, `output.logits` *should* be available.
return "UNICOIL model output structure not as expected. 'logits' not found."
# UNICOIL's output is also typically per-token scores from the MLM head.
# For UNICOIL, the weights are often taken directly from the logits after pooling.
# Unlike SPLADE's log(1+ReLU), UNICOIL's approach can be simpler,
# sometimes just taking the maximum of logits (or similar pooling).
# A common binarization for UNICOIL is based on the sign of the re-weighted scores.
# Let's mimic a common UNICOIL interpretation for obtaining sparse weights
# from the logits. The weights are usually sparse and positive.
# We can apply a threshold for binarization.
# This is a simplification; actual UNICOIL might have specific layers.
# For `castorini/unicoil-msmarco-passage`, it uses the `log(1+exp(logits))` formulation
# followed by max pooling, then often binarization based on a threshold.
# Applying a common interpretation of UNICOIL-like score generation for sparse weights:
# Instead of `log(1+ReLU(logits))`, it often uses `torch.log(1 + torch.exp(output.logits))`.
# This is essentially the softplus function, which makes values positive and sparse.
# Get the sparse weights using the UNICOIL-like transformation
sparse_weights = torch.max(torch.log(1 + torch.exp(output.logits)) * inputs['attention_mask'].unsqueeze(-1), dim=1)[0].squeeze()
# --- Binarization Step for UNICOIL ---
# For true "binary sparse", we threshold these sparse weights.
# A common approach is to simply take any non-zero value as 1, and zero as 0.
# Or, define a small threshold for binarization if values are very small but non-zero.
# For simplicity, let's treat anything above a very small epsilon as 1.
# Convert to binary: 1 if weight > epsilon, else 0
threshold = 1e-6 # Define a small threshold for binarization
binary_sparse_vector = (sparse_weights > threshold).int()
# Get indices of the '1's in the binary vector
binary_indices = torch.nonzero(binary_sparse_vector).squeeze().cpu().tolist()
if not isinstance(binary_indices, list):
binary_indices = [binary_indices] if binary_indices.numel() > 0 else []
# Map token IDs back to terms for the binary representation
binary_terms = {}
for token_id in binary_indices:
decoded_token = tokenizer_unicoil.decode([token_id])
if decoded_token not in ["[CLS]", "[SEP]", "[PAD]", "[UNK]"] and len(decoded_token.strip()) > 0:
binary_terms[decoded_token] = 1 # Value is always 1 for binary
sorted_binary_terms = sorted(binary_terms.items(), key=lambda item: item[0]) # Sort by term for consistent display
formatted_output = "UNICOIL Binary Sparse Representation (Activated Terms):\n"
if not sorted_binary_terms:
formatted_output += "No significant terms activated for this input.\n"
else:
# Display up to 50 activated terms for readability
for i, (term, _) in enumerate(sorted_binary_terms):
if i >= 50:
break
formatted_output += f"- **{term}**\n" # Only show term, as weight is always 1
if len(sorted_binary_terms) > 50:
formatted_output += f"...and {len(sorted_binary_terms) - 50} more terms.\n"
formatted_output += "\n--- Raw Binary Sparse Vector Info ---\n"
formatted_output += f"Total activated terms: {len(binary_indices)}\n"
# Calculate sparsity based on the number of '1's vs. total vocabulary size
formatted_output += f"Sparsity: {1 - (len(binary_indices) / tokenizer_unicoil.vocab_size):.2%}\n"
return formatted_output
# --- Unified Prediction Function for Gradio ---
def predict_representation(model_choice, text):
if model_choice == "SPLADE":
return get_splade_representation(text)
elif model_choice == "UNICOIL (Binary Sparse)":
return get_unicoil_binary_representation(text)
else:
return "Please select a model."
# --- Gradio Interface Setup ---
demo = gr.Interface(
fn=predict_representation,
inputs=[
gr.Radio(
["SPLADE", "UNICOIL (Binary Sparse)"], # Added UNICOIL option
label="Choose Representation Model",
value="SPLADE" # Default selection
),
gr.Textbox(
lines=5,
label="Enter your query or document text here:",
placeholder="e.g., Why is Padua the nicest city in Italy?"
)
],
outputs=gr.Markdown(),
title="🌌 Sparse and Binary Sparse Representation Generator",
description="Enter any text to see its SPLADE sparse vector or UNICOIL binary sparse representation.",
allow_flagging="never"
)
# Launch the Gradio app
demo.launch()