VietCat's picture
adjust generation time
4a6a9fa
raw
history blame
7.33 kB
import os
import time
import warnings
import re
import gc
import tracemalloc
warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gradio as gr
import psutil
# Print system resources for debugging
def print_system_resources():
memory = psutil.virtual_memory()
cpu_percent = psutil.cpu_percent(interval=1)
try:
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f:
mem_limit = min(int(f.read().strip()) / 1e9, 16.0)
except:
mem_limit = 16.0
print(f"Total physical memory (psutil): {memory.total/1e9:.2f} GB")
print(f"Container memory limit: {mem_limit:.2f} GB")
print(f"CPU usage: {cpu_percent}%")
print(f"Memory usage: {min(memory.used / (mem_limit * 1e9) * 100, 100):.1f}% ({memory.used/1e9:.2f}/{mem_limit:.2f} GB)")
print(f"Active processes: {len(psutil.pids())}")
print(f"Python memory (torch): {torch.cuda.memory_allocated()/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only")
# Print Gradio version for debugging
print(f"Gradio version: {gr.__version__}")
# Load model and tokenizer
model_id = "NlpHUST/gpt2-vietnamese"
try:
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
model = GPT2LMHeadModel.from_pretrained(model_id)
except Exception as e:
print(f"Error loading model: {e}")
raise e
# Set pad_token_id to eos_token_id if not set
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.eos_token_id
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# Print device and memory info for debugging
print(f"Device: {device}")
print(f"Memory allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only")
print_system_resources()
def clean_text(text):
"""Normalize text by removing invalid characters and extra spaces."""
text = re.sub(r'[^\w\s.,!?àáâãèéêìíòóôõùúýăđĩũơưạảấầẩẫậắằẳẵặẹẻẽếềểễệỉịọỏốồổỗộớờởỡợụủứừửữựỳỵỷỹ]', '', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
def generate_text(prompt, temperature=0.9, max_new_tokens=30, min_length=20, max_length=100):
try:
start_time = time.time()
# Start tracemalloc for memory debugging
tracemalloc.start()
# Debug memory before generation
print("Memory before generation:")
print_system_resources()
# Fixed parameters
top_k = 20
repetition_penalty = 2.0
top_p = 0.9
# Log parameters
print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, min_length={min_length}")
# Encode input with attention mask
encode_time = time.time()
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length
).to(device)
print(f"Encoding time: {time.time() - encode_time:.2f} seconds")
print(f"Input tokens: {len(inputs['input_ids'][0])}")
# Define EOS token IDs for '.', '!', '?'
eos_token_ids = [tokenizer.encode(s)[0] for s in ['.', '!', '?']]
print(f"EOS token IDs: {eos_token_ids}")
# Generate text with no_grad
gen_time = time.time()
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=int(max_new_tokens),
min_length=int(min_length),
do_sample=True,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=eos_token_ids
)
gen_duration = time.time() - gen_time
print(f"Generation time: {gen_duration:.2f} seconds")
print(f"Tokens per second: {len(outputs[0]) / gen_duration:.2f}")
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Raw output: {generated_text}")
print(f"Generated token count: {len(outputs[0])}")
cleaned_text = clean_text(generated_text)
print(f"Cleaned output: {cleaned_text}")
elapsed_time = time.time() - start_time
print(f"Total time: {elapsed_time:.2f} seconds")
# Check for repetitive text
from collections import Counter
words = cleaned_text.split()
word_counts = Counter(words)
if any(count > 3 for count in word_counts.values()):
print("Warning: Repetitive text detected")
# Clear memory cache
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Debug memory after generation
print("Memory after generation:")
print_system_resources()
# Log tracemalloc
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
print("[Top 5 memory]", top_stats[:5])
tracemalloc.stop()
return cleaned_text
except Exception as e:
return f"Error generating text: {e}"
# Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(
label="Nhập văn bản đầu vào",
placeholder="Viết gì đó bằng tiếng Việt...",
value="Hôm nay là một ngày đẹp trời, hãy kể một câu chuyện dài về một chuyến đi đáng nhớ, mô tả chi tiết nơi bạn đến, người bạn gặp, và cảm xúc của bạn."
),
gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Nhiệt độ (Temperature, 0.1-0.3 cho mạch lạc, 0.7-1.2 cho sáng tạo và câu dài)"),
gr.Slider(25, 100, value=30, step=5, label="Số token mới tối đa (max_new_tokens, 25-50 cho câu ngắn, 75-100 cho câu dài)"),
gr.Slider(20, 100, value=20, step=5, label="Độ dài tối thiểu (min_length, 20-50 cho câu vừa, 75-100 cho câu dài)"),
gr.Slider(50, 150, value=100, step=5, label="Độ dài tối đa (max_length, 50-100 cho câu ngắn, 100-150 cho câu dài)")
],
outputs="text",
title="Sinh văn bản tiếng Việt",
description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.7-1.2, max_new_tokens 30-50, min_length 20-50, và max_length 100 để tạo câu dài, mạch lạc, và hoàn chỉnh. Dùng temperature 0.1-0.3, max_new_tokens 25-30, min_length 20-30 cho câu ngắn và nhanh. Lưu ý: Tránh temperature=0.1 (dễ lặp ý) và max_new_tokens=100 (chậm, dễ cắt cụt). Dùng prompt chi tiết (bối cảnh, nhân vật, cảm xúc) để cải thiện chất lượng. Top_k mặc định là 20.",
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)