Spaces:
Running
Running
import os | |
import time | |
import warnings | |
import re | |
import gc | |
warnings.filterWarnings("ignore", category=UserWarning, module="torch._utils") | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import torch | |
import gradio as gr | |
import psutil | |
# Print system resources for debugging | |
def print_system_resources(): | |
memory = psutil.virtual_memory() | |
cpu_percent = psutil.cpu_percent(interval=1) | |
# Get container memory limit (for Docker) | |
try: | |
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f: | |
mem_limit = min(int(f.read().strip()) / 1e9, 16.0) # Cap at 16GB for HFS free | |
except: | |
mem_limit = 16.0 # Fallback for HFS free | |
print(f"Total physical memory (psutil): {memory.total/1e9:.2f} GB") | |
print(f"Container memory limit: {mem_limit:.2f} GB") | |
print(f"CPU usage: {cpu_percent}%") | |
print(f"Memory usage: {min(memory.used / (mem_limit * 1e9) * 100, 100):.1f}% ({memory.used/1e9:.2f}/{mem_limit:.2f} GB)") | |
print(f"Active processes: {len(psutil.pids())}") | |
# Print Gradio version for debugging | |
print(f"Gradio version: {gr.__version__}") | |
# Load model and tokenizer | |
model_id = "NlpHUST/gpt2-vietnamese" | |
try: | |
tokenizer = GPT2Tokenizer.from_pretrained(model_id) | |
model = GPT2LMHeadModel.from_pretrained(model_id) | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise e | |
# Set pad_token_id to eos_token_id if not set | |
if tokenizer.pad_token_id is None: | |
tokenizer.pad_token_id = tokenizer.eos_token_id | |
model.config.pad_token_id = tokenizer.eos_token_id | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
model.eval() | |
# Apply quantization to reduce memory and speed up | |
model = torch.quantization.quantize_dynamic( | |
model, {torch.nn.Linear}, dtype=torch.qint8 | |
) | |
print(f"Model quantized: {model.__class__.__name__}") | |
# Print device and memory info for debugging | |
print(f"Device: {device}") | |
print(f"Memory allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only") | |
print_system_resources() | |
def clean_text(text): | |
"""Normalize text by removing invalid characters and extra spaces.""" | |
text = re.sub(r'[^\w\s.,!?àáâãèéêìíòóôõùúýăđĩũơưạảấầẩẫậắằẳẵặẹẻẽếềểễệỉịọỏốồổỗộớờởỡợụủứừửữựỳỵỷỹ]', '', text) | |
text = re.sub(r'\s+', ' ', text).strip() | |
return text | |
def generate_text(prompt, temperature=0.5, max_new_tokens=30): | |
try: | |
start_time = time.time() | |
print_system_resources() | |
# Fixed parameters | |
max_length = 50 | |
top_k = 20 | |
repetition_penalty = 1.2 | |
# Log parameters | |
print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, repetition_penalty={repetition_penalty}") | |
# Encode input with attention mask | |
encode_time = time.time() | |
inputs = tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=max_length | |
).to(device) | |
print(f"Encoding time: {time.time() - encode_time:.2f} seconds") | |
print(f"Input tokens: {len(inputs['input_ids'][0])}") | |
# Define EOS token IDs for '.', '!', '?' | |
eos_token_ids = [tokenizer.encode(s)[0] for s in ['.', '!', '?']] | |
print(f"EOS token IDs: {eos_token_ids}") | |
# Generate text | |
gen_time = time.time() | |
outputs = model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_new_tokens=int(max_new_tokens), | |
min_length=3, | |
do_sample=True, | |
top_k=int(top_k), | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=eos_token_ids | |
) | |
print(f"Generation time: {time.time() - gen_time:.2f} seconds") | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print(f"Raw output: {generated_text}") | |
print(f"Generated token count: {len(outputs[0])}") | |
cleaned_text = clean_text(generated_text) | |
print(f"Cleaned output: {cleaned_text}") | |
elapsed_time = time.time() - start_time | |
print(f"Total time: {elapsed_time:.2f} seconds") | |
# Clear memory cache | |
gc.collect() | |
return cleaned_text | |
except Exception as e: | |
return f"Error generating text: {e}" | |
# Gradio interface | |
demo = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox( | |
label="Nhập văn bản đầu vào", | |
placeholder="Viết gì đó bằng tiếng Việt...", | |
value="Hôm nay là một ngày đẹp trời" # Default text | |
), | |
gr.Slider(0.3, 0.7, value=0.5, step=0.1, label="Nhiệt độ (Temperature, 0.3-0.5 cho tốc độ nhanh, 0.6-0.7 cho đa dạng hơn)"), | |
gr.Slider(20, 50, value=30, step=5, label="Số token mới tối đa (max_new_tokens, 20-30 cho tốc độ nhanh, 40-50 cho câu dài hơn)") | |
], | |
outputs="text", | |
title="Sinh văn bản tiếng Việt", | |
description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.3-0.5 và max_new_tokens 20-30 để đạt thời gian <2 giây. Dùng temperature 0.6-0.7 và max_new_tokens 40-50 cho câu dài và đa dạng hơn.", | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |