File size: 5,008 Bytes
20e7095
 
2dccd10
d18e6c8
a4cec6f
20e7095
729db5b
 
c91426b
3ea3aae
20e7095
 
 
 
 
 
c91426b
729db5b
2dccd10
729db5b
 
 
c91426b
9e6c3bb
1d6c7cd
729db5b
 
 
 
 
1d6c7cd
20e7095
2dccd10
20e7095
2dccd10
1d6c7cd
20e7095
 
a4cec6f
c91426b
 
 
20e7095
729db5b
2dccd10
 
 
9e6c3bb
c91426b
729db5b
c91426b
 
9e6c3bb
 
 
2dccd10
 
 
9e6c3bb
2dccd10
 
3ea3aae
 
729db5b
 
 
 
 
9e6c3bb
c91426b
a4cec6f
d18e6c8
 
c91426b
 
 
d18e6c8
c91426b
d18e6c8
 
 
 
3ea3aae
d18e6c8
 
c91426b
729db5b
 
d18e6c8
 
a4cec6f
729db5b
 
 
 
2dccd10
3ea3aae
729db5b
2dccd10
729db5b
 
2dccd10
 
729db5b
 
2dccd10
729db5b
 
 
 
 
 
 
2dccd10
729db5b
 
 
 
 
 
 
 
2dccd10
729db5b
 
c91426b
 
729db5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c91426b
 
d18e6c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from huggingface_hub import HfApi, Repository

# Matplotlib style
plt.rcParams.update({"font.family":"sans-serif", "font.size":10})

# Global DataFrame
df = pd.DataFrame()

def upload_csv(file):
    global df
    df = pd.read_csv(file.name)
    if "text" not in df.columns or "label" not in df.columns:
        return (
            None,  # hide table
            "❌ CSV must contain 'text' and 'label' columns.",
            gr.update(visible=False),  # hide actions
            gr.update(visible=False),  # hide outputs
            gr.update(visible=False),  # hide push accordion
        )
    df["label"] = df["label"].fillna("")
    return (
        df[["text","label"]],                 # show table
        "✅ Uploaded! Edit below or use the buttons.",
        gr.update(visible=True),              # show action row
        gr.update(visible=True),              # show output row
        gr.update(visible=True),              # show push accordion
    )

def save_changes(table):
    global df
    df = pd.DataFrame(table, columns=["text","label"])
    return "💾 Changes saved."

def download_csv():
    global df
    path = "annotated_data.csv"
    df.to_csv(path, index=False)
    return path

def make_figure():
    global df
    counts = df["label"].value_counts().sort_values(ascending=False)
    labels, values = counts.index.tolist(), counts.values.tolist()
    fig, (ax_table, ax_bar) = plt.subplots(
        ncols=2,
        gridspec_kw={"width_ratios":[1,2]},
        figsize=(8, max(2, len(labels)*0.4)),
        tight_layout=True
    )
    # Table
    ax_table.axis("off")
    data = [[l,v] for l,v in zip(labels, values)]
    tbl = ax_table.table(cellText=data, colLabels=["Label","Count"], loc="center")
    tbl.auto_set_font_size(False); tbl.set_fontsize(10); tbl.scale(1,1.2)
    # Bar chart
    ax_bar.barh(labels, values, color="#222")
    ax_bar.invert_yaxis(); ax_bar.set_xlabel("Count")
    return fig

def visualize_and_download():
    fig = make_figure()
    png_path = "label_distribution.png"
    fig.savefig(png_path, dpi=150, bbox_inches="tight")
    return fig, png_path

def push_to_hub(repo_name, hf_token):
    global df
    try:
        api = HfApi()
        api.create_repo(repo_id=repo_name, token=hf_token,
                        repo_type="dataset", exist_ok=True)
        local_dir = Path(f"./{repo_name.replace('/','_')}")
        if local_dir.exists():
            for f in local_dir.iterdir(): f.unlink()
            local_dir.rmdir()
        repo = Repository(
            local_dir=str(local_dir),
            clone_from=repo_name,
            repo_type="dataset",
            use_auth_token=hf_token
        )
        df.to_csv(local_dir/"data.csv", index=False)
        repo.push_to_hub(commit_message="📑 Updated data")
        return f"🚀 Pushed to datasets/{repo_name}"
    except Exception as e:
        return f"❌ Push failed: {e}"

with gr.Blocks() as app:
    gr.Markdown("## 🏷️ Label It! Text Annotation Tool")
    gr.Markdown("Upload a `.csv` (columns: **text**, **label**), then annotate, export, visualize, or push.")

    # Step 1: Upload
    with gr.Row():
        csv_input = gr.File(label="📁 Upload CSV", file_types=[".csv"])
        upload_btn = gr.Button("Upload")

    # Table + Status
    table = gr.Dataframe(headers=["text","label"], interactive=True, visible=False)
    status = gr.Textbox(label="Status", interactive=False)

    # Step 2: Actions (hidden initially)
    with gr.Row(visible=False) as action_row:
        save_btn      = gr.Button("💾 Save")
        download_btn  = gr.Button("⬇️ Download CSV")
        visualize_btn = gr.Button("📊 Visualize")

    # Outputs row (hidden initially)
    with gr.Row(visible=False) as output_row:
        download_csv_out   = gr.File(label="📥 CSV File")
        chart_plot         = gr.Plot(label="Label Distribution")
        download_chart_out = gr.File(label="📥 Chart PNG")

    # Push section
    push_accordion = gr.Accordion("📦 Push to Hugging Face Hub", open=False, visible=False)
    with push_accordion:
        repo_in     = gr.Textbox(label="Repo (username/dataset)")
        token_in    = gr.Textbox(label="🔑 HF Token", type="password")
        push_btn    = gr.Button("🚀 Push")
        push_status = gr.Textbox(label="Push Status", interactive=False)

    # Bind events
    upload_btn.click(
        upload_csv,
        inputs=csv_input,
        outputs=[table, status, action_row, output_row, push_accordion]
    )
    save_btn.click(
        save_changes,
        inputs=table,
        outputs=status
    )
    download_btn.click(
        download_csv,
        outputs=download_csv_out
    )
    visualize_btn.click(
        visualize_and_download,
        outputs=[chart_plot, download_chart_out]
    )
    push_btn.click(
        push_to_hub,
        inputs=[repo_in, token_in],
        outputs=push_status
    )

app.launch()