File size: 15,692 Bytes
25d2eb7
b1ba346
7a1cd7a
b1ba346
d6a1331
a81fb12
b54da62
 
 
 
f5eb405
95530b9
c8fad0f
b54da62
 
393e68a
3b4c438
f5eb405
b54da62
 
 
58d8f1a
7a1cd7a
ed5b7bd
 
 
 
7a1cd7a
73a84b9
 
 
ed5b7bd
b1ba346
 
 
b54da62
95530b9
b1ba346
7a1cd7a
b54da62
b1ba346
 
 
 
 
 
 
b54da62
 
 
 
 
 
b1ba346
 
 
b54da62
 
 
 
 
 
 
b1ba346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f0286f
24f7d5b
 
 
 
 
 
 
 
b1ba346
2827b8a
ed5b7bd
b54da62
 
ed5b7bd
f5eb405
 
3bd0812
b54da62
b1ba346
 
 
24f7d5b
f5eb405
b54da62
 
 
b1ba346
b54da62
b1ba346
 
 
 
 
 
b54da62
 
 
 
 
3bd0812
b1ba346
 
5422464
c8fad0f
b1ba346
 
 
 
 
 
 
 
c8fad0f
b1ba346
 
c8fad0f
 
b54da62
b1ba346
 
 
 
 
 
 
 
 
 
b54da62
3bd0812
b1ba346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c58907b
95530b9
b54da62
b1ba346
 
 
24f7d5b
b54da62
 
b1ba346
b54da62
b1ba346
 
 
 
 
 
b54da62
 
 
 
c58907b
b1ba346
 
39a5b1c
b1ba346
 
 
 
c8fad0f
b1ba346
 
c8fad0f
b1ba346
b54da62
b1ba346
 
 
 
 
 
 
 
 
 
b54da62
b1ba346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39a5b1c
b1ba346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39a5b1c
b1ba346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5eb405
6b0e834
b1ba346
 
 
 
 
 
 
 
 
 
 
 
 
 
c58907b
b54da62
 
b1ba346
 
 
 
24f7d5b
b54da62
 
 
 
24f7d5b
4f0286f
 
b54da62
4f0286f
b54da62
4f0286f
 
 
d6a1331
 
 
 
 
 
b1ba346
 
 
 
 
 
4f0286f
b54da62
4f0286f
 
d6a1331
 
 
 
 
b1ba346
 
 
 
 
 
 
4f0286f
b1ba346
 
 
2f9e086
b54da62
b1ba346
2f9e086
1a5f99b
b1ba346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f0286f
6d4e559
 
 
24f7d5b
b54da62
4f0286f
b1ba346
 
 
 
 
 
 
 
 
4f0286f
 
 
 
 
 
 
 
 
 
 
c58907b
4f0286f
b1ba346
 
 
 
 
 
 
 
 
 
4f0286f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
import gradio as gr
from datasets import load_dataset, Dataset
from difflib import ndiff
import pandas as pd
from gradio_huggingfacehub_search import HuggingfaceHubSearch

from semhash import SemHash
from semhash.datamodels import DeduplicationResult

from model2vec import StaticModel

# Default parameters
default_dataset_name = "SetFit/amazon_massive_scenario_en-US"
default_dataset1_split = "train"
default_dataset2_split = "test"
default_text_column = "text"
default_threshold = 0.9

# Load the model to use
model = StaticModel.from_pretrained("minishlab/potion-base-8M")


def display_word_differences(x: str, y: str) -> str:
    """
    Display the word-level differences between two texts, formatted to avoid
    misinterpretation of Markdown syntax.
    """
    diff = ndiff(x.split(), y.split())
    formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-")))
    return f"```\n{formatted_diff}\n```"


def load_dataset_texts(
    dataset_name: str, dataset_split: str, text_column: str
) -> tuple[list[str], Dataset]:
    """Load texts from a specified dataset split."""
    ds = load_dataset(dataset_name, split=dataset_split)
    return [example[text_column] for example in ds], ds


def deduplicate_single_dataset(
    texts: list[str], threshold: float
) -> DeduplicationResult:
    """
    Deduplicate within a single dataset using SemHash, treating each text
    as a raw string record.
    """
    # Build a SemHash index from the raw texts
    semhash = SemHash.from_records(records=texts, model=model)
    # Deduplicate the entire dataset
    return semhash.self_deduplicate(threshold=threshold)


def deduplicate_two_datasets(
    texts1: list[str], texts2: list[str], threshold: float
) -> DeduplicationResult:
    """Deduplicate dataset2 against dataset1, both as raw strings, using SemHash."""
    # Build SemHash index on dataset1
    semhash = SemHash.from_records(records=texts1, model=model)
    # Deduplicate texts2 against dataset1
    return semhash.deduplicate(records=texts2, threshold=threshold)


def create_deduplicated_dataset(
    original_dataset: Dataset, deduplicated_texts: list[str], text_column: str
) -> Dataset:
    """Create a new dataset with only the deduplicated texts."""
    # Create a mapping from text to original row
    text_to_row = {row[text_column]: row for row in original_dataset}

    # Build new dataset with deduplicated texts
    deduplicated_rows = []
    for text in deduplicated_texts:
        if text in text_to_row:
            deduplicated_rows.append(text_to_row[text])

    return Dataset.from_list(deduplicated_rows)


def perform_deduplication(
    deduplication_type: str,
    dataset1_name: str,
    dataset1_split: str,
    dataset1_text_column: str,
    dataset2_name: str = "",
    dataset2_split: str = "",
    dataset2_text_column: str = "",
    threshold: float = default_threshold,
    progress: gr.Progress = gr.Progress(track_tqdm=True),
):
    """
    Perform deduplication on one or two datasets using SemHash. This function
    streams status updates to Gradio for user feedback.
    """
    try:
        threshold = float(threshold)

        # Load Dataset 1
        texts1, dataset1 = load_dataset_texts(
            dataset1_name, dataset1_split, dataset1_text_column
        )

        if deduplication_type == "Single dataset":
            # Single-dataset deduplication
            result = deduplicate_single_dataset(texts1, threshold=threshold)

            # Sort all duplicates by score (ascending for least similar)
            for duprec in result.duplicates:
                duprec.duplicates.sort(key=lambda x: x[1])

            # Create deduplicated dataset
            deduplicated_dataset = create_deduplicated_dataset(
                dataset1, result.deduplicated, dataset1_text_column
            )

            # Summarize results
            num_duplicates = len(result.duplicates)
            deduplicated_count = len(result.deduplicated)
            total_docs = len(texts1)

            # Create examples table
            examples_table = None
            if num_duplicates > 0:
                # Only show duplicates that actually have near-duplicate records
                duplicates_with_data = [
                    duprec for duprec in result.duplicates if duprec.duplicates
                ]

                # sort duplicates by score (ascending for least similar)
                for duprec in result.duplicates:
                    duprec.duplicates.sort(key=lambda x: x[1])

                if duplicates_with_data:
                    # Create table data for the 5 least similar examples
                    table_data = []
                    for duprec in duplicates_with_data[:5]:
                        dup_text = duprec.record
                        orig_text, score = duprec.duplicates[0]
                        table_data.append(
                            [
                                orig_text[:200] + "..."
                                if len(orig_text) > 200
                                else orig_text,
                                dup_text[:200] + "..."
                                if len(dup_text) > 200
                                else dup_text,
                                f"{score:.4f}",
                            ]
                        )

                    examples_table = pd.DataFrame(
                        table_data,
                        columns=["Original Text", "Duplicate Text", "Similarity Score"],
                    )

            # Show success info with stats
            gr.Info(
                f"Deduplication completed! Found {num_duplicates} duplicates. "
                f"Dataset reduced from {total_docs} to {deduplicated_count} unique documents."
            )

            # Return table with visibility update
            if examples_table is not None and not examples_table.empty:
                return deduplicated_dataset, gr.update(
                    visible=True, value=examples_table
                )
            else:
                return deduplicated_dataset, gr.update(visible=False)

        else:
            # Cross-dataset deduplication
            texts2, dataset2 = load_dataset_texts(
                dataset2_name, dataset2_split, dataset2_text_column
            )

            result = deduplicate_two_datasets(texts1, texts2, threshold=threshold)

            # Sort duplicates by score (ascending for least similar)
            for duprec in result.duplicates:
                duprec.duplicates.sort(key=lambda x: x[1])

            # Create deduplicated dataset from dataset2
            deduplicated_dataset = create_deduplicated_dataset(
                dataset2, result.deduplicated, dataset2_text_column
            )

            num_duplicates = len(result.duplicates)
            total_docs2 = len(texts2)
            deduplicated_count = len(result.deduplicated)

            # Create examples table
            examples_table = None
            if num_duplicates > 0:
                # Again, only show duplicates that have records
                duplicates_with_data = [
                    duprec for duprec in result.duplicates if duprec.duplicates
                ]
                if duplicates_with_data:
                    # Create table data for the 5 least similar examples
                    table_data = []
                    for duprec in duplicates_with_data[:5]:
                        dup_text = duprec.record
                        orig_text, score = duprec.duplicates[0]
                        table_data.append(
                            [
                                orig_text[:200] + "..."
                                if len(orig_text) > 200
                                else orig_text,
                                dup_text[:200] + "..."
                                if len(dup_text) > 200
                                else dup_text,
                                f"{score:.4f}",
                            ]
                        )

                    examples_table = pd.DataFrame(
                        table_data,
                        columns=[
                            "Original Text (Dataset 1)",
                            "Duplicate Text (Dataset 2)",
                            "Similarity Score",
                        ],
                    )

            # Show success info with stats
            gr.Info(
                f"Deduplication completed! Found {num_duplicates} duplicates in Dataset 2. "
                f"Dataset reduced from {total_docs2} to {deduplicated_count} unique documents."
            )

            # Return table with visibility update
            if examples_table is not None and not examples_table.empty:
                return deduplicated_dataset, gr.update(
                    visible=True, value=examples_table
                )
            else:
                return deduplicated_dataset, gr.update(visible=False)

    except Exception as e:
        gr.Error(f"An error occurred during deduplication: {str(e)}")
        return None, gr.update(visible=False)


def push_to_hub(
    deduplicated_dataset: Dataset,
    output_dataset_name: str,
    oauth_profile: gr.OAuthProfile | None,
    oauth_token: gr.OAuthToken | None,
    progress: gr.Progress = gr.Progress(),
) -> str:
    """Push the deduplicated dataset to Hugging Face Hub."""
    if oauth_token is None:
        raise gr.Error("Please log in with Hugging Face to push datasets to the Hub.")

    if not output_dataset_name.strip():
        raise gr.Error("Please provide a dataset name.")

    if deduplicated_dataset is None:
        raise gr.Error(
            "No deduplicated dataset available. Please run deduplication first."
        )

    try:
        progress(0.1, desc="Preparing dataset...")

        # Determine the full dataset name (username/dataset_name)
        username = oauth_profile.username if oauth_profile else None
        if "/" not in output_dataset_name and username:
            full_dataset_name = f"{username}/{output_dataset_name}"
        else:
            full_dataset_name = output_dataset_name

        progress(0.3, desc="Pushing to Hub...")

        # Push to hub using the OAuth token
        deduplicated_dataset.push_to_hub(
            full_dataset_name, token=oauth_token.token, private=False
        )

        progress(1.0, desc="Complete!")

        gr.Info(
            f"Successfully pushed deduplicated dataset with {len(deduplicated_dataset)} rows to the Hub!"
        )

        return (
            f"✅ **Dataset published:** [{full_dataset_name}]"
            f"(https://huggingface.co/datasets/{full_dataset_name})"
        )

    except Exception as e:
        raise gr.Error(f"Failed to push dataset to Hub: {str(e)}")


def get_user_info(oauth_profile: gr.OAuthProfile | None) -> str:
    """Display user login status."""
    if oauth_profile is None:
        return "Not logged in. Please log in to push datasets to the Hub."
    return f"Logged in as: **{oauth_profile.username}**"


def update_push_button_state(oauth_profile: gr.OAuthProfile | None):
    """Update the push button state based on login status."""
    is_logged_in = oauth_profile is not None
    return gr.update(interactive=is_logged_in)


# --- Gradio App ---
with gr.Blocks(
    theme=gr.themes.Ocean(), css="#status_output { height: 50px; overflow: auto; }"
) as demo:
    gr.Markdown("# SemDedup-My-Dataset: Semantic Text Deduplication Using SemHash")
    gr.Markdown("""
    This demo showcases **semantic deduplication** using [SemHash](https://github.com/MinishLab/semhash) for HuggingFace datasets, using a [Model2Vec](https://github.com/MinishLab/model2vec) encoder.
    It can be used to identify duplicate texts within a **single dataset** or across **two datasets**.
    You can adjust the similarity threshold to control the strictness of the deduplication.

    """)

    deduplication_type = gr.Radio(
        choices=["Cross-dataset", "Single dataset"],
        label="Deduplication Type",
        value="Cross-dataset",  # default
    )

    with gr.Row():
        dataset1_name = HuggingfaceHubSearch(
            label="Dataset 1 Name",
            placeholder="Search for datasets on HuggingFace Hub",
            search_type="dataset",
            value=default_dataset_name,
        )
        dataset1_split = gr.Textbox(
            value=default_dataset1_split, label="Dataset 1 Split"
        )
        dataset1_text_column = gr.Textbox(
            value=default_text_column, label="Text Column Name"
        )

    dataset2_inputs = gr.Column(visible=True)
    with dataset2_inputs:
        with gr.Row():
            dataset2_name = HuggingfaceHubSearch(
                label="Dataset 2 Name",
                placeholder="Search for datasets on HuggingFace Hub",
                search_type="dataset",
                value=default_dataset_name,
            )
            dataset2_split = gr.Textbox(
                value=default_dataset2_split, label="Dataset 2 Split"
            )
            dataset2_text_column = gr.Textbox(
                value=default_text_column, label="Text Column Name"
            )

    threshold = gr.Slider(
        0.0, 1.0, value=default_threshold, label="Similarity Threshold"
    )

    with gr.Row():
        compute_button = gr.Button("Deduplicate", variant="primary")

    status_output = gr.Markdown(elem_id="status_output")

    # Examples table
    examples_table = gr.Dataframe(
        headers=["Original Text", "Duplicate Text", "Similarity Score"],
        datatype=["str", "str", "str"],
    )

    # Hidden state to store the deduplicated dataset
    deduplicated_dataset_state = gr.State()

    # Output dataset configuration
    gr.Markdown("## Push Deduplicated Dataset to Hub")
    with gr.Row():
        with gr.Column():
            output_dataset_name = gr.Textbox(
                label="Output Dataset Name",
                placeholder="my-deduplicated-dataset",
                info="Will be saved as username/dataset-name",
            )
        with gr.Column():
            push_button = gr.Button(
                "Push to Hub", variant="secondary", interactive=False
            )
            login_button = gr.LoginButton()

    # Login section - moved below push to hub
    with gr.Row():
        user_info = gr.Markdown()
        push_output = gr.Markdown()

    # HACK: for some reason gradio wants this.
    login_button.activate()
    
    def update_visibility(choice: str):
        return gr.update(visible=(choice == "Cross-dataset"))

    deduplication_type.change(
        update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
    )

    # Update user info and button state when page loads or login status changes
    demo.load(get_user_info, inputs=None, outputs=user_info)
    demo.load(update_push_button_state, inputs=None, outputs=push_button)
    login_button.click(get_user_info, inputs=None, outputs=user_info)
    login_button.click(update_push_button_state, inputs=None, outputs=push_button)

    compute_button.click(
        fn=perform_deduplication,
        inputs=[
            deduplication_type,
            dataset1_name,
            dataset1_split,
            dataset1_text_column,
            dataset2_name,
            dataset2_split,
            dataset2_text_column,
            threshold,
        ],
        outputs=[deduplicated_dataset_state, examples_table],
    )

    push_button.click(
        fn=push_to_hub,
        inputs=[
            deduplicated_dataset_state,
            output_dataset_name,
        ],
        outputs=push_output,
    )

demo.launch()