burtenshaw's picture
burtenshaw HF Staff
Update app.py
6d4e559 verified
import gradio as gr
from datasets import load_dataset, Dataset
from difflib import ndiff
import pandas as pd
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from semhash import SemHash
from semhash.datamodels import DeduplicationResult
from model2vec import StaticModel
# Default parameters
default_dataset_name = "SetFit/amazon_massive_scenario_en-US"
default_dataset1_split = "train"
default_dataset2_split = "test"
default_text_column = "text"
default_threshold = 0.9
# Load the model to use
model = StaticModel.from_pretrained("minishlab/potion-base-8M")
def display_word_differences(x: str, y: str) -> str:
"""
Display the word-level differences between two texts, formatted to avoid
misinterpretation of Markdown syntax.
"""
diff = ndiff(x.split(), y.split())
formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-")))
return f"```\n{formatted_diff}\n```"
def load_dataset_texts(
dataset_name: str, dataset_split: str, text_column: str
) -> tuple[list[str], Dataset]:
"""Load texts from a specified dataset split."""
ds = load_dataset(dataset_name, split=dataset_split)
return [example[text_column] for example in ds], ds
def deduplicate_single_dataset(
texts: list[str], threshold: float
) -> DeduplicationResult:
"""
Deduplicate within a single dataset using SemHash, treating each text
as a raw string record.
"""
# Build a SemHash index from the raw texts
semhash = SemHash.from_records(records=texts, model=model)
# Deduplicate the entire dataset
return semhash.self_deduplicate(threshold=threshold)
def deduplicate_two_datasets(
texts1: list[str], texts2: list[str], threshold: float
) -> DeduplicationResult:
"""Deduplicate dataset2 against dataset1, both as raw strings, using SemHash."""
# Build SemHash index on dataset1
semhash = SemHash.from_records(records=texts1, model=model)
# Deduplicate texts2 against dataset1
return semhash.deduplicate(records=texts2, threshold=threshold)
def create_deduplicated_dataset(
original_dataset: Dataset, deduplicated_texts: list[str], text_column: str
) -> Dataset:
"""Create a new dataset with only the deduplicated texts."""
# Create a mapping from text to original row
text_to_row = {row[text_column]: row for row in original_dataset}
# Build new dataset with deduplicated texts
deduplicated_rows = []
for text in deduplicated_texts:
if text in text_to_row:
deduplicated_rows.append(text_to_row[text])
return Dataset.from_list(deduplicated_rows)
def perform_deduplication(
deduplication_type: str,
dataset1_name: str,
dataset1_split: str,
dataset1_text_column: str,
dataset2_name: str = "",
dataset2_split: str = "",
dataset2_text_column: str = "",
threshold: float = default_threshold,
progress: gr.Progress = gr.Progress(track_tqdm=True),
):
"""
Perform deduplication on one or two datasets using SemHash. This function
streams status updates to Gradio for user feedback.
"""
try:
threshold = float(threshold)
# Load Dataset 1
texts1, dataset1 = load_dataset_texts(
dataset1_name, dataset1_split, dataset1_text_column
)
if deduplication_type == "Single dataset":
# Single-dataset deduplication
result = deduplicate_single_dataset(texts1, threshold=threshold)
# Sort all duplicates by score (ascending for least similar)
for duprec in result.duplicates:
duprec.duplicates.sort(key=lambda x: x[1])
# Create deduplicated dataset
deduplicated_dataset = create_deduplicated_dataset(
dataset1, result.deduplicated, dataset1_text_column
)
# Summarize results
num_duplicates = len(result.duplicates)
deduplicated_count = len(result.deduplicated)
total_docs = len(texts1)
# Create examples table
examples_table = None
if num_duplicates > 0:
# Only show duplicates that actually have near-duplicate records
duplicates_with_data = [
duprec for duprec in result.duplicates if duprec.duplicates
]
# sort duplicates by score (ascending for least similar)
for duprec in result.duplicates:
duprec.duplicates.sort(key=lambda x: x[1])
if duplicates_with_data:
# Create table data for the 5 least similar examples
table_data = []
for duprec in duplicates_with_data[:5]:
dup_text = duprec.record
orig_text, score = duprec.duplicates[0]
table_data.append(
[
orig_text[:200] + "..."
if len(orig_text) > 200
else orig_text,
dup_text[:200] + "..."
if len(dup_text) > 200
else dup_text,
f"{score:.4f}",
]
)
examples_table = pd.DataFrame(
table_data,
columns=["Original Text", "Duplicate Text", "Similarity Score"],
)
# Show success info with stats
gr.Info(
f"Deduplication completed! Found {num_duplicates} duplicates. "
f"Dataset reduced from {total_docs} to {deduplicated_count} unique documents."
)
# Return table with visibility update
if examples_table is not None and not examples_table.empty:
return deduplicated_dataset, gr.update(
visible=True, value=examples_table
)
else:
return deduplicated_dataset, gr.update(visible=False)
else:
# Cross-dataset deduplication
texts2, dataset2 = load_dataset_texts(
dataset2_name, dataset2_split, dataset2_text_column
)
result = deduplicate_two_datasets(texts1, texts2, threshold=threshold)
# Sort duplicates by score (ascending for least similar)
for duprec in result.duplicates:
duprec.duplicates.sort(key=lambda x: x[1])
# Create deduplicated dataset from dataset2
deduplicated_dataset = create_deduplicated_dataset(
dataset2, result.deduplicated, dataset2_text_column
)
num_duplicates = len(result.duplicates)
total_docs2 = len(texts2)
deduplicated_count = len(result.deduplicated)
# Create examples table
examples_table = None
if num_duplicates > 0:
# Again, only show duplicates that have records
duplicates_with_data = [
duprec for duprec in result.duplicates if duprec.duplicates
]
if duplicates_with_data:
# Create table data for the 5 least similar examples
table_data = []
for duprec in duplicates_with_data[:5]:
dup_text = duprec.record
orig_text, score = duprec.duplicates[0]
table_data.append(
[
orig_text[:200] + "..."
if len(orig_text) > 200
else orig_text,
dup_text[:200] + "..."
if len(dup_text) > 200
else dup_text,
f"{score:.4f}",
]
)
examples_table = pd.DataFrame(
table_data,
columns=[
"Original Text (Dataset 1)",
"Duplicate Text (Dataset 2)",
"Similarity Score",
],
)
# Show success info with stats
gr.Info(
f"Deduplication completed! Found {num_duplicates} duplicates in Dataset 2. "
f"Dataset reduced from {total_docs2} to {deduplicated_count} unique documents."
)
# Return table with visibility update
if examples_table is not None and not examples_table.empty:
return deduplicated_dataset, gr.update(
visible=True, value=examples_table
)
else:
return deduplicated_dataset, gr.update(visible=False)
except Exception as e:
gr.Error(f"An error occurred during deduplication: {str(e)}")
return None, gr.update(visible=False)
def push_to_hub(
deduplicated_dataset: Dataset,
output_dataset_name: str,
oauth_profile: gr.OAuthProfile | None,
oauth_token: gr.OAuthToken | None,
progress: gr.Progress = gr.Progress(),
) -> str:
"""Push the deduplicated dataset to Hugging Face Hub."""
if oauth_token is None:
raise gr.Error("Please log in with Hugging Face to push datasets to the Hub.")
if not output_dataset_name.strip():
raise gr.Error("Please provide a dataset name.")
if deduplicated_dataset is None:
raise gr.Error(
"No deduplicated dataset available. Please run deduplication first."
)
try:
progress(0.1, desc="Preparing dataset...")
# Determine the full dataset name (username/dataset_name)
username = oauth_profile.username if oauth_profile else None
if "/" not in output_dataset_name and username:
full_dataset_name = f"{username}/{output_dataset_name}"
else:
full_dataset_name = output_dataset_name
progress(0.3, desc="Pushing to Hub...")
# Push to hub using the OAuth token
deduplicated_dataset.push_to_hub(
full_dataset_name, token=oauth_token.token, private=False
)
progress(1.0, desc="Complete!")
gr.Info(
f"Successfully pushed deduplicated dataset with {len(deduplicated_dataset)} rows to the Hub!"
)
return (
f"✅ **Dataset published:** [{full_dataset_name}]"
f"(https://huggingface.co/datasets/{full_dataset_name})"
)
except Exception as e:
raise gr.Error(f"Failed to push dataset to Hub: {str(e)}")
def get_user_info(oauth_profile: gr.OAuthProfile | None) -> str:
"""Display user login status."""
if oauth_profile is None:
return "Not logged in. Please log in to push datasets to the Hub."
return f"Logged in as: **{oauth_profile.username}**"
def update_push_button_state(oauth_profile: gr.OAuthProfile | None):
"""Update the push button state based on login status."""
is_logged_in = oauth_profile is not None
return gr.update(interactive=is_logged_in)
# --- Gradio App ---
with gr.Blocks(
theme=gr.themes.Ocean(), css="#status_output { height: 50px; overflow: auto; }"
) as demo:
gr.Markdown("# SemDedup-My-Dataset: Semantic Text Deduplication Using SemHash")
gr.Markdown("""
This demo showcases **semantic deduplication** using [SemHash](https://github.com/MinishLab/semhash) for HuggingFace datasets, using a [Model2Vec](https://github.com/MinishLab/model2vec) encoder.
It can be used to identify duplicate texts within a **single dataset** or across **two datasets**.
You can adjust the similarity threshold to control the strictness of the deduplication.
""")
deduplication_type = gr.Radio(
choices=["Cross-dataset", "Single dataset"],
label="Deduplication Type",
value="Cross-dataset", # default
)
with gr.Row():
dataset1_name = HuggingfaceHubSearch(
label="Dataset 1 Name",
placeholder="Search for datasets on HuggingFace Hub",
search_type="dataset",
value=default_dataset_name,
)
dataset1_split = gr.Textbox(
value=default_dataset1_split, label="Dataset 1 Split"
)
dataset1_text_column = gr.Textbox(
value=default_text_column, label="Text Column Name"
)
dataset2_inputs = gr.Column(visible=True)
with dataset2_inputs:
with gr.Row():
dataset2_name = HuggingfaceHubSearch(
label="Dataset 2 Name",
placeholder="Search for datasets on HuggingFace Hub",
search_type="dataset",
value=default_dataset_name,
)
dataset2_split = gr.Textbox(
value=default_dataset2_split, label="Dataset 2 Split"
)
dataset2_text_column = gr.Textbox(
value=default_text_column, label="Text Column Name"
)
threshold = gr.Slider(
0.0, 1.0, value=default_threshold, label="Similarity Threshold"
)
with gr.Row():
compute_button = gr.Button("Deduplicate", variant="primary")
status_output = gr.Markdown(elem_id="status_output")
# Examples table
examples_table = gr.Dataframe(
headers=["Original Text", "Duplicate Text", "Similarity Score"],
datatype=["str", "str", "str"],
)
# Hidden state to store the deduplicated dataset
deduplicated_dataset_state = gr.State()
# Output dataset configuration
gr.Markdown("## Push Deduplicated Dataset to Hub")
with gr.Row():
with gr.Column():
output_dataset_name = gr.Textbox(
label="Output Dataset Name",
placeholder="my-deduplicated-dataset",
info="Will be saved as username/dataset-name",
)
with gr.Column():
push_button = gr.Button(
"Push to Hub", variant="secondary", interactive=False
)
login_button = gr.LoginButton()
# Login section - moved below push to hub
with gr.Row():
user_info = gr.Markdown()
push_output = gr.Markdown()
# HACK: for some reason gradio wants this.
login_button.activate()
def update_visibility(choice: str):
return gr.update(visible=(choice == "Cross-dataset"))
deduplication_type.change(
update_visibility, inputs=deduplication_type, outputs=dataset2_inputs
)
# Update user info and button state when page loads or login status changes
demo.load(get_user_info, inputs=None, outputs=user_info)
demo.load(update_push_button_state, inputs=None, outputs=push_button)
login_button.click(get_user_info, inputs=None, outputs=user_info)
login_button.click(update_push_button_state, inputs=None, outputs=push_button)
compute_button.click(
fn=perform_deduplication,
inputs=[
deduplication_type,
dataset1_name,
dataset1_split,
dataset1_text_column,
dataset2_name,
dataset2_split,
dataset2_text_column,
threshold,
],
outputs=[deduplicated_dataset_state, examples_table],
)
push_button.click(
fn=push_to_hub,
inputs=[
deduplicated_dataset_state,
output_dataset_name,
],
outputs=push_output,
)
demo.launch()