|
|
|
import os |
|
import gradio as gr |
|
import requests |
|
import time |
|
from io import BytesIO |
|
import matplotlib.pyplot as plt |
|
from datasets import load_dataset |
|
from train_tokenizer import train_tokenizer |
|
from tokenizers import Tokenizer |
|
from langdetect import detect, DetectorFactory |
|
|
|
|
|
DetectorFactory.seed = 0 |
|
|
|
|
|
CHECKPOINT_FILE = "checkpoint.txt" |
|
TOKENIZER_DIR = "tokenizer_model" |
|
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json") |
|
MAX_SAMPLES = 50000000 |
|
|
|
|
|
STOP_COLLECTION = False |
|
|
|
|
|
def load_checkpoint(): |
|
"""Φόρτωση δεδομένων από το checkpoint αν υπάρχει.""" |
|
if os.path.exists(CHECKPOINT_FILE): |
|
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f: |
|
return f.read().splitlines() |
|
return [] |
|
|
|
|
|
def append_to_checkpoint(texts): |
|
"""Αποθήκευση δεδομένων στο αρχείο checkpoint.""" |
|
with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f: |
|
for t in texts: |
|
f.write(t + "\n") |
|
|
|
|
|
def create_iterator(dataset_name, configs, split): |
|
"""Φορτώνει το dataset και αποδίδει τα κείμενα ως iterator.""" |
|
configs_list = [c.strip() for c in configs.split(",") if c.strip()] |
|
for config in configs_list: |
|
try: |
|
dataset = load_dataset(dataset_name, name=config, split=split, streaming=True) |
|
for example in dataset: |
|
text = example.get('text', '') |
|
if text: |
|
yield text |
|
except Exception as e: |
|
print(f"⚠️ Σφάλμα φόρτωσης dataset για config {config}: {e}") |
|
|
|
|
|
def analyze_checkpoint(num_samples=1000): |
|
"""Αναλύει τα πρώτα num_samples δείγματα από το checkpoint και επιστρέφει το ποσοστό γλωσσών.""" |
|
if not os.path.exists(CHECKPOINT_FILE): |
|
return "Το αρχείο checkpoint δεν υπάρχει." |
|
|
|
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f: |
|
lines = f.read().splitlines() |
|
|
|
sample_lines = lines[:num_samples] if len(lines) >= num_samples else lines |
|
|
|
language_counts = {} |
|
total = 0 |
|
for line in sample_lines: |
|
try: |
|
lang = detect(line) |
|
language_counts[lang] = language_counts.get(lang, 0) + 1 |
|
total += 1 |
|
except Exception: |
|
continue |
|
|
|
if total == 0: |
|
return "Δεν βρέθηκαν έγκυρα δείγματα για ανάλυση." |
|
|
|
report = "📊 Αποτελέσματα Ανάλυσης:\n" |
|
for lang, count in language_counts.items(): |
|
report += f" - {lang}: {count / total * 100:.2f}%\n" |
|
|
|
return report |
|
|
|
|
|
def collect_samples(dataset_name, configs, split, chunk_size): |
|
""" |
|
Ξεκινά τη συλλογή δειγμάτων από το dataset μέχρι να φτάσει το MAX_SAMPLES |
|
ή μέχρι να ζητηθεί διακοπή (STOP_COLLECTION). |
|
""" |
|
global STOP_COLLECTION |
|
STOP_COLLECTION = False |
|
total_processed = len(load_checkpoint()) |
|
progress_messages = [f"📌 Ξεκινά η συλλογή... Υπάρχουν ήδη {total_processed} δείγματα στο checkpoint."] |
|
|
|
dataset_iterator = create_iterator(dataset_name, configs, split) |
|
new_texts = [] |
|
|
|
for text in dataset_iterator: |
|
if STOP_COLLECTION: |
|
progress_messages.append("⏹️ Η συλλογή σταμάτησε από το χρήστη.") |
|
break |
|
|
|
new_texts.append(text) |
|
total_processed += 1 |
|
|
|
if len(new_texts) >= chunk_size: |
|
append_to_checkpoint(new_texts) |
|
progress_messages.append(f"✅ Αποθηκεύτηκαν {total_processed} δείγματα στο checkpoint.") |
|
new_texts = [] |
|
|
|
if total_processed >= MAX_SAMPLES: |
|
progress_messages.append("⚠️ Έφτασε το όριο δειγμάτων.") |
|
break |
|
|
|
if new_texts: |
|
append_to_checkpoint(new_texts) |
|
progress_messages.append(f"✅ Τελικό batch αποθηκεύτηκε ({total_processed} δείγματα).") |
|
|
|
return "\n".join(progress_messages) |
|
|
|
|
|
def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text): |
|
"""Εκπαιδεύει τον tokenizer χρησιμοποιώντας τα δεδομένα του checkpoint.""" |
|
print("🚀 Ξεκινά η εκπαίδευση...") |
|
all_texts = load_checkpoint() |
|
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR) |
|
|
|
|
|
trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE) |
|
|
|
|
|
encoded = trained_tokenizer.encode(test_text) |
|
decoded = trained_tokenizer.decode(encoded.ids) |
|
|
|
|
|
token_lengths = [len(t) for t in encoded.tokens] |
|
fig = plt.figure() |
|
plt.hist(token_lengths, bins=20) |
|
plt.xlabel('Μήκος Token') |
|
plt.ylabel('Συχνότητα') |
|
img_buffer = BytesIO() |
|
plt.savefig(img_buffer, format='png') |
|
plt.close() |
|
|
|
return (f"✅ Εκπαίδευση ολοκληρώθηκε!\nΑποθηκεύτηκε στον φάκελο: {TOKENIZER_DIR}", |
|
decoded, |
|
img_buffer.getvalue()) |
|
|
|
|
|
def stop_collection(): |
|
"""Σταματά τη συλλογή δειγμάτων.""" |
|
global STOP_COLLECTION |
|
STOP_COLLECTION = True |
|
return "⏹️ Η συλλογή σταμάτησε από το χρήστη." |
|
|
|
|
|
def restart_collection(): |
|
"""Διαγράφει το checkpoint και επανεκκινεί τη συλλογή.""" |
|
global STOP_COLLECTION |
|
STOP_COLLECTION = False |
|
if os.path.exists(CHECKPOINT_FILE): |
|
os.remove(CHECKPOINT_FILE) |
|
return "🔄 Το checkpoint διαγράφηκε. Μπορείς να ξεκινήσεις νέα συλλογή." |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Wikipedia Tokenizer Trainer with Logs & Control") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset Name") |
|
configs = gr.Textbox(value="20231101.el,20231101.en", label="Configs") |
|
split = gr.Dropdown(choices=["train"], value="train", label="Split") |
|
chunk_size = gr.Slider(500, 50000, value=50000, label="Chunk Size") |
|
vocab_size = gr.Slider(20000, 100000, value=50000, label="Vocabulary Size") |
|
min_freq = gr.Slider(1, 100, value=3, label="Minimum Frequency") |
|
test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text") |
|
start_btn = gr.Button("Start Collection") |
|
stop_btn = gr.Button("Stop Collection") |
|
restart_btn = gr.Button("Restart Collection") |
|
analyze_btn = gr.Button("Analyze Samples") |
|
train_btn = gr.Button("Train Tokenizer") |
|
|
|
progress = gr.Textbox(label="Progress", interactive=False, lines=10) |
|
decoded_text = gr.Textbox(label="Decoded Text", interactive=False) |
|
token_distribution = gr.Image(label="Token Distribution") |
|
|
|
start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size], progress) |
|
stop_btn.click(stop_collection, [], progress) |
|
restart_btn.click(restart_collection, [], progress) |
|
analyze_btn.click(analyze_checkpoint, [], progress) |
|
train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text], |
|
[progress, decoded_text, token_distribution]) |
|
|
|
demo.launch() |