File size: 10,181 Bytes
25d2eb7 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a 25d2eb7 2827b8a 6188d2c 2827b8a 6188d2c 25d2eb7 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a 6188d2c 2827b8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import gradio as gr
from datasets import load_dataset
import numpy as np
from model2vec import StaticModel
from reach import Reach
from tqdm import tqdm
import difflib
def display_word_differences(x: str, y: str) -> str:
diff = difflib.ndiff(x.split(), y.split())
return " ".join([word for word in diff if word.startswith(('+', '-'))])
def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
"""
Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
"""
reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])
# Use a set for deduplicated indices and keep track of duplicates
deduplicated_indices = set(range(len(embedding_matrix))) # Start with all indices as deduplicated
duplicate_to_original_mapping = {}
results = reach.nearest_neighbor_threshold(
embedding_matrix,
threshold=threshold,
batch_size=batch_size,
show_progressbar=False # Disable internal progress bar
)
# Process duplicates
for i, similar_items in enumerate(tqdm(results, desc="Processing duplicates")):
if i not in deduplicated_indices:
continue # Skip already marked duplicates
# Similar items are returned as (index, score), we are only interested in the index
similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]
# Mark similar documents as duplicates and map them to the original
for sim_idx in similar_indices:
if sim_idx in deduplicated_indices:
deduplicated_indices.remove(sim_idx)
duplicate_to_original_mapping[sim_idx] = i # Map duplicate to original
return np.array(list(deduplicated_indices)), duplicate_to_original_mapping
def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[list[int], dict[int, int]]:
"""
Deduplicate embeddings across two datasets and return the indices of duplicates between them.
"""
reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])
# Keep track of duplicates in the second dataset
duplicate_indices_in_test = []
duplicate_to_original_mapping = {}
# Find nearest neighbors from the test set in the train set
results = reach.nearest_neighbor_threshold(
embedding_matrix_2,
threshold=threshold,
batch_size=batch_size,
show_progressbar=False # Disable internal progress bar
)
# Process duplicates
for i, similar_items in enumerate(tqdm(results, desc="Processing duplicates")):
# Similar items are returned as (index, score), we are only interested in the index
similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold] # Keep those above the threshold
# If we find a similar item in the train set, mark it as a duplicate
if similar_indices:
duplicate_indices_in_test.append(i)
duplicate_to_original_mapping[i] = similar_indices[0] # Map duplicate in test to original in train
return duplicate_indices_in_test, duplicate_to_original_mapping
def perform_deduplication(
deduplication_type,
dataset1_name,
dataset1_split,
dataset2_name,
dataset2_split,
text_column_name,
threshold
):
# Convert threshold to float
threshold = float(threshold)
with gr.Progress(track_tqdm=True) as progress:
if deduplication_type == "Single dataset":
# Load the dataset
ds = load_dataset(dataset1_name, split=dataset1_split)
# Extract texts
try:
texts = [example[text_column_name] for example in ds]
except KeyError:
return f"Error: Text column '{text_column_name}' not found in dataset."
# Compute embeddings
progress(0.1, desc="Loading model and computing embeddings...")
model = StaticModel.from_pretrained("minishlab/M2V_base_output")
embedding_matrix = model.encode(texts, show_progressbar=False)
# Deduplicate
progress(0.5, desc="Performing deduplication...")
deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold)
# Prepare the results
num_duplicates = len(duplicate_to_original_mapping)
num_total = len(texts)
num_deduplicated = len(deduplicated_indices)
result_text = f"**Total documents:** {num_total}\n"
result_text += f"**Number of duplicates found:** {num_duplicates}\n"
result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"
# Show sample duplicates
result_text += "### Sample Duplicate Pairs with Differences:\n\n"
num_examples = min(5, num_duplicates)
if num_examples > 0:
sample_duplicates = list(duplicate_to_original_mapping.items())[:num_examples]
for duplicate_idx, original_idx in sample_duplicates:
original_text = texts[original_idx]
duplicate_text = texts[duplicate_idx]
differences = display_word_differences(original_text, duplicate_text)
result_text += f"**Original Text (Index {original_idx}):**\n{original_text}\n\n"
result_text += f"**Duplicate Text (Index {duplicate_idx}):**\n{duplicate_text}\n\n"
result_text += f"**Differences:**\n{differences}\n\n"
result_text += "---\n\n"
else:
result_text += "No duplicates found.\n"
return result_text
elif deduplication_type == "Cross-dataset":
# Load datasets
ds1 = load_dataset(dataset1_name, split=dataset1_split)
ds2 = load_dataset(dataset2_name, split=dataset2_split)
# Extract texts
try:
texts1 = [example[text_column_name] for example in ds1]
texts2 = [example[text_column_name] for example in ds2]
except KeyError:
return f"Error: Text column '{text_column_name}' not found in one of the datasets."
# Compute embeddings
progress(0.1, desc="Computing embeddings for Dataset 1...")
model = StaticModel.from_pretrained("minishlab/M2V_base_output")
embedding_matrix1 = model.encode(texts1, show_progressbar=False)
progress(0.5, desc="Computing embeddings for Dataset 2...")
embedding_matrix2 = model.encode(texts2, show_progressbar=False)
# Deduplicate across datasets
progress(0.7, desc="Performing cross-dataset deduplication...")
duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold)
num_duplicates = len(duplicate_indices_in_ds2)
num_total_ds2 = len(texts2)
num_unique_ds2 = num_total_ds2 - num_duplicates
result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"
# Show sample duplicates
result_text += "### Sample Duplicate Pairs with Differences:\n\n"
num_examples = min(5, num_duplicates)
if num_examples > 0:
sample_duplicates = list(duplicate_to_original_mapping.items())[:num_examples]
for duplicate_idx, original_idx in sample_duplicates:
original_text = texts1[original_idx]
duplicate_text = texts2[duplicate_idx]
differences = display_word_differences(original_text, duplicate_text)
result_text += f"**Original Text in {dataset1_name}/{dataset1_split} (Index {original_idx}):**\n{original_text}\n\n"
result_text += f"**Duplicate Text in {dataset2_name}/{dataset2_split} (Index {duplicate_idx}):**\n{duplicate_text}\n\n"
result_text += f"**Differences:**\n{differences}\n\n"
result_text += "---\n\n"
else:
result_text += "No duplicates found.\n"
return result_text
with gr.Blocks() as demo:
gr.Markdown("# Semantic Deduplication")
deduplication_type = gr.Radio(choices=["Single dataset", "Cross-dataset"], label="Deduplication Type", value="Single dataset")
with gr.Row():
dataset1_name = gr.Textbox(value="ag_news", label="Dataset 1 Name")
dataset1_split = gr.Textbox(value="train", label="Dataset 1 Split")
dataset2_row = gr.Column(visible=False)
with dataset2_row:
dataset2_name = gr.Textbox(value="ag_news", label="Dataset 2 Name")
dataset2_split = gr.Textbox(value="test", label="Dataset 2 Split")
text_column_name = gr.Textbox(value="text", label="Text Column Name")
threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, label="Similarity Threshold")
compute_button = gr.Button("Compute")
output = gr.Markdown()
# Function to update the visibility of dataset2_row
def update_visibility(choice):
if choice == "Cross-dataset":
return {dataset2_row: gr.update(visible=True)}
else:
return {dataset2_row: gr.update(visible=False)}
deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=[dataset2_row])
compute_button.click(
fn=perform_deduplication,
inputs=[deduplication_type, dataset1_name, dataset1_split, dataset2_name, dataset2_split, text_column_name, threshold],
outputs=output
)
demo.launch()
|