|
|
|
import os |
|
import time |
|
import gradio as gr |
|
import requests |
|
from io import BytesIO |
|
import matplotlib.pyplot as plt |
|
import tempfile |
|
from datasets import load_dataset |
|
from train_tokenizer import train_tokenizer |
|
from tokenizers import Tokenizer |
|
|
|
|
|
CHECKPOINT_FILE = "checkpoint.txt" |
|
CHUNK_SIZE = 1000 |
|
|
|
def fetch_splits(dataset_name): |
|
try: |
|
response = requests.get( |
|
f"https://datasets-server.huggingface.co/splits?dataset={dataset_name}", |
|
timeout=10 |
|
) |
|
response.raise_for_status() |
|
data = response.json() |
|
|
|
splits_info = {} |
|
for split in data['splits']: |
|
config = split['config'] |
|
split_name = split['split'] |
|
if config not in splits_info: |
|
splits_info[config] = [] |
|
splits_info[config].append(split_name) |
|
|
|
return { |
|
"splits": splits_info, |
|
"viewer_template": f"https://huggingface.co/datasets/{dataset_name}/embed/viewer/{{config}}/{{split}}" |
|
} |
|
except Exception as e: |
|
raise gr.Error(f"Σφάλμα κατά την ανάκτηση των splits: {str(e)}") |
|
|
|
def update_components(dataset_name): |
|
if not dataset_name: |
|
return [gr.Textbox.update(value=""), gr.Dropdown.update(choices=[], value=None), gr.HTML.update(value="")] |
|
|
|
try: |
|
splits_data = fetch_splits(dataset_name) |
|
config_choices = list(splits_data['splits'].keys()) |
|
first_config = config_choices[0] if config_choices else None |
|
iframe_html = f""" |
|
<iframe |
|
src="{splits_data['viewer_template'].format(config=first_config, split='train')}" |
|
frameborder="0" |
|
width="100%" |
|
height="560px" |
|
></iframe> |
|
""" if first_config else "Δεν βρέθηκαν διαθέσιμα δεδομένα" |
|
|
|
|
|
default_configs = "20231101.el,20231101.en" if first_config and "el" in first_config else first_config |
|
return [ |
|
gr.Textbox.update(value=default_configs), |
|
gr.Dropdown.update(choices=splits_data['splits'].get(first_config, [])), |
|
gr.HTML.update(value=iframe_html) |
|
] |
|
except Exception as e: |
|
raise gr.Error(f"Σφάλμα: {str(e)}") |
|
|
|
def update_split_choices(dataset_name, configs): |
|
if not dataset_name or not configs: |
|
return gr.Dropdown.update(choices=[]) |
|
try: |
|
splits_data = fetch_splits(dataset_name) |
|
|
|
first_config = configs.split(",")[0].strip() |
|
return gr.Dropdown.update(choices=splits_data['splits'].get(first_config, [])) |
|
except: |
|
return gr.Dropdown.update(choices=[]) |
|
|
|
def create_iterator(dataset_name, configs, split): |
|
""" |
|
Για κάθε config στη λίστα (χωρισμένα με κόμμα) φορτώνει το αντίστοιχο streaming dataset και παράγει τα κείμενα. |
|
""" |
|
configs_list = [c.strip() for c in configs.split(",") if c.strip()] |
|
for config in configs_list: |
|
try: |
|
dataset = load_dataset( |
|
dataset_name, |
|
name=config, |
|
split=split, |
|
streaming=True |
|
) |
|
for example in dataset: |
|
text = example.get('text', '') |
|
if text: |
|
yield text |
|
except Exception as e: |
|
print(f"Σφάλμα φόρτωσης dataset για config {config}: {e}") |
|
|
|
def append_to_checkpoint(texts, checkpoint_file): |
|
""" |
|
Αποθηκεύει τα κείμενα στο αρχείο checkpoint. |
|
""" |
|
with open(checkpoint_file, "a", encoding="utf-8") as f: |
|
for t in texts: |
|
f.write(t + "\n") |
|
|
|
def load_checkpoint(checkpoint_file): |
|
""" |
|
Διαβάζει και επιστρέφει τα κείμενα από το checkpoint (αν υπάρχει). |
|
""" |
|
if os.path.exists(checkpoint_file): |
|
with open(checkpoint_file, "r", encoding="utf-8") as f: |
|
return f.read().splitlines() |
|
return [] |
|
|
|
def train_and_test_streaming(dataset_name, configs, split, vocab_size, min_freq, test_text, custom_files): |
|
""" |
|
Generator που εκπαιδεύει τον tokenizer σε chunks, αποθηκεύοντας τα δεδομένα σε checkpoint. |
|
Επίσης, ενημερώνει την πρόοδο μέσω streaming στην Gradio διεπαφή. |
|
|
|
Αν υπάρχει ήδη checkpoint, συνεχίζει από εκεί. |
|
""" |
|
|
|
all_texts = load_checkpoint(CHECKPOINT_FILE) |
|
total_processed = len(all_texts) |
|
yield {"progress": f"Έχετε {total_processed} δείγματα ήδη αποθηκευμένα στο checkpoint.\n"} |
|
|
|
|
|
dataset_iterator = create_iterator(dataset_name, configs, split) |
|
|
|
new_texts = [] |
|
chunk_count = 0 |
|
|
|
for text in dataset_iterator: |
|
new_texts.append(text) |
|
total_processed += 1 |
|
|
|
if len(new_texts) >= CHUNK_SIZE: |
|
append_to_checkpoint(new_texts, CHECKPOINT_FILE) |
|
chunk_count += 1 |
|
yield {"progress": f"Επεξεργάστηκαν {total_processed} δείγματα (chunk {chunk_count}).\n"} |
|
new_texts = [] |
|
|
|
|
|
if new_texts: |
|
append_to_checkpoint(new_texts, CHECKPOINT_FILE) |
|
total_processed += len(new_texts) |
|
chunk_count += 1 |
|
yield {"progress": f"Τελικό chunk: συνολικά {total_processed} δείγματα αποθηκεύτηκαν.\n"} |
|
|
|
|
|
if custom_files: |
|
custom_texts = [] |
|
for file_path in custom_files: |
|
try: |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
if content: |
|
custom_texts.append(content) |
|
except Exception as file_error: |
|
print(f"Σφάλμα ανάγνωσης αρχείου {file_path}: {file_error}") |
|
if custom_texts: |
|
append_to_checkpoint(custom_texts, CHECKPOINT_FILE) |
|
total_processed += len(custom_texts) |
|
yield {"progress": f"Προστέθηκαν {len(custom_texts)} δείγματα από custom αρχεία.\n"} |
|
|
|
|
|
all_texts = load_checkpoint(CHECKPOINT_FILE) |
|
yield {"progress": f"Ξεκινάει η εκπαίδευση του tokenizer σε {len(all_texts)} δείγματα...\n"} |
|
|
|
|
|
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as f: |
|
tokenizer.save(f.name) |
|
trained_tokenizer = Tokenizer.from_file(f.name) |
|
os.unlink(f.name) |
|
|
|
|
|
encoded = trained_tokenizer.encode(test_text) |
|
decoded = trained_tokenizer.decode(encoded.ids) |
|
|
|
|
|
token_lengths = [len(t) for t in encoded.tokens] |
|
fig = plt.figure() |
|
plt.hist(token_lengths, bins=20) |
|
plt.xlabel('Μήκος Token') |
|
plt.ylabel('Συχνότητα') |
|
img_buffer = BytesIO() |
|
plt.savefig(img_buffer, format='png') |
|
plt.close() |
|
|
|
results = { |
|
"Πρωτότυπο Κείμενο": test_text, |
|
"Αποκωδικοποιημένο": decoded, |
|
"Αριθμός Tokens": len(encoded.tokens), |
|
"Αγνώστων Tokens": sum(1 for t in encoded.tokens if t == "<unk>") |
|
} |
|
yield {"progress": "Η εκπαίδευση ολοκληρώθηκε!\n", "results": results, "plot": img_buffer.getvalue()} |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("## Wikipedia Tokenizer Trainer with Checkpointing and Streaming") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
dataset_name = gr.Textbox( |
|
label="Dataset Name", |
|
value="wikimedia/wikipedia", |
|
placeholder="π.χ. 'wikimedia/wikipedia'" |
|
) |
|
configs = gr.Textbox( |
|
label="Configs (π.χ. '20231101.el,20231101.en' για ελληνικά και αγγλικά)", |
|
value="20231101.el,20231101.en", |
|
placeholder="Εισάγετε configs χωρισμένα με κόμμα" |
|
) |
|
split = gr.Dropdown( |
|
label="Split", |
|
choices=["train"], |
|
value="train", |
|
allow_custom_value=True |
|
) |
|
vocab_size = gr.Slider(20000, 100000, value=50000, label="Μέγεθος Λεξιλογίου") |
|
min_freq = gr.Slider(1, 100, value=3, label="Ελάχιστη Συχνότητα") |
|
test_text = gr.Textbox( |
|
value="Η Ακρόπολη είναι σύμβολο της αρχαίας ελληνικής πολιτισμικής κληρονομιάς.", |
|
label="Test Text" |
|
) |
|
custom_files = gr.File( |
|
label="Προσαρμοσμένα Ελληνικά Κείμενα", |
|
file_count="multiple", |
|
type="filepath" |
|
) |
|
train_btn = gr.Button("Εκπαίδευση", variant="primary") |
|
with gr.Column(): |
|
progress_box = gr.Textbox(label="Πρόοδος", interactive=False) |
|
results_json = gr.JSON(label="Αποτελέσματα") |
|
results_plot = gr.Image(label="Κατανομή Μηκών Tokens") |
|
|
|
|
|
dataset_name.change( |
|
fn=update_components, |
|
inputs=dataset_name, |
|
outputs=[configs, split, gr.HTML(label="Dataset Preview")] |
|
) |
|
split.change( |
|
fn=update_split_choices, |
|
inputs=[dataset_name, configs], |
|
outputs=split |
|
) |
|
train_btn.click( |
|
fn=train_and_test_streaming, |
|
inputs=[dataset_name, configs, split, vocab_size, min_freq, test_text, custom_files], |
|
outputs=[progress_box, results_json, results_plot], |
|
stream=True |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |