GPT2-PBE / app.py
tymbos's picture
Update app.py
0044b58 verified
raw
history blame
11.3 kB
# -*- coding: utf-8 -*-
import os
import gradio as gr
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import tempfile
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
# Ρυθμίσεις checkpointing
CHECKPOINT_FILE = "checkpoint.txt" # Αρχείο που αποθηκεύει τα ήδη επεξεργασμένα κείμενα
CHUNK_SIZE = 1000 # Αριθμός δειγμάτων ανά chunk για αποθήκευση στο checkpoint
def fetch_splits(dataset_name):
try:
response = requests.get(
f"https://datasets-server.huggingface.co/splits?dataset={dataset_name}",
timeout=10
)
response.raise_for_status()
data = response.json()
splits_info = {}
for split in data['splits']:
config = split['config']
split_name = split['split']
if config not in splits_info:
splits_info[config] = []
splits_info[config].append(split_name)
return {
"splits": splits_info,
"viewer_template": f"https://huggingface.co/datasets/{dataset_name}/embed/viewer/{{config}}/{{split}}"
}
except Exception as e:
raise gr.Error(f"Σφάλμα κατά την ανάκτηση των splits: {str(e)}")
def update_components(dataset_name):
if not dataset_name:
return [gr.Textbox.update(value=""), gr.Dropdown.update(choices=[], value=None), gr.HTML.update(value="")]
try:
splits_data = fetch_splits(dataset_name)
config_choices = list(splits_data['splits'].keys())
first_config = config_choices[0] if config_choices else None
iframe_html = f"""
<iframe
src="{splits_data['viewer_template'].format(config=first_config, split='train')}"
frameborder="0"
width="100%"
height="560px"
></iframe>
""" if first_config else "Δεν βρέθηκαν διαθέσιμα δεδομένα"
# Προτείνουμε ως προεπιλογή για πολλαπλά configs τα ελληνικά και αγγλικά
default_configs = "20231101.el,20231101.en" if first_config and "el" in first_config else first_config
return [
gr.Textbox.update(value=default_configs),
gr.Dropdown.update(choices=splits_data['splits'].get(first_config, [])),
gr.HTML.update(value=iframe_html)
]
except Exception as e:
raise gr.Error(f"Σφάλμα: {str(e)}")
def update_split_choices(dataset_name, configs):
if not dataset_name or not configs:
return gr.Dropdown.update(choices=[])
try:
splits_data = fetch_splits(dataset_name)
first_config = configs.split(",")[0].strip()
return gr.Dropdown.update(choices=splits_data['splits'].get(first_config, []))
except:
return gr.Dropdown.update(choices=[])
def create_iterator(dataset_name, configs, split):
"""
Για κάθε config (χωρισμένα με κόμμα) φορτώνει το αντίστοιχο streaming dataset και παράγει τα κείμενα.
"""
configs_list = [c.strip() for c in configs.split(",") if c.strip()]
for config in configs_list:
try:
dataset = load_dataset(
dataset_name,
name=config,
split=split,
streaming=True
)
for example in dataset:
text = example.get('text', '')
if text:
yield text
except Exception as e:
print(f"Σφάλμα φόρτωσης dataset για config {config}: {e}")
def append_to_checkpoint(texts, checkpoint_file):
"""
Αποθηκεύει τα δεδομένα στο αρχείο checkpoint.
"""
with open(checkpoint_file, "a", encoding="utf-8") as f:
for t in texts:
f.write(t + "\n")
def load_checkpoint(checkpoint_file):
"""
Διαβάζει τα δεδομένα από το checkpoint (αν υπάρχει).
"""
if os.path.exists(checkpoint_file):
with open(checkpoint_file, "r", encoding="utf-8") as f:
return f.read().splitlines()
return []
def train_and_test(dataset_name, configs, split, vocab_size, min_freq, test_text, custom_files):
"""
Εκπαιδεύει τον tokenizer με checkpointing.
Επιστρέφει στο τέλος την τελική πρόοδο, τα αποτελέσματα και το plot.
(Σημείωση: Σε αυτήν την έκδοση δεν υπάρχει streaming progress λόγω περιορισμών στο Gradio στο Spaces.)
"""
progress_messages = []
# Φόρτωση ήδη επεξεργασμένων δεδομένων από το checkpoint (αν υπάρχουν)
all_texts = load_checkpoint(CHECKPOINT_FILE)
total_processed = len(all_texts)
progress_messages.append(f"Έχετε {total_processed} δείγματα ήδη αποθηκευμένα στο checkpoint.")
# Δημιουργία iterator από τα streaming datasets
dataset_iterator = create_iterator(dataset_name, configs, split)
new_texts = []
chunk_count = 0
# Επεξεργασία νέων δεδομένων σε chunks
for text in dataset_iterator:
new_texts.append(text)
total_processed += 1
if len(new_texts) >= CHUNK_SIZE:
append_to_checkpoint(new_texts, CHECKPOINT_FILE)
chunk_count += 1
progress_messages.append(f"Επεξεργάστηκαν {total_processed} δείγματα (chunk {chunk_count}).")
new_texts = []
# Αποθήκευση υπολειπόμενων δεδομένων
if new_texts:
append_to_checkpoint(new_texts, CHECKPOINT_FILE)
total_processed += len(new_texts)
chunk_count += 1
progress_messages.append(f"Τελικό chunk: συνολικά {total_processed} δείγματα αποθηκεύτηκαν.")
# Επεξεργασία των custom αρχείων, αν υπάρχουν
if custom_files:
custom_texts = []
for file_path in custom_files:
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
if content:
custom_texts.append(content)
except Exception as file_error:
progress_messages.append(f"Σφάλμα ανάγνωσης αρχείου {file_path}: {file_error}")
if custom_texts:
append_to_checkpoint(custom_texts, CHECKPOINT_FILE)
total_processed += len(custom_texts)
progress_messages.append(f"Προστέθηκαν {len(custom_texts)} δείγματα από custom αρχεία.")
# Φόρτωση όλων των δεδομένων για εκπαίδευση
all_texts = load_checkpoint(CHECKPOINT_FILE)
progress_messages.append(f"Ξεκινάει η εκπαίδευση του tokenizer σε {len(all_texts)} δείγματα...")
# Εκπαίδευση του tokenizer
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq)
# Αποθήκευση και φόρτωση του εκπαιδευμένου tokenizer
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as f:
tokenizer.save(f.name)
trained_tokenizer = Tokenizer.from_file(f.name)
os.unlink(f.name)
# Validation: κωδικοποίηση και αποκωδικοποίηση του test κειμένου
encoded = trained_tokenizer.encode(test_text)
decoded = trained_tokenizer.decode(encoded.ids)
# Δημιουργία γραφήματος για την κατανομή των μηκών των tokens
token_lengths = [len(t) for t in encoded.tokens]
fig = plt.figure()
plt.hist(token_lengths, bins=20)
plt.xlabel('Μήκος Token')
plt.ylabel('Συχνότητα')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
results = {
"Πρωτότυπο Κείμενο": test_text,
"Αποκωδικοποιημένο": decoded,
"Αριθμός Tokens": len(encoded.tokens),
"Αγνώστων Tokens": sum(1 for t in encoded.tokens if t == "<unk>")
}
progress_messages.append("Η εκπαίδευση ολοκληρώθηκε!")
# Επιστρέφουμε τα μηνύματα προόδου μαζί με τα τελικά αποτελέσματα και το plot
final_progress = "\n".join(progress_messages)
return final_progress, results, img_buffer.getvalue()
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Wikipedia Tokenizer Trainer with Checkpointing")
with gr.Row():
with gr.Column():
dataset_name = gr.Textbox(
label="Dataset Name",
value="wikimedia/wikipedia",
placeholder="π.χ. 'wikimedia/wikipedia'"
)
configs = gr.Textbox(
label="Configs (π.χ. '20231101.el,20231101.en' για ελληνικά και αγγλικά)",
value="20231101.el,20231101.en",
placeholder="Εισάγετε configs χωρισμένα με κόμμα"
)
split = gr.Dropdown(
label="Split",
choices=["train"],
value="train",
allow_custom_value=True
)
vocab_size = gr.Slider(20000, 100000, value=50000, label="Μέγεθος Λεξιλογίου")
min_freq = gr.Slider(1, 100, value=3, label="Ελάχιστη Συχνότητα")
test_text = gr.Textbox(
value="Η Ακρόπολη είναι σύμβολο της αρχαίας ελληνικής πολιτισμικής κληρονομιάς.",
label="Test Text"
)
custom_files = gr.File(
label="Προσαρμοσμένα Ελληνικά Κείμενα",
file_count="multiple",
type="filepath"
)
train_btn = gr.Button("Εκπαίδευση", variant="primary")
with gr.Column():
progress_box = gr.Textbox(label="Πρόοδος", interactive=False, lines=10)
results_json = gr.JSON(label="Αποτελέσματα")
results_plot = gr.Image(label="Κατανομή Μηκών Tokens")
# Event handlers
dataset_name.change(
fn=update_components,
inputs=dataset_name,
outputs=[configs, split, gr.HTML(label="Dataset Preview")]
)
split.change(
fn=update_split_choices,
inputs=[dataset_name, configs],
outputs=split
)
train_btn.click(
fn=train_and_test,
inputs=[dataset_name, configs, split, vocab_size, min_freq, test_text, custom_files],
outputs=[progress_box, results_json, results_plot]
)
if __name__ == "__main__":
demo.launch()