File size: 4,693 Bytes
20e7095
 
2dccd10
d18e6c8
a4cec6f
20e7095
729db5b
d304161
c91426b
3ea3aae
20e7095
 
 
 
 
d304161
c91426b
d304161
2dccd10
d304161
 
c91426b
9e6c3bb
1d6c7cd
d304161
 
 
 
1d6c7cd
20e7095
2dccd10
20e7095
2dccd10
1d6c7cd
20e7095
 
a4cec6f
c91426b
 
 
20e7095
729db5b
2dccd10
d304161
 
c91426b
729db5b
d304161
c91426b
9e6c3bb
d304161
 
 
 
 
 
 
2dccd10
d304161
 
 
3ea3aae
 
729db5b
 
d304161
 
 
9e6c3bb
c91426b
a4cec6f
d18e6c8
 
c91426b
 
 
d18e6c8
c91426b
d18e6c8
 
 
 
3ea3aae
d18e6c8
 
c91426b
d304161
729db5b
d18e6c8
 
a4cec6f
729db5b
 
 
 
d304161
3ea3aae
d304161
2dccd10
729db5b
d304161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c91426b
 
d304161
 
729db5b
 
 
 
 
 
 
 
d304161
 
729db5b
 
 
 
d304161
c91426b
 
d18e6c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from huggingface_hub import HfApi, Repository

# Matplotlib style
plt.rcParams.update({"font.family":"sans-serif","font.size":10})

# Global DataFrame
df = pd.DataFrame()

def upload_csv(file):
    global df
    df = pd.read_csv(file.name)
    if not {"text","label"}.issubset(df.columns):
        return (
            None, 
            "❌ CSV must contain 'text' and 'label' columns.",
            gr.update(visible=False), # hide action buttons
            gr.update(visible=False), # hide push
        )
    df["label"] = df["label"].fillna("")
    return (
        df[["text","label"]],
        "✅ Uploaded! You can now annotate and use the buttons below.",
        gr.update(visible=True),  # show action buttons
        gr.update(visible=True),  # show push accordion
    )

def save_changes(table):
    global df
    df = pd.DataFrame(table, columns=["text","label"])
    return "💾 Changes saved."

def download_csv():
    global df
    path = "annotated_data.csv"
    df.to_csv(path, index=False)
    return path

def make_figure():
    counts = df["label"].value_counts().sort_values(ascending=False)
    labels, values = list(counts.index), list(counts.values)
    fig, (ax1, ax2) = plt.subplots(
        ncols=2,
        gridspec_kw={"width_ratios":[1,2]},
        figsize=(8, max(2,len(labels)*0.4)),
        tight_layout=True
    )
    # table
    ax1.axis("off")
    tbl = ax1.table(
        cellText=[[l,v] for l,v in zip(labels,values)],
        colLabels=["Label","Count"],
        loc="center"
    )
    tbl.auto_set_font_size(False); tbl.set_fontsize(10); tbl.scale(1,1.2)
    # bar chart
    ax2.barh(labels, values, color="#222222")
    ax2.invert_yaxis(); ax2.set_xlabel("Count")
    return fig

def visualize_and_download():
    fig = make_figure()
    out_png = "label_distribution.png"
    fig.savefig(out_png, dpi=150, bbox_inches="tight")
    return fig, out_png

def push_to_hub(repo_name, hf_token):
    global df
    try:
        api = HfApi()
        api.create_repo(repo_id=repo_name, token=hf_token,
                        repo_type="dataset", exist_ok=True)
        local_dir = Path(f"./{repo_name.replace('/','_')}")
        if local_dir.exists():
            for f in local_dir.iterdir(): f.unlink()
            local_dir.rmdir()
        repo = Repository(
            local_dir=str(local_dir),
            clone_from=repo_name,
            repo_type="dataset",
            use_auth_token=hf_token
        )
        df.to_csv(local_dir/"data.csv", index=False)
        repo.push_to_hub(commit_message="📑 Updated annotated data")
        return f"🚀 Pushed to datasets/{repo_name}"
    except Exception as e:
        return f"❌ Push failed: {e}"

with gr.Blocks() as app:
    gr.Markdown("## 🏷️ Label It! Text Annotation Tool")
    gr.Markdown("Upload a `.csv` (columns: **text**, **label**), then annotate, export, visualize, or push.")

    # STEP 1: Upload
    with gr.Row():
        csv_in = gr.File(label="📁 Upload CSV", file_types=[".csv"])
        upload_btn = gr.Button("Upload")

    # Editable table + status
    table     = gr.Dataframe(headers=["text","label"], interactive=True, visible=False)
    status    = gr.Textbox(label="Status", interactive=False)

    # STEP 2: Action buttons (hidden initially)
    with gr.Row(visible=False) as actions:
        save_btn         = gr.Button("💾 Save")
        download_csv_btn = gr.DownloadButton(fn=download_csv, label="⬇️ Download CSV", file_name="annotated_data.csv")
        download_png_btn = gr.DownloadButton(fn=visualize_and_download, label="⬇️ Download Chart", file_name="label_distribution.png")
        visualize_btn    = gr.Button("📊 Visualize")

    chart_plot = gr.Plot(visible=False)

    # Push accordion
    push_acc   = gr.Accordion("📦 Push to Hugging Face Hub", open=False, visible=False)
    with push_acc:
        repo_in   = gr.Textbox(label="Repo (username/dataset)")
        token_in  = gr.Textbox(label="🔑 HF Token", type="password")
        push_btn  = gr.Button("🚀 Push")
        push_out  = gr.Textbox(label="Push Status", interactive=False)

    # Event wiring
    upload_btn.click(
        upload_csv,
        inputs=csv_in,
        outputs=[table, status, actions, push_acc]
    )
    save_btn.click(
        save_changes,
        inputs=table,
        outputs=status
    )
    visualize_btn.click(
        visualize_and_download,
        inputs=None,
        outputs=[chart_plot, download_png_btn]
    )
    push_btn.click(
        push_to_hub,
        inputs=[repo_in, token_in],
        outputs=push_out
    )

app.launch()