|
import gradio as gr |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from pathlib import Path |
|
from huggingface_hub import HfApi, Repository |
|
|
|
|
|
plt.rcParams.update({"font.family":"sans-serif", "font.size":10}) |
|
|
|
|
|
df = pd.DataFrame() |
|
|
|
def upload_csv(file): |
|
global df |
|
df = pd.read_csv(file.name) |
|
if "text" not in df.columns or "label" not in df.columns: |
|
return ( |
|
None, |
|
"❌ CSV must contain 'text' and 'label' columns.", |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
df["label"] = df["label"].fillna("") |
|
return ( |
|
df[["text","label"]], |
|
"✅ Uploaded! Edit below or use the buttons.", |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
) |
|
|
|
def save_changes(table): |
|
global df |
|
df = pd.DataFrame(table, columns=["text","label"]) |
|
return "💾 Changes saved." |
|
|
|
def download_csv(): |
|
global df |
|
path = "annotated_data.csv" |
|
df.to_csv(path, index=False) |
|
return path |
|
|
|
def make_figure(): |
|
global df |
|
counts = df["label"].value_counts().sort_values(ascending=False) |
|
labels, values = counts.index.tolist(), counts.values.tolist() |
|
fig, (ax_table, ax_bar) = plt.subplots( |
|
ncols=2, |
|
gridspec_kw={"width_ratios":[1,2]}, |
|
figsize=(8, max(2, len(labels)*0.4)), |
|
tight_layout=True |
|
) |
|
|
|
ax_table.axis("off") |
|
data = [[l,v] for l,v in zip(labels, values)] |
|
tbl = ax_table.table(cellText=data, colLabels=["Label","Count"], loc="center") |
|
tbl.auto_set_font_size(False); tbl.set_fontsize(10); tbl.scale(1,1.2) |
|
|
|
ax_bar.barh(labels, values, color="#222") |
|
ax_bar.invert_yaxis(); ax_bar.set_xlabel("Count") |
|
return fig |
|
|
|
def visualize_and_download(): |
|
fig = make_figure() |
|
png_path = "label_distribution.png" |
|
fig.savefig(png_path, dpi=150, bbox_inches="tight") |
|
return fig, png_path |
|
|
|
def push_to_hub(repo_name, hf_token): |
|
global df |
|
try: |
|
api = HfApi() |
|
api.create_repo(repo_id=repo_name, token=hf_token, |
|
repo_type="dataset", exist_ok=True) |
|
local_dir = Path(f"./{repo_name.replace('/','_')}") |
|
if local_dir.exists(): |
|
for f in local_dir.iterdir(): f.unlink() |
|
local_dir.rmdir() |
|
repo = Repository( |
|
local_dir=str(local_dir), |
|
clone_from=repo_name, |
|
repo_type="dataset", |
|
use_auth_token=hf_token |
|
) |
|
df.to_csv(local_dir/"data.csv", index=False) |
|
repo.push_to_hub(commit_message="📑 Updated data") |
|
return f"🚀 Pushed to datasets/{repo_name}" |
|
except Exception as e: |
|
return f"❌ Push failed: {e}" |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("## 🏷️ Label It! Text Annotation Tool") |
|
gr.Markdown("Upload a `.csv` (columns: **text**, **label**), then annotate, export, visualize, or push.") |
|
|
|
|
|
with gr.Row(): |
|
csv_input = gr.File(label="📁 Upload CSV", file_types=[".csv"]) |
|
upload_btn = gr.Button("Upload") |
|
|
|
|
|
table = gr.Dataframe(headers=["text","label"], interactive=True, visible=False) |
|
status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
with gr.Row(visible=False) as action_row: |
|
save_btn = gr.Button("💾 Save") |
|
download_btn = gr.Button("⬇️ Download CSV") |
|
visualize_btn = gr.Button("📊 Visualize") |
|
|
|
|
|
with gr.Row(visible=False) as output_row: |
|
download_csv_out = gr.File(label="📥 CSV File") |
|
chart_plot = gr.Plot(label="Label Distribution") |
|
download_chart_out = gr.File(label="📥 Chart PNG") |
|
|
|
|
|
push_accordion = gr.Accordion("📦 Push to Hugging Face Hub", open=False, visible=False) |
|
with push_accordion: |
|
repo_in = gr.Textbox(label="Repo (username/dataset)") |
|
token_in = gr.Textbox(label="🔑 HF Token", type="password") |
|
push_btn = gr.Button("🚀 Push") |
|
push_status = gr.Textbox(label="Push Status", interactive=False) |
|
|
|
|
|
upload_btn.click( |
|
upload_csv, |
|
inputs=csv_input, |
|
outputs=[table, status, action_row, output_row, push_accordion] |
|
) |
|
save_btn.click( |
|
save_changes, |
|
inputs=table, |
|
outputs=status |
|
) |
|
download_btn.click( |
|
download_csv, |
|
outputs=download_csv_out |
|
) |
|
visualize_btn.click( |
|
visualize_and_download, |
|
outputs=[chart_plot, download_chart_out] |
|
) |
|
push_btn.click( |
|
push_to_hub, |
|
inputs=[repo_in, token_in], |
|
outputs=push_status |
|
) |
|
|
|
app.launch() |
|
|