Spaces:
Sleeping
Sleeping
File size: 8,282 Bytes
6e82314 4f711b0 0673a12 2cbdbe8 937d274 58c9dd0 535d128 0673a12 7fcb72d 4bf4a35 4f711b0 3f4ce15 2cbdbe8 d78035c 8d2936e 19763d1 8d2936e 19763d1 8d2936e 2cbdbe8 abe4cc3 b638223 5655ff3 b638223 2cbdbe8 9f79fe4 2454010 9f79fe4 3f4ce15 6c4916c 9f79fe4 7fcb72d 4bf4a35 0673a12 4f711b0 2cbdbe8 d86d021 6fc92b9 2cbdbe8 d86d021 9f79fe4 4f711b0 58c9dd0 19763d1 d78035c e3ea79a 58c9dd0 4a6a9fa 5655ff3 6fc92b9 5655ff3 6c4916c e3ea79a 6c4916c e3ea79a 937d274 58c9dd0 e3ea79a 4a6a9fa e3ea79a 4a6a9fa d86d021 4a6a9fa 19763d1 d86d021 4f711b0 d78035c abe4cc3 2cbdbe8 6fc92b9 4f711b0 e3ea79a 58c9dd0 937d274 19763d1 58c9dd0 2cbdbe8 9f79fe4 7fcb72d 9f79fe4 1cf6170 7fcb72d 1cf6170 a539a2e d86d021 a539a2e d86d021 58c9dd0 1cf6170 7fcb72d d86d021 7fcb72d 1cf6170 3f4ce15 99c05e6 8d2936e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import os
import time
import warnings
import re
import gc
import tracemalloc
warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gradio as gr
import psutil
# Print system resources for debugging
def print_system_resources():
memory = psutil.virtual_memory()
cpu_percent = psutil.cpu_percent(interval=1)
try:
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f:
mem_limit = min(int(f.read().strip()) / 1e9, 16.0)
except:
mem_limit = 16.0
print(f"Total physical memory (psutil): {memory.total/1e9:.2f} GB")
print(f"Container memory limit: {mem_limit:.2f} GB")
print(f"CPU usage: {cpu_percent}%")
print(f"Memory usage: {min(memory.used / (mem_limit * 1e9) * 100, 100):.1f}% ({memory.used/1e9:.2f}/{mem_limit:.2f} GB)")
print(f"Active processes: {len(psutil.pids())}")
print(f"Python memory (torch): {torch.cuda.memory_allocated()/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only")
# Print Gradio version for debugging
print(f"Gradio version: {gr.__version__}")
# Load model and tokenizer
model_id = "NlpHUST/gpt2-vietnamese"
try:
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
model = GPT2LMHeadModel.from_pretrained(model_id)
except Exception as e:
print(f"Error loading model: {e}")
raise e
# Set pad_token_id to eos_token_id if not set
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.eos_token_id
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# Print device and memory info for debugging
print(f"Device: {device}")
print(f"Memory allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only")
print_system_resources()
def clean_text(text):
"""Normalize text by removing invalid characters while preserving essential punctuation, quotes, and line breaks."""
# Step 1: Keep allowed characters
# \w: letters, digits, underscore
# \s: whitespace, including line breaks
# Specific punctuation: .,?!:;"'()-–/…
# Vietnamese characters (Unicode)
allowed_pattern = r'[^\w\s.,?!:;"\'()\–\-/…àáâãèéêìíòóôõùúýăđĩũơưạảấầẩẫậắằẳẵặẹẻẽếềểễệỉịọỏốồổỗộớờởỡợụủứừửữựỳỵỷỹ]'
# Find and log removed characters for debugging
removed_chars = set(re.findall(allowed_pattern, text))
if removed_chars:
print(f"Removed characters: {removed_chars}")
# Remove invalid characters
text = re.sub(allowed_pattern, '', text)
# Step 2: Normalize spaces (exclude line breaks)
# Split by line breaks to preserve them
lines = text.split('\n')
# Normalize spaces within each line
lines = [re.sub(r'[ \t]+', ' ', line).strip(' ') for line in lines]
# Join lines back with line breaks
text = '\n'.join(lines)
# Step 3: Remove leading/trailing empty lines
text = text.strip('\n')
return text
def generate_text(prompt, temperature=0.9, max_new_tokens=40, min_length=10, max_length=100):
try:
start_time = time.time()
# Start tracemalloc for memory debugging
tracemalloc.start()
# Debug memory before generation
print("Memory before generation:")
print_system_resources()
# Fixed parameters
top_k = 20
repetition_penalty = 2.0
top_p = 0.9
# Log parameters
print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, min_length={min_length}")
# Encode input with attention mask
encode_time = time.time()
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length
).to(device)
print(f"Encoding time: {time.time() - encode_time:.2f} seconds")
print(f"Input tokens: {len(inputs['input_ids'][0])}")
# Define EOS token IDs for '.', '!', '?'
eos_token_ids = [tokenizer.encode(s)[0] for s in ['.', '!', '?']]
print(f"EOS token IDs: {eos_token_ids}")
# Generate text with no_grad
gen_time = time.time()
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_new_tokens=int(max_new_tokens),
min_length=int(min_length),
do_sample=True,
top_k=top_k,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=eos_token_ids,
early_stopping=True
)
gen_duration = time.time() - gen_time
print(f"Generation time: {gen_duration:.2f} seconds")
print(f"Tokens per second: {len(outputs[0]) / gen_duration:.2f}")
print(f"Last token: {outputs[0][-1].item()}")
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Raw output: {generated_text}")
print(f"Generated token count: {len(outputs[0])}")
cleaned_text = clean_text(generated_text)
print(f"Cleaned output: {cleaned_text}")
elapsed_time = time.time() - start_time
print(f"Total time: {elapsed_time:.2f} seconds")
# Check for repetitive text
from collections import Counter
words = cleaned_text.split()
word_counts = Counter(words)
if any(count > 3 for count in word_counts.values()):
print("Warning: Repetitive text detected")
# Clear memory cache
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Debug memory after generation
print("Memory after generation:")
print_system_resources()
# Log tracemalloc
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
print("[Top 5 memory]", top_stats[:5])
tracemalloc.stop()
return cleaned_text
except Exception as e:
return f"Error generating text: {e}"
# Gradio interface
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(
label="Nhập văn bản đầu vào",
placeholder="Viết gì đó bằng tiếng Việt...",
value="Hôm nay là một ngày đẹp trời, hãy kể một câu chuyện dài về một chuyến đi du lịch đáng nhớ, mô tả chi tiết địa điểm bạn đến, người bạn gặp, và cảm xúc của bạn. Câu chuyện phải có bối cảnh rõ ràng, ít nhất một nhân vật, và kết thúc hoàn chỉnh."
),
gr.Slider(0.3, 1.2, value=0.9, step=0.1, label="Nhiệt độ (Temperature, 0.3-0.5 cho mạch lạc, 0.7-1.0 cho sáng tạo)"),
gr.Slider(25, 75, value=40, step=5, label="Số token mới tối đa (max_new_tokens, 30-50 cho câu vừa, 50-75 cho câu dài)"),
gr.Slider(10, 50, value=10, step=5, label="Độ dài tối thiểu (min_length, 10-30 cho câu vừa, 30-50 cho câu dài)"),
gr.Slider(50, 150, value=100, step=5, label="Độ dài tối đa (max_length, 50-100 cho câu ngắn, 100-150 cho câu dài)")
],
outputs="text",
title="Sinh văn bản tiếng Việt",
description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.7-1.0, max_new_tokens 30-50, min_length 10-30, và max_length 100 để tạo câu dài, mạch lạc, và hoàn chỉnh. Tránh max_new_tokens>75 (chậm, dễ cắt cụt). Dùng prompt chi tiết (bối cảnh, nhân vật, cảm xúc, kết thúc hoàn chỉnh) để cải thiện chất lượng. Top_k mặc định là 20.",
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860) |