GPT2-PBE / app.py
tymbos's picture
Update app.py
24cb1cd verified
raw
history blame
8.02 kB
# -*- coding: utf-8 -*-
import os
import gradio as gr
import requests
import time
from io import BytesIO
import matplotlib.pyplot as plt
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
from langdetect import detect, DetectorFactory
# Για επαναληψιμότητα στο langdetect
DetectorFactory.seed = 0
# Ρυθμίσεις checkpointing και αποθήκευσης του tokenizer
CHECKPOINT_FILE = "checkpoint.txt"
TOKENIZER_DIR = "tokenizer_model"
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
MAX_SAMPLES = 50000000 # Όριο δειγμάτων
# Παγκόσμια μεταβλητή ελέγχου συλλογής
STOP_COLLECTION = False
def load_checkpoint():
"""Φόρτωση δεδομένων από το checkpoint αν υπάρχει."""
if os.path.exists(CHECKPOINT_FILE):
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
return f.read().splitlines()
return []
def append_to_checkpoint(texts):
"""Αποθήκευση δεδομένων στο αρχείο checkpoint."""
with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f:
for t in texts:
f.write(t + "\n")
def create_iterator(dataset_name, configs, split):
"""Φορτώνει το dataset και αποδίδει τα κείμενα ως iterator."""
configs_list = [c.strip() for c in configs.split(",") if c.strip()]
for config in configs_list:
try:
dataset = load_dataset(dataset_name, name=config, split=split, streaming=True)
for example in dataset:
text = example.get('text', '')
if text:
yield text
except Exception as e:
print(f"⚠️ Σφάλμα φόρτωσης dataset για config {config}: {e}")
def analyze_checkpoint(num_samples=1000):
"""Αναλύει τα πρώτα num_samples δείγματα από το checkpoint και επιστρέφει το ποσοστό γλωσσών."""
if not os.path.exists(CHECKPOINT_FILE):
return "Το αρχείο checkpoint δεν υπάρχει."
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
lines = f.read().splitlines()
sample_lines = lines[:num_samples] if len(lines) >= num_samples else lines
language_counts = {}
total = 0
for line in sample_lines:
try:
lang = detect(line)
language_counts[lang] = language_counts.get(lang, 0) + 1
total += 1
except Exception:
continue
if total == 0:
return "Δεν βρέθηκαν έγκυρα δείγματα για ανάλυση."
report = "📊 Αποτελέσματα Ανάλυσης:\n"
for lang, count in language_counts.items():
report += f" - {lang}: {count / total * 100:.2f}%\n"
return report
def collect_samples(dataset_name, configs, split, chunk_size):
"""
Ξεκινά τη συλλογή δειγμάτων από το dataset μέχρι να φτάσει το MAX_SAMPLES
ή μέχρι να ζητηθεί διακοπή (STOP_COLLECTION).
"""
global STOP_COLLECTION
STOP_COLLECTION = False # Βεβαιωνόμαστε ότι η συλλογή ξεκινάει κανονικά
total_processed = len(load_checkpoint())
progress_messages = [f"📌 Ξεκινά η συλλογή... Υπάρχουν ήδη {total_processed} δείγματα στο checkpoint."]
dataset_iterator = create_iterator(dataset_name, configs, split)
new_texts = []
for text in dataset_iterator:
if STOP_COLLECTION:
progress_messages.append("⏹️ Η συλλογή σταμάτησε από το χρήστη.")
break
new_texts.append(text)
total_processed += 1
if len(new_texts) >= chunk_size:
append_to_checkpoint(new_texts)
progress_messages.append(f"✅ Αποθηκεύτηκαν {total_processed} δείγματα στο checkpoint.")
new_texts = []
if total_processed >= MAX_SAMPLES:
progress_messages.append("⚠️ Έφτασε το όριο δειγμάτων.")
break
if new_texts:
append_to_checkpoint(new_texts)
progress_messages.append(f"✅ Τελικό batch αποθηκεύτηκε ({total_processed} δείγματα).")
return "\n".join(progress_messages)
def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text):
"""Εκπαιδεύει τον tokenizer χρησιμοποιώντας τα δεδομένα του checkpoint."""
print("🚀 Ξεκινά η εκπαίδευση...")
all_texts = load_checkpoint()
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR)
# Φόρτωση εκπαιδευμένου tokenizer
trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
# Δοκιμή
encoded = trained_tokenizer.encode(test_text)
decoded = trained_tokenizer.decode(encoded.ids)
# Γράφημα κατανομής tokens
token_lengths = [len(t) for t in encoded.tokens]
fig = plt.figure()
plt.hist(token_lengths, bins=20)
plt.xlabel('Μήκος Token')
plt.ylabel('Συχνότητα')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
return (f"✅ Εκπαίδευση ολοκληρώθηκε!\nΑποθηκεύτηκε στον φάκελο: {TOKENIZER_DIR}",
decoded,
img_buffer.getvalue())
def stop_collection():
"""Σταματά τη συλλογή δειγμάτων."""
global STOP_COLLECTION
STOP_COLLECTION = True
return "⏹️ Η συλλογή σταμάτησε από το χρήστη."
def restart_collection():
"""Διαγράφει το checkpoint και επανεκκινεί τη συλλογή."""
global STOP_COLLECTION
STOP_COLLECTION = False
if os.path.exists(CHECKPOINT_FILE):
os.remove(CHECKPOINT_FILE)
return "🔄 Το checkpoint διαγράφηκε. Μπορείς να ξεκινήσεις νέα συλλογή."
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Wikipedia Tokenizer Trainer with Logs & Control")
with gr.Row():
with gr.Column():
dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset Name")
configs = gr.Textbox(value="20231101.el,20231101.en", label="Configs")
split = gr.Dropdown(choices=["train"], value="train", label="Split")
chunk_size = gr.Slider(500, 50000, value=50000, label="Chunk Size")
vocab_size = gr.Slider(20000, 100000, value=50000, label="Vocabulary Size")
min_freq = gr.Slider(1, 100, value=3, label="Minimum Frequency")
test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
start_btn = gr.Button("Start Collection")
stop_btn = gr.Button("Stop Collection")
restart_btn = gr.Button("Restart Collection")
analyze_btn = gr.Button("Analyze Samples")
train_btn = gr.Button("Train Tokenizer")
progress = gr.Textbox(label="Progress", interactive=False, lines=10)
decoded_text = gr.Textbox(label="Decoded Text", interactive=False)
token_distribution = gr.Image(label="Token Distribution")
start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size], progress)
stop_btn.click(stop_collection, [], progress)
restart_btn.click(restart_collection, [], progress)
analyze_btn.click(analyze_checkpoint, [], progress)
train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text],
[progress, decoded_text, token_distribution])
demo.launch()