File size: 8,023 Bytes
7c5aa99 f94c5ea 0430da2 4a4435c 5c35386 0430da2 4a4435c 0430da2 f94c5ea 4a4435c 5d41434 f94c5ea 5d41434 3d37920 92bb5cb 9dd78b5 0430da2 5c35386 a9ae246 3d37920 a9ae246 3d37920 a9ae246 3d37920 1c51cb8 f94c5ea 5d41434 5c35386 5d41434 af09211 5d41434 af09211 5d41434 5c35386 5d41434 5c35386 9dd78b5 5c35386 d6a5933 5c35386 f94c5ea 5d41434 f94c5ea 9dd78b5 af09211 9dd78b5 5c35386 f94c5ea d6a5933 5c35386 3d37920 d6a5933 9d049fd 5c35386 9dd78b5 d6a5933 9dd78b5 d6a5933 3d37920 d6a5933 5d41434 d6a5933 5c35386 d6a5933 5c35386 3d37920 5c35386 3d37920 5c35386 f94c5ea 3d37920 f94c5ea 9dd78b5 d6a5933 9dd78b5 5c35386 9dd78b5 af09211 9dd78b5 5c35386 9dd78b5 5c35386 9dd78b5 d6a5933 af09211 0430da2 5c35386 4a4435c 3d37920 5c35386 af09211 5d41434 24cb1cd 5d41434 d6a5933 9dd78b5 5c35386 d6a5933 5c35386 60ccdc1 5c35386 60ccdc1 0044b58 3d37920 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
# -*- coding: utf-8 -*-
import os
import gradio as gr
import requests
import time
from io import BytesIO
import matplotlib.pyplot as plt
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
from langdetect import detect, DetectorFactory
# Για επαναληψιμότητα στο langdetect
DetectorFactory.seed = 0
# Ρυθμίσεις checkpointing και αποθήκευσης του tokenizer
CHECKPOINT_FILE = "checkpoint.txt"
TOKENIZER_DIR = "tokenizer_model"
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
MAX_SAMPLES = 50000000 # Όριο δειγμάτων
# Παγκόσμια μεταβλητή ελέγχου συλλογής
STOP_COLLECTION = False
def load_checkpoint():
"""Φόρτωση δεδομένων από το checkpoint αν υπάρχει."""
if os.path.exists(CHECKPOINT_FILE):
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
return f.read().splitlines()
return []
def append_to_checkpoint(texts):
"""Αποθήκευση δεδομένων στο αρχείο checkpoint."""
with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f:
for t in texts:
f.write(t + "\n")
def create_iterator(dataset_name, configs, split):
"""Φορτώνει το dataset και αποδίδει τα κείμενα ως iterator."""
configs_list = [c.strip() for c in configs.split(",") if c.strip()]
for config in configs_list:
try:
dataset = load_dataset(dataset_name, name=config, split=split, streaming=True)
for example in dataset:
text = example.get('text', '')
if text:
yield text
except Exception as e:
print(f"⚠️ Σφάλμα φόρτωσης dataset για config {config}: {e}")
def analyze_checkpoint(num_samples=1000):
"""Αναλύει τα πρώτα num_samples δείγματα από το checkpoint και επιστρέφει το ποσοστό γλωσσών."""
if not os.path.exists(CHECKPOINT_FILE):
return "Το αρχείο checkpoint δεν υπάρχει."
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
lines = f.read().splitlines()
sample_lines = lines[:num_samples] if len(lines) >= num_samples else lines
language_counts = {}
total = 0
for line in sample_lines:
try:
lang = detect(line)
language_counts[lang] = language_counts.get(lang, 0) + 1
total += 1
except Exception:
continue
if total == 0:
return "Δεν βρέθηκαν έγκυρα δείγματα για ανάλυση."
report = "📊 Αποτελέσματα Ανάλυσης:\n"
for lang, count in language_counts.items():
report += f" - {lang}: {count / total * 100:.2f}%\n"
return report
def collect_samples(dataset_name, configs, split, chunk_size):
"""
Ξεκινά τη συλλογή δειγμάτων από το dataset μέχρι να φτάσει το MAX_SAMPLES
ή μέχρι να ζητηθεί διακοπή (STOP_COLLECTION).
"""
global STOP_COLLECTION
STOP_COLLECTION = False # Βεβαιωνόμαστε ότι η συλλογή ξεκινάει κανονικά
total_processed = len(load_checkpoint())
progress_messages = [f"📌 Ξεκινά η συλλογή... Υπάρχουν ήδη {total_processed} δείγματα στο checkpoint."]
dataset_iterator = create_iterator(dataset_name, configs, split)
new_texts = []
for text in dataset_iterator:
if STOP_COLLECTION:
progress_messages.append("⏹️ Η συλλογή σταμάτησε από το χρήστη.")
break
new_texts.append(text)
total_processed += 1
if len(new_texts) >= chunk_size:
append_to_checkpoint(new_texts)
progress_messages.append(f"✅ Αποθηκεύτηκαν {total_processed} δείγματα στο checkpoint.")
new_texts = []
if total_processed >= MAX_SAMPLES:
progress_messages.append("⚠️ Έφτασε το όριο δειγμάτων.")
break
if new_texts:
append_to_checkpoint(new_texts)
progress_messages.append(f"✅ Τελικό batch αποθηκεύτηκε ({total_processed} δείγματα).")
return "\n".join(progress_messages)
def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text):
"""Εκπαιδεύει τον tokenizer χρησιμοποιώντας τα δεδομένα του checkpoint."""
print("🚀 Ξεκινά η εκπαίδευση...")
all_texts = load_checkpoint()
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR)
# Φόρτωση εκπαιδευμένου tokenizer
trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
# Δοκιμή
encoded = trained_tokenizer.encode(test_text)
decoded = trained_tokenizer.decode(encoded.ids)
# Γράφημα κατανομής tokens
token_lengths = [len(t) for t in encoded.tokens]
fig = plt.figure()
plt.hist(token_lengths, bins=20)
plt.xlabel('Μήκος Token')
plt.ylabel('Συχνότητα')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
return (f"✅ Εκπαίδευση ολοκληρώθηκε!\nΑποθηκεύτηκε στον φάκελο: {TOKENIZER_DIR}",
decoded,
img_buffer.getvalue())
def stop_collection():
"""Σταματά τη συλλογή δειγμάτων."""
global STOP_COLLECTION
STOP_COLLECTION = True
return "⏹️ Η συλλογή σταμάτησε από το χρήστη."
def restart_collection():
"""Διαγράφει το checkpoint και επανεκκινεί τη συλλογή."""
global STOP_COLLECTION
STOP_COLLECTION = False
if os.path.exists(CHECKPOINT_FILE):
os.remove(CHECKPOINT_FILE)
return "🔄 Το checkpoint διαγράφηκε. Μπορείς να ξεκινήσεις νέα συλλογή."
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Wikipedia Tokenizer Trainer with Logs & Control")
with gr.Row():
with gr.Column():
dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset Name")
configs = gr.Textbox(value="20231101.el,20231101.en", label="Configs")
split = gr.Dropdown(choices=["train"], value="train", label="Split")
chunk_size = gr.Slider(500, 50000, value=50000, label="Chunk Size")
vocab_size = gr.Slider(20000, 100000, value=50000, label="Vocabulary Size")
min_freq = gr.Slider(1, 100, value=3, label="Minimum Frequency")
test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
start_btn = gr.Button("Start Collection")
stop_btn = gr.Button("Stop Collection")
restart_btn = gr.Button("Restart Collection")
analyze_btn = gr.Button("Analyze Samples")
train_btn = gr.Button("Train Tokenizer")
progress = gr.Textbox(label="Progress", interactive=False, lines=10)
decoded_text = gr.Textbox(label="Decoded Text", interactive=False)
token_distribution = gr.Image(label="Token Distribution")
start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size], progress)
stop_btn.click(stop_collection, [], progress)
restart_btn.click(restart_collection, [], progress)
analyze_btn.click(analyze_checkpoint, [], progress)
train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text],
[progress, decoded_text, token_distribution])
demo.launch() |