GPT2-PBE / app.py
tymbos's picture
Update app.py
f94c5ea verified
raw
history blame
11.5 kB
# -*- coding: utf-8 -*-
import os
import time
import gradio as gr
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import tempfile
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
# Ρυθμίσεις checkpointing
CHECKPOINT_FILE = "checkpoint.txt" # αρχείο που αποθηκεύει όλα τα επεξεργασμένα κείμενα
CHUNK_SIZE = 1000 # αριθμός δειγμάτων που θα επεξεργάζονται πριν την αποθήκευση checkpoint
def fetch_splits(dataset_name):
try:
response = requests.get(
f"https://datasets-server.huggingface.co/splits?dataset={dataset_name}",
timeout=10
)
response.raise_for_status()
data = response.json()
splits_info = {}
for split in data['splits']:
config = split['config']
split_name = split['split']
if config not in splits_info:
splits_info[config] = []
splits_info[config].append(split_name)
return {
"splits": splits_info,
"viewer_template": f"https://huggingface.co/datasets/{dataset_name}/embed/viewer/{{config}}/{{split}}"
}
except Exception as e:
raise gr.Error(f"Σφάλμα κατά την ανάκτηση των splits: {str(e)}")
def update_components(dataset_name):
if not dataset_name:
return [gr.Textbox.update(value=""), gr.Dropdown.update(choices=[], value=None), gr.HTML.update(value="")]
try:
splits_data = fetch_splits(dataset_name)
config_choices = list(splits_data['splits'].keys())
first_config = config_choices[0] if config_choices else None
iframe_html = f"""
<iframe
src="{splits_data['viewer_template'].format(config=first_config, split='train')}"
frameborder="0"
width="100%"
height="560px"
></iframe>
""" if first_config else "Δεν βρέθηκαν διαθέσιμα δεδομένα"
# Προτείνουμε ως προεπιλογή για πολλαπλά configs τα ελληνικά και αγγλικά
default_configs = "20231101.el,20231101.en" if first_config and "el" in first_config else first_config
return [
gr.Textbox.update(value=default_configs),
gr.Dropdown.update(choices=splits_data['splits'].get(first_config, [])),
gr.HTML.update(value=iframe_html)
]
except Exception as e:
raise gr.Error(f"Σφάλμα: {str(e)}")
def update_split_choices(dataset_name, configs):
if not dataset_name or not configs:
return gr.Dropdown.update(choices=[])
try:
splits_data = fetch_splits(dataset_name)
# Χρησιμοποιούμε το πρώτο config της λίστας για τις επιλογές του split
first_config = configs.split(",")[0].strip()
return gr.Dropdown.update(choices=splits_data['splits'].get(first_config, []))
except:
return gr.Dropdown.update(choices=[])
def create_iterator(dataset_name, configs, split):
"""
Για κάθε config στη λίστα (χωρισμένα με κόμμα) φορτώνει το αντίστοιχο streaming dataset και παράγει τα κείμενα.
"""
configs_list = [c.strip() for c in configs.split(",") if c.strip()]
for config in configs_list:
try:
dataset = load_dataset(
dataset_name,
name=config,
split=split,
streaming=True
)
for example in dataset:
text = example.get('text', '')
if text:
yield text
except Exception as e:
print(f"Σφάλμα φόρτωσης dataset για config {config}: {e}")
def append_to_checkpoint(texts, checkpoint_file):
"""
Αποθηκεύει τα κείμενα στο αρχείο checkpoint.
"""
with open(checkpoint_file, "a", encoding="utf-8") as f:
for t in texts:
f.write(t + "\n")
def load_checkpoint(checkpoint_file):
"""
Διαβάζει και επιστρέφει τα κείμενα από το checkpoint (αν υπάρχει).
"""
if os.path.exists(checkpoint_file):
with open(checkpoint_file, "r", encoding="utf-8") as f:
return f.read().splitlines()
return []
def train_and_test_streaming(dataset_name, configs, split, vocab_size, min_freq, test_text, custom_files):
"""
Generator που εκπαιδεύει τον tokenizer σε chunks, αποθηκεύοντας τα δεδομένα σε checkpoint.
Επίσης, ενημερώνει την πρόοδο μέσω streaming στην Gradio διεπαφή.
Αν υπάρχει ήδη checkpoint, συνεχίζει από εκεί.
"""
# Φόρτωση ήδη επεξεργασμένων δεδομένων από checkpoint (αν υπάρχουν)
all_texts = load_checkpoint(CHECKPOINT_FILE)
total_processed = len(all_texts)
yield {"progress": f"Έχετε {total_processed} δείγματα ήδη αποθηκευμένα στο checkpoint.\n"}
# Δημιουργία iterator από τα streaming datasets
dataset_iterator = create_iterator(dataset_name, configs, split)
new_texts = []
chunk_count = 0
# Διατρέχουμε τα νέα δεδομένα σε chunks
for text in dataset_iterator:
new_texts.append(text)
total_processed += 1
# Κάθε CHUNK_SIZE δείγματα αποθηκεύουμε στο checkpoint και ενημερώνουμε την πρόοδο
if len(new_texts) >= CHUNK_SIZE:
append_to_checkpoint(new_texts, CHECKPOINT_FILE)
chunk_count += 1
yield {"progress": f"Επεξεργάστηκαν {total_processed} δείγματα (chunk {chunk_count}).\n"}
new_texts = [] # καθαρίζουμε το chunk
# Αποθήκευση τυχόν υπολειπόμενων νέων δεδομένων
if new_texts:
append_to_checkpoint(new_texts, CHECKPOINT_FILE)
total_processed += len(new_texts)
chunk_count += 1
yield {"progress": f"Τελικό chunk: συνολικά {total_processed} δείγματα αποθηκεύτηκαν.\n"}
# Ενσωματώνουμε επίσης τα custom files (αν υπάρχουν)
if custom_files:
custom_texts = []
for file_path in custom_files:
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
if content:
custom_texts.append(content)
except Exception as file_error:
print(f"Σφάλμα ανάγνωσης αρχείου {file_path}: {file_error}")
if custom_texts:
append_to_checkpoint(custom_texts, CHECKPOINT_FILE)
total_processed += len(custom_texts)
yield {"progress": f"Προστέθηκαν {len(custom_texts)} δείγματα από custom αρχεία.\n"}
# Συνολικά δεδομένα για εκπαίδευση: checkpoint + νέα δεδομένα
all_texts = load_checkpoint(CHECKPOINT_FILE)
yield {"progress": f"Ξεκινάει η εκπαίδευση του tokenizer σε {len(all_texts)} δείγματα...\n"}
# Εκπαίδευση του tokenizer πάνω στα συσσωρευμένα δεδομένα
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq)
# Αποθήκευση και φόρτωση του εκπαιδευμένου tokenizer
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as f:
tokenizer.save(f.name)
trained_tokenizer = Tokenizer.from_file(f.name)
os.unlink(f.name)
# Validation: κωδικοποίηση και αποκωδικοποίηση του test κειμένου
encoded = trained_tokenizer.encode(test_text)
decoded = trained_tokenizer.decode(encoded.ids)
# Δημιουργία γραφήματος για την κατανομή των μηκών των tokens
token_lengths = [len(t) for t in encoded.tokens]
fig = plt.figure()
plt.hist(token_lengths, bins=20)
plt.xlabel('Μήκος Token')
plt.ylabel('Συχνότητα')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
results = {
"Πρωτότυπο Κείμενο": test_text,
"Αποκωδικοποιημένο": decoded,
"Αριθμός Tokens": len(encoded.tokens),
"Αγνώστων Tokens": sum(1 for t in encoded.tokens if t == "<unk>")
}
yield {"progress": "Η εκπαίδευση ολοκληρώθηκε!\n", "results": results, "plot": img_buffer.getvalue()}
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Wikipedia Tokenizer Trainer with Checkpointing and Streaming")
with gr.Row():
with gr.Column():
dataset_name = gr.Textbox(
label="Dataset Name",
value="wikimedia/wikipedia",
placeholder="π.χ. 'wikimedia/wikipedia'"
)
configs = gr.Textbox(
label="Configs (π.χ. '20231101.el,20231101.en' για ελληνικά και αγγλικά)",
value="20231101.el,20231101.en",
placeholder="Εισάγετε configs χωρισμένα με κόμμα"
)
split = gr.Dropdown(
label="Split",
choices=["train"],
value="train",
allow_custom_value=True
)
vocab_size = gr.Slider(20000, 100000, value=50000, label="Μέγεθος Λεξιλογίου")
min_freq = gr.Slider(1, 100, value=3, label="Ελάχιστη Συχνότητα")
test_text = gr.Textbox(
value="Η Ακρόπολη είναι σύμβολο της αρχαίας ελληνικής πολιτισμικής κληρονομιάς.",
label="Test Text"
)
custom_files = gr.File(
label="Προσαρμοσμένα Ελληνικά Κείμενα",
file_count="multiple",
type="filepath"
)
train_btn = gr.Button("Εκπαίδευση", variant="primary")
with gr.Column():
progress_box = gr.Textbox(label="Πρόοδος", interactive=False)
results_json = gr.JSON(label="Αποτελέσματα")
results_plot = gr.Image(label="Κατανομή Μηκών Tokens")
# Event handlers
dataset_name.change(
fn=update_components,
inputs=dataset_name,
outputs=[configs, split, gr.HTML(label="Dataset Preview")]
)
split.change(
fn=update_split_choices,
inputs=[dataset_name, configs],
outputs=split
)
train_btn.click(
fn=train_and_test_streaming,
inputs=[dataset_name, configs, split, vocab_size, min_freq, test_text, custom_files],
outputs=[progress_box, results_json, results_plot],
stream=True
)
if __name__ == "__main__":
demo.launch()