|
import gradio as gr |
|
import pandas as pd |
|
import io |
|
import os |
|
from pathlib import Path |
|
from huggingface_hub import HfApi, Repository |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
df = pd.DataFrame() |
|
|
|
|
|
DEFAULT_MODELS = [ |
|
"mistralai/Mistral-7B-Instruct-v0.2", |
|
"HuggingFaceH4/zephyr-7b-beta", |
|
"tiiuae/falcon-rw-1b", |
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
] |
|
|
|
def upload_csv(file): |
|
global df |
|
df = pd.read_csv(file.name) |
|
if "text" not in df.columns or "label" not in df.columns: |
|
return gr.update(visible=False), "❌ CSV must contain ‘text’ and ‘label’ columns." |
|
df["label"] = df["label"].fillna("") |
|
return ( |
|
gr.update(value=df[["text","label"]], visible=True), |
|
"✅ File uploaded — you can now edit labels." |
|
) |
|
|
|
def save_changes(edited_table): |
|
global df |
|
df = pd.DataFrame(edited_table, columns=["text","label"]) |
|
return "💾 Changes saved." |
|
|
|
def download_csv(): |
|
global df |
|
out_path = "annotated_data.csv" |
|
df.to_csv(out_path, index=False) |
|
return out_path |
|
|
|
def visualize_distribution(): |
|
global df |
|
if df.empty or "label" not in df.columns: |
|
return None |
|
counts = df["label"].value_counts() |
|
fig, ax = plt.subplots() |
|
counts.plot(kind="bar", ax=ax) |
|
ax.set_title("Label Distribution") |
|
ax.set_xlabel("Label") |
|
ax.set_ylabel("Count") |
|
plt.tight_layout() |
|
return fig |
|
|
|
def push_to_hub(repo_name: str, hf_token: str) -> str: |
|
global df |
|
try: |
|
api = HfApi() |
|
api.create_repo( |
|
repo_id=repo_name, |
|
token=hf_token, |
|
repo_type="dataset", |
|
exist_ok=True |
|
) |
|
|
|
local_dir = Path(f"./{repo_name.replace('/', '_')}") |
|
if local_dir.exists(): |
|
for child in local_dir.iterdir(): |
|
child.unlink() |
|
local_dir.rmdir() |
|
|
|
repo = Repository( |
|
local_dir=str(local_dir), |
|
clone_from=repo_name, |
|
repo_type="dataset", |
|
use_auth_token=hf_token |
|
) |
|
|
|
csv_path = local_dir / "data.csv" |
|
df.to_csv(csv_path, index=False) |
|
repo.push_to_hub(commit_message="📑 Update annotated data") |
|
return f"🚀 Pushed to https://huggingface.co/datasets/{repo_name}" |
|
|
|
except Exception as e: |
|
return f"❌ Push failed: {e}" |
|
|
|
with gr.Blocks(theme=gr.themes.Default()) as app: |
|
gr.Markdown("## 🏷️ Label It! Text Annotation Tool") |
|
gr.Markdown("Upload a `.csv` with **text** + **label** columns, annotate in-place, then export, visualize, or publish.") |
|
|
|
with gr.Row(): |
|
file_input = gr.File(label="📁 Upload CSV", file_types=[".csv"]) |
|
upload_btn = gr.Button("Upload") |
|
|
|
df_table = gr.Dataframe( |
|
headers=["text","label"], |
|
label="📝 Editable Table", |
|
interactive=True, |
|
visible=False |
|
) |
|
status = gr.Textbox(label="Status", interactive=False) |
|
|
|
with gr.Row(): |
|
save_btn = gr.Button("💾 Save") |
|
download_btn = gr.Button("⬇️ Download CSV") |
|
visualize_btn= gr.Button("📊 Visualize Distribution") |
|
download_out = gr.File(label="📥 Downloaded File") |
|
viz_out = gr.Plot(label="Label Distribution") |
|
|
|
with gr.Row(): |
|
model_dropdown = gr.Dropdown( |
|
label="🤖 (Future) Auto-Label Model", |
|
choices=DEFAULT_MODELS, |
|
value=DEFAULT_MODELS[0] |
|
) |
|
|
|
with gr.Accordion("📦 Push to Hugging Face Hub", open=False): |
|
repo_input = gr.Textbox(label="Repo (username/dataset-name)") |
|
token_input = gr.Textbox(label="🔑 HF Token", type="password") |
|
push_btn = gr.Button("🚀 Push") |
|
push_status = gr.Textbox(label="Push Status", interactive=False) |
|
|
|
|
|
upload_btn.click(upload_csv, inputs=file_input, outputs=[df_table, status]) |
|
save_btn.click( save_changes, inputs=df_table, outputs=status) |
|
download_btn.click(download_csv, outputs=download_out) |
|
visualize_btn.click(visualize_distribution, outputs=viz_out) |
|
push_btn.click( push_to_hub, inputs=[repo_input, token_input], outputs=push_status) |
|
|
|
app.launch() |
|
|