labelit / app.py
Suzana's picture
Update app.py
045377e verified
raw
history blame
4.91 kB
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from huggingface_hub import HfApi, Repository
# Matplotlib styling
plt.rcParams.update({"font.family":"sans-serif","font.size":10})
# Global DataFrame
df = pd.DataFrame()
def upload_csv(file):
global df
df = pd.read_csv(file.name)
if not {"text","label"}.issubset(df.columns):
return (
None,
"❌ CSV must contain 'text' and 'label' columns.",
gr.update(visible=False), # Save
gr.update(visible=False), # Download CSV
gr.update(visible=False), # Visualize
gr.update(visible=False), # Push accordion
)
df["label"] = df["label"].fillna("")
return (
df[["text","label"]],
"✅ Uploaded! You can now annotate and use the buttons below.",
gr.update(visible=True), # Save
gr.update(visible=True), # Download CSV
gr.update(visible=True), # Visualize
gr.update(visible=True), # Push accordion
)
def save_changes(table):
global df
df = pd.DataFrame(table, columns=["text","label"])
return "💾 Changes saved."
def download_csv():
global df
path = "annotated_data.csv"
df.to_csv(path, index=False)
return path
def make_figure():
counts = df["label"].value_counts().sort_values(ascending=False)
labels, values = list(counts.index), list(counts.values)
fig, (ax1, ax2) = plt.subplots(
ncols=2,
gridspec_kw={"width_ratios":[1,2]},
figsize=(8, max(2, len(labels)*0.4)),
tight_layout=True
)
ax1.axis("off")
tbl = ax1.table(
cellText=[[l,v] for l,v in zip(labels,values)],
colLabels=["Label","Count"],
loc="center"
)
tbl.auto_set_font_size(False); tbl.set_fontsize(10); tbl.scale(1,1.2)
ax2.barh(labels, values, color="#222222")
ax2.invert_yaxis(); ax2.set_xlabel("Count")
return fig
def visualize_and_download():
fig = make_figure()
png_path = "label_distribution.png"
fig.savefig(png_path, dpi=150, bbox_inches="tight")
return fig, png_path
def push_to_hub(repo_name, hf_token):
global df
try:
api = HfApi()
api.create_repo(repo_id=repo_name, token=hf_token,
repo_type="dataset", exist_ok=True)
local_dir = Path(f"./{repo_name.replace('/','_')}")
if local_dir.exists():
for f in local_dir.iterdir(): f.unlink()
local_dir.rmdir()
repo = Repository(
local_dir=str(local_dir),
clone_from=repo_name,
repo_type="dataset",
use_auth_token=hf_token
)
df.to_csv(local_dir/"data.csv", index=False)
repo.push_to_hub(commit_message="📑 Updated annotated data")
return f"🚀 Pushed to datasets/{repo_name}"
except Exception as e:
return f"❌ Push failed: {e}"
with gr.Blocks() as app:
gr.Markdown("## 🏷️ Label It! Text Annotation Tool")
gr.Markdown("Upload a `.csv` (columns: **text**, **label**), then annotate, export, visualize, or push.")
# Step 1: Upload
with gr.Row():
csv_input = gr.File(label="📁 Upload CSV", file_types=[".csv"])
upload_btn = gr.Button("Upload")
# Table + status
table = gr.Dataframe(headers=["text","label"], interactive=True, visible=False)
status = gr.Textbox(label="Status", interactive=False)
# Step 2: Actions (hidden initially)
with gr.Row(visible=False) as action_row:
save_btn = gr.Button("💾 Save")
download_btn = gr.Button("⬇️ Download CSV")
visualize_btn= gr.Button("📊 Visualize")
download_csv_out = gr.File(label="📥 Downloaded CSV", interactive=False)
chart_plot = gr.Plot(label="Label Distribution", visible=False)
download_chart_out = gr.File(label="📥 Downloaded Chart", interactive=False)
# Push controls
push_acc = gr.Accordion("📦 Push to Hugging Face Hub", open=False, visible=False)
with push_acc:
repo_in = gr.Textbox(label="Repo (username/dataset)")
token_in = gr.Textbox(label="🔑 HF Token", type="password")
push_btn = gr.Button("🚀 Push")
push_out = gr.Textbox(label="Push Status", interactive=False)
# Bind events
upload_btn.click(
upload_csv,
inputs=csv_input,
outputs=[table, status, save_btn, download_btn, visualize_btn, push_acc]
)
save_btn.click(
save_changes,
inputs=table,
outputs=status
)
download_btn.click(
download_csv,
outputs=download_csv_out
)
visualize_btn.click(
visualize_and_download,
outputs=[chart_plot, download_chart_out]
)
push_btn.click(
push_to_hub,
inputs=[repo_in, token_in],
outputs=push_out
)
app.launch()