|
import gradio as gr |
|
from datasets import load_dataset, Dataset |
|
from difflib import ndiff |
|
import pandas as pd |
|
from gradio_huggingfacehub_search import HuggingfaceHubSearch |
|
|
|
from semhash import SemHash |
|
from semhash.datamodels import DeduplicationResult |
|
|
|
from model2vec import StaticModel |
|
|
|
|
|
default_dataset_name = "SetFit/amazon_massive_scenario_en-US" |
|
default_dataset1_split = "train" |
|
default_dataset2_split = "test" |
|
default_text_column = "text" |
|
default_threshold = 0.9 |
|
|
|
|
|
model = StaticModel.from_pretrained("minishlab/potion-base-8M") |
|
|
|
|
|
def display_word_differences(x: str, y: str) -> str: |
|
""" |
|
Display the word-level differences between two texts, formatted to avoid |
|
misinterpretation of Markdown syntax. |
|
""" |
|
diff = ndiff(x.split(), y.split()) |
|
formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-"))) |
|
return f"```\n{formatted_diff}\n```" |
|
|
|
|
|
def load_dataset_texts( |
|
dataset_name: str, dataset_split: str, text_column: str |
|
) -> tuple[list[str], Dataset]: |
|
"""Load texts from a specified dataset split.""" |
|
ds = load_dataset(dataset_name, split=dataset_split) |
|
return [example[text_column] for example in ds], ds |
|
|
|
|
|
def deduplicate_single_dataset( |
|
texts: list[str], threshold: float |
|
) -> DeduplicationResult: |
|
""" |
|
Deduplicate within a single dataset using SemHash, treating each text |
|
as a raw string record. |
|
""" |
|
|
|
semhash = SemHash.from_records(records=texts, model=model) |
|
|
|
return semhash.self_deduplicate(threshold=threshold) |
|
|
|
|
|
def deduplicate_two_datasets( |
|
texts1: list[str], texts2: list[str], threshold: float |
|
) -> DeduplicationResult: |
|
"""Deduplicate dataset2 against dataset1, both as raw strings, using SemHash.""" |
|
|
|
semhash = SemHash.from_records(records=texts1, model=model) |
|
|
|
return semhash.deduplicate(records=texts2, threshold=threshold) |
|
|
|
|
|
def create_deduplicated_dataset( |
|
original_dataset: Dataset, deduplicated_texts: list[str], text_column: str |
|
) -> Dataset: |
|
"""Create a new dataset with only the deduplicated texts.""" |
|
|
|
text_to_row = {row[text_column]: row for row in original_dataset} |
|
|
|
|
|
deduplicated_rows = [] |
|
for text in deduplicated_texts: |
|
if text in text_to_row: |
|
deduplicated_rows.append(text_to_row[text]) |
|
|
|
return Dataset.from_list(deduplicated_rows) |
|
|
|
|
|
def perform_deduplication( |
|
deduplication_type: str, |
|
dataset1_name: str, |
|
dataset1_split: str, |
|
dataset1_text_column: str, |
|
dataset2_name: str = "", |
|
dataset2_split: str = "", |
|
dataset2_text_column: str = "", |
|
threshold: float = default_threshold, |
|
progress: gr.Progress = gr.Progress(track_tqdm=True), |
|
): |
|
""" |
|
Perform deduplication on one or two datasets using SemHash. This function |
|
streams status updates to Gradio for user feedback. |
|
""" |
|
try: |
|
threshold = float(threshold) |
|
|
|
|
|
texts1, dataset1 = load_dataset_texts( |
|
dataset1_name, dataset1_split, dataset1_text_column |
|
) |
|
|
|
if deduplication_type == "Single dataset": |
|
|
|
result = deduplicate_single_dataset(texts1, threshold=threshold) |
|
|
|
|
|
for duprec in result.duplicates: |
|
duprec.duplicates.sort(key=lambda x: x[1]) |
|
|
|
|
|
deduplicated_dataset = create_deduplicated_dataset( |
|
dataset1, result.deduplicated, dataset1_text_column |
|
) |
|
|
|
|
|
num_duplicates = len(result.duplicates) |
|
deduplicated_count = len(result.deduplicated) |
|
total_docs = len(texts1) |
|
|
|
|
|
examples_table = None |
|
if num_duplicates > 0: |
|
|
|
duplicates_with_data = [ |
|
duprec for duprec in result.duplicates if duprec.duplicates |
|
] |
|
|
|
|
|
for duprec in result.duplicates: |
|
duprec.duplicates.sort(key=lambda x: x[1]) |
|
|
|
if duplicates_with_data: |
|
|
|
table_data = [] |
|
for duprec in duplicates_with_data[:5]: |
|
dup_text = duprec.record |
|
orig_text, score = duprec.duplicates[0] |
|
table_data.append( |
|
[ |
|
orig_text[:200] + "..." |
|
if len(orig_text) > 200 |
|
else orig_text, |
|
dup_text[:200] + "..." |
|
if len(dup_text) > 200 |
|
else dup_text, |
|
f"{score:.4f}", |
|
] |
|
) |
|
|
|
examples_table = pd.DataFrame( |
|
table_data, |
|
columns=["Original Text", "Duplicate Text", "Similarity Score"], |
|
) |
|
|
|
|
|
gr.Info( |
|
f"Deduplication completed! Found {num_duplicates} duplicates. " |
|
f"Dataset reduced from {total_docs} to {deduplicated_count} unique documents." |
|
) |
|
|
|
|
|
if examples_table is not None and not examples_table.empty: |
|
return deduplicated_dataset, gr.update( |
|
visible=True, value=examples_table |
|
) |
|
else: |
|
return deduplicated_dataset, gr.update(visible=False) |
|
|
|
else: |
|
|
|
texts2, dataset2 = load_dataset_texts( |
|
dataset2_name, dataset2_split, dataset2_text_column |
|
) |
|
|
|
result = deduplicate_two_datasets(texts1, texts2, threshold=threshold) |
|
|
|
|
|
for duprec in result.duplicates: |
|
duprec.duplicates.sort(key=lambda x: x[1]) |
|
|
|
|
|
deduplicated_dataset = create_deduplicated_dataset( |
|
dataset2, result.deduplicated, dataset2_text_column |
|
) |
|
|
|
num_duplicates = len(result.duplicates) |
|
total_docs2 = len(texts2) |
|
deduplicated_count = len(result.deduplicated) |
|
|
|
|
|
examples_table = None |
|
if num_duplicates > 0: |
|
|
|
duplicates_with_data = [ |
|
duprec for duprec in result.duplicates if duprec.duplicates |
|
] |
|
if duplicates_with_data: |
|
|
|
table_data = [] |
|
for duprec in duplicates_with_data[:5]: |
|
dup_text = duprec.record |
|
orig_text, score = duprec.duplicates[0] |
|
table_data.append( |
|
[ |
|
orig_text[:200] + "..." |
|
if len(orig_text) > 200 |
|
else orig_text, |
|
dup_text[:200] + "..." |
|
if len(dup_text) > 200 |
|
else dup_text, |
|
f"{score:.4f}", |
|
] |
|
) |
|
|
|
examples_table = pd.DataFrame( |
|
table_data, |
|
columns=[ |
|
"Original Text (Dataset 1)", |
|
"Duplicate Text (Dataset 2)", |
|
"Similarity Score", |
|
], |
|
) |
|
|
|
|
|
gr.Info( |
|
f"Deduplication completed! Found {num_duplicates} duplicates in Dataset 2. " |
|
f"Dataset reduced from {total_docs2} to {deduplicated_count} unique documents." |
|
) |
|
|
|
|
|
if examples_table is not None and not examples_table.empty: |
|
return deduplicated_dataset, gr.update( |
|
visible=True, value=examples_table |
|
) |
|
else: |
|
return deduplicated_dataset, gr.update(visible=False) |
|
|
|
except Exception as e: |
|
gr.Error(f"An error occurred during deduplication: {str(e)}") |
|
return None, gr.update(visible=False) |
|
|
|
|
|
def push_to_hub( |
|
deduplicated_dataset: Dataset, |
|
output_dataset_name: str, |
|
oauth_profile: gr.OAuthProfile | None, |
|
oauth_token: gr.OAuthToken | None, |
|
progress: gr.Progress = gr.Progress(), |
|
) -> str: |
|
"""Push the deduplicated dataset to Hugging Face Hub.""" |
|
if oauth_token is None: |
|
raise gr.Error("Please log in with Hugging Face to push datasets to the Hub.") |
|
|
|
if not output_dataset_name.strip(): |
|
raise gr.Error("Please provide a dataset name.") |
|
|
|
if deduplicated_dataset is None: |
|
raise gr.Error( |
|
"No deduplicated dataset available. Please run deduplication first." |
|
) |
|
|
|
try: |
|
progress(0.1, desc="Preparing dataset...") |
|
|
|
|
|
username = oauth_profile.username if oauth_profile else None |
|
if "/" not in output_dataset_name and username: |
|
full_dataset_name = f"{username}/{output_dataset_name}" |
|
else: |
|
full_dataset_name = output_dataset_name |
|
|
|
progress(0.3, desc="Pushing to Hub...") |
|
|
|
|
|
deduplicated_dataset.push_to_hub( |
|
full_dataset_name, token=oauth_token.token, private=False |
|
) |
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
gr.Info( |
|
f"Successfully pushed deduplicated dataset with {len(deduplicated_dataset)} rows to the Hub!" |
|
) |
|
|
|
return ( |
|
f"✅ **Dataset published:** [{full_dataset_name}]" |
|
f"(https://huggingface.co/datasets/{full_dataset_name})" |
|
) |
|
|
|
except Exception as e: |
|
raise gr.Error(f"Failed to push dataset to Hub: {str(e)}") |
|
|
|
|
|
def get_user_info(oauth_profile: gr.OAuthProfile | None) -> str: |
|
"""Display user login status.""" |
|
if oauth_profile is None: |
|
return "Not logged in. Please log in to push datasets to the Hub." |
|
return f"Logged in as: **{oauth_profile.username}**" |
|
|
|
|
|
def update_push_button_state(oauth_profile: gr.OAuthProfile | None): |
|
"""Update the push button state based on login status.""" |
|
is_logged_in = oauth_profile is not None |
|
return gr.update(interactive=is_logged_in) |
|
|
|
|
|
|
|
with gr.Blocks( |
|
theme=gr.themes.Ocean(), css="#status_output { height: 50px; overflow: auto; }" |
|
) as demo: |
|
gr.Markdown("# SemDedup-My-Dataset: Semantic Text Deduplication Using SemHash") |
|
gr.Markdown(""" |
|
This demo showcases **semantic deduplication** using [SemHash](https://github.com/MinishLab/semhash) for HuggingFace datasets, using a [Model2Vec](https://github.com/MinishLab/model2vec) encoder. |
|
It can be used to identify duplicate texts within a **single dataset** or across **two datasets**. |
|
You can adjust the similarity threshold to control the strictness of the deduplication. |
|
|
|
""") |
|
|
|
deduplication_type = gr.Radio( |
|
choices=["Cross-dataset", "Single dataset"], |
|
label="Deduplication Type", |
|
value="Cross-dataset", |
|
) |
|
|
|
with gr.Row(): |
|
dataset1_name = HuggingfaceHubSearch( |
|
label="Dataset 1 Name", |
|
placeholder="Search for datasets on HuggingFace Hub", |
|
search_type="dataset", |
|
value=default_dataset_name, |
|
) |
|
dataset1_split = gr.Textbox( |
|
value=default_dataset1_split, label="Dataset 1 Split" |
|
) |
|
dataset1_text_column = gr.Textbox( |
|
value=default_text_column, label="Text Column Name" |
|
) |
|
|
|
dataset2_inputs = gr.Column(visible=True) |
|
with dataset2_inputs: |
|
with gr.Row(): |
|
dataset2_name = HuggingfaceHubSearch( |
|
label="Dataset 2 Name", |
|
placeholder="Search for datasets on HuggingFace Hub", |
|
search_type="dataset", |
|
value=default_dataset_name, |
|
) |
|
dataset2_split = gr.Textbox( |
|
value=default_dataset2_split, label="Dataset 2 Split" |
|
) |
|
dataset2_text_column = gr.Textbox( |
|
value=default_text_column, label="Text Column Name" |
|
) |
|
|
|
threshold = gr.Slider( |
|
0.0, 1.0, value=default_threshold, label="Similarity Threshold" |
|
) |
|
|
|
with gr.Row(): |
|
compute_button = gr.Button("Deduplicate", variant="primary") |
|
|
|
status_output = gr.Markdown(elem_id="status_output") |
|
|
|
|
|
examples_table = gr.Dataframe( |
|
headers=["Original Text", "Duplicate Text", "Similarity Score"], |
|
datatype=["str", "str", "str"], |
|
) |
|
|
|
|
|
deduplicated_dataset_state = gr.State() |
|
|
|
|
|
gr.Markdown("## Push Deduplicated Dataset to Hub") |
|
with gr.Row(): |
|
with gr.Column(): |
|
output_dataset_name = gr.Textbox( |
|
label="Output Dataset Name", |
|
placeholder="my-deduplicated-dataset", |
|
info="Will be saved as username/dataset-name", |
|
) |
|
with gr.Column(): |
|
push_button = gr.Button( |
|
"Push to Hub", variant="secondary", interactive=False |
|
) |
|
login_button = gr.LoginButton() |
|
|
|
|
|
with gr.Row(): |
|
user_info = gr.Markdown() |
|
push_output = gr.Markdown() |
|
|
|
|
|
login_button.activate() |
|
|
|
def update_visibility(choice: str): |
|
return gr.update(visible=(choice == "Cross-dataset")) |
|
|
|
deduplication_type.change( |
|
update_visibility, inputs=deduplication_type, outputs=dataset2_inputs |
|
) |
|
|
|
|
|
demo.load(get_user_info, inputs=None, outputs=user_info) |
|
demo.load(update_push_button_state, inputs=None, outputs=push_button) |
|
login_button.click(get_user_info, inputs=None, outputs=user_info) |
|
login_button.click(update_push_button_state, inputs=None, outputs=push_button) |
|
|
|
compute_button.click( |
|
fn=perform_deduplication, |
|
inputs=[ |
|
deduplication_type, |
|
dataset1_name, |
|
dataset1_split, |
|
dataset1_text_column, |
|
dataset2_name, |
|
dataset2_split, |
|
dataset2_text_column, |
|
threshold, |
|
], |
|
outputs=[deduplicated_dataset_state, examples_table], |
|
) |
|
|
|
push_button.click( |
|
fn=push_to_hub, |
|
inputs=[ |
|
deduplicated_dataset_state, |
|
output_dataset_name, |
|
], |
|
outputs=push_output, |
|
) |
|
|
|
demo.launch() |
|
|