File size: 10,181 Bytes
25d2eb7
2827b8a
 
 
 
 
6188d2c
 
 
 
 
2827b8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6188d2c
2827b8a
 
 
6188d2c
2827b8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6188d2c
2827b8a
 
 
6188d2c
2827b8a
 
 
 
 
 
 
 
 
25d2eb7
2827b8a
 
 
 
 
 
6188d2c
2827b8a
 
 
 
6188d2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25d2eb7
2827b8a
 
6188d2c
2827b8a
6188d2c
2827b8a
 
 
6188d2c
 
2827b8a
 
 
6188d2c
 
 
2827b8a
6188d2c
2827b8a
6188d2c
2827b8a
6188d2c
2827b8a
6188d2c
 
2827b8a
 
 
 
 
6188d2c
2827b8a
 
6188d2c
2827b8a
 
6188d2c
2827b8a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import gradio as gr
from datasets import load_dataset
import numpy as np
from model2vec import StaticModel
from reach import Reach
from tqdm import tqdm
import difflib

def display_word_differences(x: str, y: str) -> str:
    diff = difflib.ndiff(x.split(), y.split())
    return " ".join([word for word in diff if word.startswith(('+', '-'))])

def deduplicate(embedding_matrix: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[np.ndarray, dict[int, int]]:
    """
    Deduplicate embeddings and return the deduplicated indices and a mapping of removed indices to their corresponding original indices.
    """
    reach = Reach(vectors=embedding_matrix, items=[str(i) for i in range(len(embedding_matrix))])

    # Use a set for deduplicated indices and keep track of duplicates
    deduplicated_indices = set(range(len(embedding_matrix)))  # Start with all indices as deduplicated
    duplicate_to_original_mapping = {}

    results = reach.nearest_neighbor_threshold(
        embedding_matrix, 
        threshold=threshold, 
        batch_size=batch_size, 
        show_progressbar=False  # Disable internal progress bar
    )

    # Process duplicates
    for i, similar_items in enumerate(tqdm(results, desc="Processing duplicates")):
        if i not in deduplicated_indices:
            continue  # Skip already marked duplicates

        # Similar items are returned as (index, score), we are only interested in the index
        similar_indices = [int(item[0]) for item in similar_items if int(item[0]) != i]

        # Mark similar documents as duplicates and map them to the original
        for sim_idx in similar_indices:
            if sim_idx in deduplicated_indices:
                deduplicated_indices.remove(sim_idx)
                duplicate_to_original_mapping[sim_idx] = i  # Map duplicate to original

    return np.array(list(deduplicated_indices)), duplicate_to_original_mapping

def deduplicate_across_datasets(embedding_matrix_1: np.ndarray, embedding_matrix_2: np.ndarray, threshold: float, batch_size: int = 1024) -> tuple[list[int], dict[int, int]]:
    """
    Deduplicate embeddings across two datasets and return the indices of duplicates between them.
    """
    reach = Reach(vectors=embedding_matrix_1, items=[str(i) for i in range(len(embedding_matrix_1))])

    # Keep track of duplicates in the second dataset
    duplicate_indices_in_test = []
    duplicate_to_original_mapping = {}

    # Find nearest neighbors from the test set in the train set
    results = reach.nearest_neighbor_threshold(
        embedding_matrix_2, 
        threshold=threshold, 
        batch_size=batch_size, 
        show_progressbar=False  # Disable internal progress bar
    )

    # Process duplicates
    for i, similar_items in enumerate(tqdm(results, desc="Processing duplicates")):
        # Similar items are returned as (index, score), we are only interested in the index
        similar_indices = [int(item[0]) for item in similar_items if item[1] >= threshold]  # Keep those above the threshold

        # If we find a similar item in the train set, mark it as a duplicate
        if similar_indices:
            duplicate_indices_in_test.append(i)
            duplicate_to_original_mapping[i] = similar_indices[0]  # Map duplicate in test to original in train

    return duplicate_indices_in_test, duplicate_to_original_mapping

def perform_deduplication(
    deduplication_type,
    dataset1_name,
    dataset1_split,
    dataset2_name,
    dataset2_split,
    text_column_name,
    threshold
):
    # Convert threshold to float
    threshold = float(threshold)

    with gr.Progress(track_tqdm=True) as progress:
        if deduplication_type == "Single dataset":
            # Load the dataset
            ds = load_dataset(dataset1_name, split=dataset1_split)

            # Extract texts
            try:
                texts = [example[text_column_name] for example in ds]
            except KeyError:
                return f"Error: Text column '{text_column_name}' not found in dataset."

            # Compute embeddings
            progress(0.1, desc="Loading model and computing embeddings...")
            model = StaticModel.from_pretrained("minishlab/M2V_base_output")
            embedding_matrix = model.encode(texts, show_progressbar=False)

            # Deduplicate
            progress(0.5, desc="Performing deduplication...")
            deduplicated_indices, duplicate_to_original_mapping = deduplicate(embedding_matrix, threshold)

            # Prepare the results
            num_duplicates = len(duplicate_to_original_mapping)
            num_total = len(texts)
            num_deduplicated = len(deduplicated_indices)

            result_text = f"**Total documents:** {num_total}\n"
            result_text += f"**Number of duplicates found:** {num_duplicates}\n"
            result_text += f"**Number of unique documents after deduplication:** {num_deduplicated}\n\n"

            # Show sample duplicates
            result_text += "### Sample Duplicate Pairs with Differences:\n\n"
            num_examples = min(5, num_duplicates)
            if num_examples > 0:
                sample_duplicates = list(duplicate_to_original_mapping.items())[:num_examples]
                for duplicate_idx, original_idx in sample_duplicates:
                    original_text = texts[original_idx]
                    duplicate_text = texts[duplicate_idx]
                    differences = display_word_differences(original_text, duplicate_text)
                    result_text += f"**Original Text (Index {original_idx}):**\n{original_text}\n\n"
                    result_text += f"**Duplicate Text (Index {duplicate_idx}):**\n{duplicate_text}\n\n"
                    result_text += f"**Differences:**\n{differences}\n\n"
                    result_text += "---\n\n"
            else:
                result_text += "No duplicates found.\n"

            return result_text

        elif deduplication_type == "Cross-dataset":
            # Load datasets
            ds1 = load_dataset(dataset1_name, split=dataset1_split)
            ds2 = load_dataset(dataset2_name, split=dataset2_split)

            # Extract texts
            try:
                texts1 = [example[text_column_name] for example in ds1]
                texts2 = [example[text_column_name] for example in ds2]
            except KeyError:
                return f"Error: Text column '{text_column_name}' not found in one of the datasets."

            # Compute embeddings
            progress(0.1, desc="Computing embeddings for Dataset 1...")
            model = StaticModel.from_pretrained("minishlab/M2V_base_output")
            embedding_matrix1 = model.encode(texts1, show_progressbar=False)

            progress(0.5, desc="Computing embeddings for Dataset 2...")
            embedding_matrix2 = model.encode(texts2, show_progressbar=False)

            # Deduplicate across datasets
            progress(0.7, desc="Performing cross-dataset deduplication...")
            duplicate_indices_in_ds2, duplicate_to_original_mapping = deduplicate_across_datasets(embedding_matrix1, embedding_matrix2, threshold)

            num_duplicates = len(duplicate_indices_in_ds2)
            num_total_ds2 = len(texts2)
            num_unique_ds2 = num_total_ds2 - num_duplicates

            result_text = f"**Total documents in {dataset2_name}/{dataset2_split}:** {num_total_ds2}\n"
            result_text += f"**Number of duplicates found in {dataset2_name}/{dataset2_split}:** {num_duplicates}\n"
            result_text += f"**Number of unique documents in {dataset2_name}/{dataset2_split} after deduplication:** {num_unique_ds2}\n\n"

            # Show sample duplicates
            result_text += "### Sample Duplicate Pairs with Differences:\n\n"
            num_examples = min(5, num_duplicates)
            if num_examples > 0:
                sample_duplicates = list(duplicate_to_original_mapping.items())[:num_examples]
                for duplicate_idx, original_idx in sample_duplicates:
                    original_text = texts1[original_idx]
                    duplicate_text = texts2[duplicate_idx]
                    differences = display_word_differences(original_text, duplicate_text)
                    result_text += f"**Original Text in {dataset1_name}/{dataset1_split} (Index {original_idx}):**\n{original_text}\n\n"
                    result_text += f"**Duplicate Text in {dataset2_name}/{dataset2_split} (Index {duplicate_idx}):**\n{duplicate_text}\n\n"
                    result_text += f"**Differences:**\n{differences}\n\n"
                    result_text += "---\n\n"
            else:
                result_text += "No duplicates found.\n"

            return result_text

with gr.Blocks() as demo:
    gr.Markdown("# Semantic Deduplication")

    deduplication_type = gr.Radio(choices=["Single dataset", "Cross-dataset"], label="Deduplication Type", value="Single dataset")

    with gr.Row():
        dataset1_name = gr.Textbox(value="ag_news", label="Dataset 1 Name")
        dataset1_split = gr.Textbox(value="train", label="Dataset 1 Split")

    dataset2_row = gr.Column(visible=False)
    with dataset2_row:
        dataset2_name = gr.Textbox(value="ag_news", label="Dataset 2 Name")
        dataset2_split = gr.Textbox(value="test", label="Dataset 2 Split")

    text_column_name = gr.Textbox(value="text", label="Text Column Name")

    threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.8, label="Similarity Threshold")

    compute_button = gr.Button("Compute")

    output = gr.Markdown()

    # Function to update the visibility of dataset2_row
    def update_visibility(choice):
        if choice == "Cross-dataset":
            return {dataset2_row: gr.update(visible=True)}
        else:
            return {dataset2_row: gr.update(visible=False)}
    
    deduplication_type.change(update_visibility, inputs=deduplication_type, outputs=[dataset2_row])

    compute_button.click(
        fn=perform_deduplication,
        inputs=[deduplication_type, dataset1_name, dataset1_split, dataset2_name, dataset2_split, text_column_name, threshold],
        outputs=output
    )

demo.launch()