|
import gradio as gr |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
from pathlib import Path |
|
from huggingface_hub import HfApi, Repository |
|
|
|
|
|
plt.rcParams.update({"font.family":"sans-serif","font.size":10}) |
|
|
|
|
|
df = pd.DataFrame() |
|
|
|
def upload_csv(file): |
|
global df |
|
df = pd.read_csv(file.name) |
|
if not {"text","label"}.issubset(df.columns): |
|
return ( |
|
None, |
|
"❌ CSV must contain 'text' and 'label' columns.", |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
df["label"] = df["label"].fillna("") |
|
return ( |
|
df[["text","label"]], |
|
"✅ Uploaded! You can now annotate and use the buttons below.", |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
) |
|
|
|
def save_changes(table): |
|
global df |
|
df = pd.DataFrame(table, columns=["text","label"]) |
|
return "💾 Changes saved." |
|
|
|
def download_csv(): |
|
global df |
|
path = "annotated_data.csv" |
|
df.to_csv(path, index=False) |
|
return path |
|
|
|
def make_figure(): |
|
counts = df["label"].value_counts().sort_values(ascending=False) |
|
labels, values = list(counts.index), list(counts.values) |
|
fig, (ax1, ax2) = plt.subplots( |
|
ncols=2, |
|
gridspec_kw={"width_ratios":[1,2]}, |
|
figsize=(8, max(2,len(labels)*0.4)), |
|
tight_layout=True |
|
) |
|
|
|
ax1.axis("off") |
|
tbl = ax1.table( |
|
cellText=[[l,v] for l,v in zip(labels,values)], |
|
colLabels=["Label","Count"], |
|
loc="center" |
|
) |
|
tbl.auto_set_font_size(False); tbl.set_fontsize(10); tbl.scale(1,1.2) |
|
|
|
ax2.barh(labels, values, color="#222222") |
|
ax2.invert_yaxis(); ax2.set_xlabel("Count") |
|
return fig |
|
|
|
def visualize_and_download(): |
|
fig = make_figure() |
|
out_png = "label_distribution.png" |
|
fig.savefig(out_png, dpi=150, bbox_inches="tight") |
|
return fig, out_png |
|
|
|
def push_to_hub(repo_name, hf_token): |
|
global df |
|
try: |
|
api = HfApi() |
|
api.create_repo(repo_id=repo_name, token=hf_token, |
|
repo_type="dataset", exist_ok=True) |
|
local_dir = Path(f"./{repo_name.replace('/','_')}") |
|
if local_dir.exists(): |
|
for f in local_dir.iterdir(): f.unlink() |
|
local_dir.rmdir() |
|
repo = Repository( |
|
local_dir=str(local_dir), |
|
clone_from=repo_name, |
|
repo_type="dataset", |
|
use_auth_token=hf_token |
|
) |
|
df.to_csv(local_dir/"data.csv", index=False) |
|
repo.push_to_hub(commit_message="📑 Updated annotated data") |
|
return f"🚀 Pushed to datasets/{repo_name}" |
|
except Exception as e: |
|
return f"❌ Push failed: {e}" |
|
|
|
with gr.Blocks() as app: |
|
gr.Markdown("## 🏷️ Label It! Text Annotation Tool") |
|
gr.Markdown("Upload a `.csv` (columns: **text**, **label**), then annotate, export, visualize, or push.") |
|
|
|
|
|
with gr.Row(): |
|
csv_in = gr.File(label="📁 Upload CSV", file_types=[".csv"]) |
|
upload_btn = gr.Button("Upload") |
|
|
|
|
|
table = gr.Dataframe(headers=["text","label"], interactive=True, visible=False) |
|
status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
with gr.Row(visible=False) as actions: |
|
save_btn = gr.Button("💾 Save") |
|
download_csv_btn = gr.DownloadButton(fn=download_csv, label="⬇️ Download CSV", file_name="annotated_data.csv") |
|
download_png_btn = gr.DownloadButton(fn=visualize_and_download, label="⬇️ Download Chart", file_name="label_distribution.png") |
|
visualize_btn = gr.Button("📊 Visualize") |
|
|
|
chart_plot = gr.Plot(visible=False) |
|
|
|
|
|
push_acc = gr.Accordion("📦 Push to Hugging Face Hub", open=False, visible=False) |
|
with push_acc: |
|
repo_in = gr.Textbox(label="Repo (username/dataset)") |
|
token_in = gr.Textbox(label="🔑 HF Token", type="password") |
|
push_btn = gr.Button("🚀 Push") |
|
push_out = gr.Textbox(label="Push Status", interactive=False) |
|
|
|
|
|
upload_btn.click( |
|
upload_csv, |
|
inputs=csv_in, |
|
outputs=[table, status, actions, push_acc] |
|
) |
|
save_btn.click( |
|
save_changes, |
|
inputs=table, |
|
outputs=status |
|
) |
|
visualize_btn.click( |
|
visualize_and_download, |
|
inputs=None, |
|
outputs=[chart_plot, download_png_btn] |
|
) |
|
push_btn.click( |
|
push_to_hub, |
|
inputs=[repo_in, token_in], |
|
outputs=push_out |
|
) |
|
|
|
app.launch() |
|
|