labelit / app.py
Suzana's picture
Update app.py
d304161 verified
raw
history blame
4.69 kB
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from huggingface_hub import HfApi, Repository
# Matplotlib style
plt.rcParams.update({"font.family":"sans-serif","font.size":10})
# Global DataFrame
df = pd.DataFrame()
def upload_csv(file):
global df
df = pd.read_csv(file.name)
if not {"text","label"}.issubset(df.columns):
return (
None,
"❌ CSV must contain 'text' and 'label' columns.",
gr.update(visible=False), # hide action buttons
gr.update(visible=False), # hide push
)
df["label"] = df["label"].fillna("")
return (
df[["text","label"]],
"✅ Uploaded! You can now annotate and use the buttons below.",
gr.update(visible=True), # show action buttons
gr.update(visible=True), # show push accordion
)
def save_changes(table):
global df
df = pd.DataFrame(table, columns=["text","label"])
return "💾 Changes saved."
def download_csv():
global df
path = "annotated_data.csv"
df.to_csv(path, index=False)
return path
def make_figure():
counts = df["label"].value_counts().sort_values(ascending=False)
labels, values = list(counts.index), list(counts.values)
fig, (ax1, ax2) = plt.subplots(
ncols=2,
gridspec_kw={"width_ratios":[1,2]},
figsize=(8, max(2,len(labels)*0.4)),
tight_layout=True
)
# table
ax1.axis("off")
tbl = ax1.table(
cellText=[[l,v] for l,v in zip(labels,values)],
colLabels=["Label","Count"],
loc="center"
)
tbl.auto_set_font_size(False); tbl.set_fontsize(10); tbl.scale(1,1.2)
# bar chart
ax2.barh(labels, values, color="#222222")
ax2.invert_yaxis(); ax2.set_xlabel("Count")
return fig
def visualize_and_download():
fig = make_figure()
out_png = "label_distribution.png"
fig.savefig(out_png, dpi=150, bbox_inches="tight")
return fig, out_png
def push_to_hub(repo_name, hf_token):
global df
try:
api = HfApi()
api.create_repo(repo_id=repo_name, token=hf_token,
repo_type="dataset", exist_ok=True)
local_dir = Path(f"./{repo_name.replace('/','_')}")
if local_dir.exists():
for f in local_dir.iterdir(): f.unlink()
local_dir.rmdir()
repo = Repository(
local_dir=str(local_dir),
clone_from=repo_name,
repo_type="dataset",
use_auth_token=hf_token
)
df.to_csv(local_dir/"data.csv", index=False)
repo.push_to_hub(commit_message="📑 Updated annotated data")
return f"🚀 Pushed to datasets/{repo_name}"
except Exception as e:
return f"❌ Push failed: {e}"
with gr.Blocks() as app:
gr.Markdown("## 🏷️ Label It! Text Annotation Tool")
gr.Markdown("Upload a `.csv` (columns: **text**, **label**), then annotate, export, visualize, or push.")
# STEP 1: Upload
with gr.Row():
csv_in = gr.File(label="📁 Upload CSV", file_types=[".csv"])
upload_btn = gr.Button("Upload")
# Editable table + status
table = gr.Dataframe(headers=["text","label"], interactive=True, visible=False)
status = gr.Textbox(label="Status", interactive=False)
# STEP 2: Action buttons (hidden initially)
with gr.Row(visible=False) as actions:
save_btn = gr.Button("💾 Save")
download_csv_btn = gr.DownloadButton(fn=download_csv, label="⬇️ Download CSV", file_name="annotated_data.csv")
download_png_btn = gr.DownloadButton(fn=visualize_and_download, label="⬇️ Download Chart", file_name="label_distribution.png")
visualize_btn = gr.Button("📊 Visualize")
chart_plot = gr.Plot(visible=False)
# Push accordion
push_acc = gr.Accordion("📦 Push to Hugging Face Hub", open=False, visible=False)
with push_acc:
repo_in = gr.Textbox(label="Repo (username/dataset)")
token_in = gr.Textbox(label="🔑 HF Token", type="password")
push_btn = gr.Button("🚀 Push")
push_out = gr.Textbox(label="Push Status", interactive=False)
# Event wiring
upload_btn.click(
upload_csv,
inputs=csv_in,
outputs=[table, status, actions, push_acc]
)
save_btn.click(
save_changes,
inputs=table,
outputs=status
)
visualize_btn.click(
visualize_and_download,
inputs=None,
outputs=[chart_plot, download_png_btn]
)
push_btn.click(
push_to_hub,
inputs=[repo_in, token_in],
outputs=push_out
)
app.launch()