File size: 7,741 Bytes
7c5aa99 f94c5ea 0430da2 4a4435c 3d37920 0430da2 4a4435c 0430da2 f94c5ea 4a4435c 5d41434 f94c5ea 5d41434 3d37920 5d41434 0430da2 4a4435c 3d37920 4a4435c 3d37920 4a4435c 7c5aa99 0430da2 a9ae246 3d37920 a9ae246 3d37920 a9ae246 3d37920 1c51cb8 3d37920 f94c5ea 3d37920 f94c5ea 5d41434 3d37920 f94c5ea 3d37920 f94c5ea 5d41434 f94c5ea 5d41434 f94c5ea 3d37920 9d049fd f94c5ea 3d37920 5d41434 9d049fd 3d37920 f94c5ea 3d37920 f94c5ea 3d37920 f94c5ea 3d37920 0430da2 4a4435c 3d37920 9d049fd 0430da2 5d41434 3d37920 5d41434 0044b58 3d37920 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
# -*- coding: utf-8 -*-
import os
import gradio as gr
import requests
import tempfile
from io import BytesIO
import matplotlib.pyplot as plt
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
from langdetect import detect, DetectorFactory
# Για επαναληψιμότητα στο langdetect
DetectorFactory.seed = 0
# Ρυθμίσεις checkpointing και αποθήκευσης του tokenizer
CHECKPOINT_FILE = "checkpoint.txt"
TOKENIZER_DIR = "tokenizer_model"
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
CHUNK_SIZE = 1000 # Μέγεθος batch για checkpoint
MAX_SAMPLES = 3000000 # Όριο δειγμάτων (μπορείς να το προσαρμόσεις)
def fetch_splits(dataset_name):
"""Ανάκτηση των splits του dataset από το Hugging Face."""
try:
response = requests.get(f"https://datasets-server.huggingface.co/splits?dataset={dataset_name}", timeout=10)
response.raise_for_status()
data = response.json()
splits_info = {}
for split in data['splits']:
config = split['config']
split_name = split['split']
if config not in splits_info:
splits_info[config] = []
splits_info[config].append(split_name)
return {
"splits": splits_info,
"viewer_template": f"https://huggingface.co/datasets/{dataset_name}/embed/viewer/{{config}}/{{split}}"
}
except Exception as e:
raise gr.Error(f"Σφάλμα κατά την ανάκτηση των splits: {str(e)}")
def create_iterator(dataset_name, configs, split):
"""Φορτώνει το dataset και αποδίδει τα κείμενα ως iterator."""
configs_list = [c.strip() for c in configs.split(",") if c.strip()]
for config in configs_list:
try:
dataset = load_dataset(dataset_name, name=config, split=split, streaming=True)
for example in dataset:
text = example.get('text', '')
if text:
yield text
except Exception as e:
print(f"⚠️ Σφάλμα φόρτωσης dataset για config {config}: {e}")
def append_to_checkpoint(texts):
"""Αποθήκευση δεδομένων στο αρχείο checkpoint."""
with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f:
for t in texts:
f.write(t + "\n")
def load_checkpoint():
"""Φόρτωση δεδομένων από το checkpoint αν υπάρχει."""
if os.path.exists(CHECKPOINT_FILE):
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
return f.read().splitlines()
return []
def analyze_checkpoint(num_samples=1000):
"""
Διαβάζει τα πρώτα num_samples δείγματα από το checkpoint και επιστρέφει το ποσοστό γλωσσών.
"""
if not os.path.exists(CHECKPOINT_FILE):
return "Το αρχείο checkpoint δεν υπάρχει."
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
lines = f.read().splitlines()
sample_lines = lines[:num_samples] if len(lines) >= num_samples else lines
language_counts = {}
total = 0
for line in sample_lines:
try:
lang = detect(line)
language_counts[lang] = language_counts.get(lang, 0) + 1
total += 1
except Exception as e:
continue
if total == 0:
return "Δεν βρέθηκαν έγκυρα δείγματα για ανάλυση."
report = "Αποτελέσματα Ανάλυσης:\n"
for lang, count in language_counts.items():
report += f"Γλώσσα {lang}: {count/total*100:.2f}%\n"
return report
def train_and_test(dataset_name, configs, split, vocab_size, min_freq, test_text):
"""Εκπαίδευση του tokenizer και δοκιμή του."""
print("🚀 Ξεκινά η διαδικασία εκπαίδευσης...")
all_texts = load_checkpoint()
total_processed = len(all_texts)
print(f"📌 Υπάρχουν ήδη {total_processed} δείγματα στο checkpoint.")
dataset_iterator = create_iterator(dataset_name, configs, split)
new_texts = []
for text in dataset_iterator:
if total_processed >= MAX_SAMPLES:
break # Διακοπή εάν ξεπεραστεί το όριο
new_texts.append(text)
total_processed += 1
if len(new_texts) >= CHUNK_SIZE:
append_to_checkpoint(new_texts)
print(f"✅ Αποθηκεύτηκαν {total_processed} δείγματα στο checkpoint.")
new_texts = []
if new_texts:
append_to_checkpoint(new_texts)
print(f"✅ Τελικό batch αποθηκεύτηκε ({total_processed} δείγματα).")
print("🚀 Η αποθήκευση δεδομένων ολοκληρώθηκε! Ξεκινάει η εκπαίδευση του tokenizer...")
# Εκπαίδευση του tokenizer
all_texts = load_checkpoint()
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR)
# Φόρτωση εκπαιδευμένου tokenizer
trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
# Δοκιμή
encoded = trained_tokenizer.encode(test_text)
decoded = trained_tokenizer.decode(encoded.ids)
# Γράφημα κατανομής tokens
token_lengths = [len(t) for t in encoded.tokens]
fig = plt.figure()
plt.hist(token_lengths, bins=20)
plt.xlabel('Μήκος Token')
plt.ylabel('Συχνότητα')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
return f"✅ Εκπαίδευση ολοκληρώθηκε!\nΑποθηκεύτηκε στον φάκελο: {TOKENIZER_DIR}", decoded, img_buffer.getvalue()
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Wikipedia Tokenizer Trainer with Checkpointing")
with gr.Row():
with gr.Column():
dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset Name")
configs = gr.Textbox(value="20231101.el,20231101.en", label="Configs")
split = gr.Dropdown(choices=["train"], value="train", label="Split")
vocab_size = gr.Slider(20000, 100000, value=50000, label="Vocabulary Size")
min_freq = gr.Slider(1, 100, value=3, label="Minimum Frequency")
test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
train_btn = gr.Button("Train")
analyze_btn = gr.Button("Analyze Samples")
with gr.Column():
progress = gr.Textbox(label="Progress", interactive=False, lines=10)
results_text = gr.Textbox(label="Test Decoded Text", interactive=False)
results_plot = gr.Image(label="Token Length Distribution")
# Έλεγχος ύπαρξης του tokenizer για download
initial_file_value = TOKENIZER_FILE if os.path.exists(TOKENIZER_FILE) else None
download_button = gr.File(label="Download Tokenizer", value=initial_file_value)
train_btn.click(train_and_test,
inputs=[dataset_name, configs, split, vocab_size, min_freq, test_text],
outputs=[progress, results_text, results_plot])
analyze_btn.click(fn=lambda: analyze_checkpoint(1000),
inputs=[],
outputs=progress)
demo.launch() |