File size: 8,282 Bytes
6e82314
4f711b0
0673a12
2cbdbe8
937d274
58c9dd0
535d128
0673a12
7fcb72d
 
4bf4a35
4f711b0
3f4ce15
2cbdbe8
 
 
d78035c
8d2936e
 
19763d1
8d2936e
19763d1
8d2936e
 
2cbdbe8
abe4cc3
b638223
5655ff3
b638223
 
 
2cbdbe8
9f79fe4
2454010
9f79fe4
 
 
 
 
 
3f4ce15
6c4916c
 
 
 
 
9f79fe4
7fcb72d
 
 
4bf4a35
0673a12
 
 
4f711b0
 
2cbdbe8
d86d021
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6fc92b9
2cbdbe8
d86d021
9f79fe4
4f711b0
58c9dd0
 
19763d1
 
d78035c
e3ea79a
58c9dd0
4a6a9fa
5655ff3
6fc92b9
5655ff3
6c4916c
e3ea79a
6c4916c
 
 
 
 
 
 
e3ea79a
937d274
58c9dd0
 
e3ea79a
4a6a9fa
e3ea79a
4a6a9fa
 
 
 
 
 
 
 
 
 
 
 
d86d021
 
4a6a9fa
19763d1
 
 
d86d021
4f711b0
d78035c
abe4cc3
2cbdbe8
6fc92b9
4f711b0
e3ea79a
58c9dd0
 
 
 
 
 
937d274
 
19763d1
 
 
 
 
58c9dd0
 
 
 
 
2cbdbe8
9f79fe4
 
7fcb72d
9f79fe4
1cf6170
7fcb72d
1cf6170
a539a2e
 
 
d86d021
a539a2e
d86d021
 
 
58c9dd0
1cf6170
 
7fcb72d
d86d021
7fcb72d
1cf6170
3f4ce15
99c05e6
8d2936e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import time
import warnings
import re
import gc
import tracemalloc
warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gradio as gr
import psutil

# Print system resources for debugging
def print_system_resources():
    memory = psutil.virtual_memory()
    cpu_percent = psutil.cpu_percent(interval=1)
    try:
        with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f:
            mem_limit = min(int(f.read().strip()) / 1e9, 16.0)
    except:
        mem_limit = 16.0
    print(f"Total physical memory (psutil): {memory.total/1e9:.2f} GB")
    print(f"Container memory limit: {mem_limit:.2f} GB")
    print(f"CPU usage: {cpu_percent}%")
    print(f"Memory usage: {min(memory.used / (mem_limit * 1e9) * 100, 100):.1f}% ({memory.used/1e9:.2f}/{mem_limit:.2f} GB)")
    print(f"Active processes: {len(psutil.pids())}")
    print(f"Python memory (torch): {torch.cuda.memory_allocated()/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only")

# Print Gradio version for debugging
print(f"Gradio version: {gr.__version__}")

# Load model and tokenizer
model_id = "NlpHUST/gpt2-vietnamese"
try:
    tokenizer = GPT2Tokenizer.from_pretrained(model_id)
    model = GPT2LMHeadModel.from_pretrained(model_id)
except Exception as e:
    print(f"Error loading model: {e}")
    raise e

# Set pad_token_id to eos_token_id if not set
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = tokenizer.eos_token_id

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Print device and memory info for debugging
print(f"Device: {device}")
print(f"Memory allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only")
print_system_resources()

def clean_text(text):
    """Normalize text by removing invalid characters while preserving essential punctuation, quotes, and line breaks."""
    # Step 1: Keep allowed characters
    # \w: letters, digits, underscore
    # \s: whitespace, including line breaks
    # Specific punctuation: .,?!:;"'()-–/…
    # Vietnamese characters (Unicode)
    allowed_pattern = r'[^\w\s.,?!:;"\'()\–\-/…àáâãèéêìíòóôõùúýăđĩũơưạảấầẩẫậắằẳẵặẹẻẽếềểễệỉịọỏốồổỗộớờởỡợụủứừửữựỳỵỷỹ]'
    # Find and log removed characters for debugging
    removed_chars = set(re.findall(allowed_pattern, text))
    if removed_chars:
        print(f"Removed characters: {removed_chars}")
    # Remove invalid characters
    text = re.sub(allowed_pattern, '', text)
    # Step 2: Normalize spaces (exclude line breaks)
    # Split by line breaks to preserve them
    lines = text.split('\n')
    # Normalize spaces within each line
    lines = [re.sub(r'[ \t]+', ' ', line).strip(' ') for line in lines]
    # Join lines back with line breaks
    text = '\n'.join(lines)
    # Step 3: Remove leading/trailing empty lines
    text = text.strip('\n')
    return text

def generate_text(prompt, temperature=0.9, max_new_tokens=40, min_length=10, max_length=100):
    try:
        start_time = time.time()
        # Start tracemalloc for memory debugging
        tracemalloc.start()
        # Debug memory before generation
        print("Memory before generation:")
        print_system_resources()
        # Fixed parameters
        top_k = 20
        repetition_penalty = 2.0
        top_p = 0.9
        # Log parameters
        print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, min_length={min_length}")
        # Encode input with attention mask
        encode_time = time.time()
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)
        print(f"Encoding time: {time.time() - encode_time:.2f} seconds")
        print(f"Input tokens: {len(inputs['input_ids'][0])}")
        # Define EOS token IDs for '.', '!', '?'
        eos_token_ids = [tokenizer.encode(s)[0] for s in ['.', '!', '?']]
        print(f"EOS token IDs: {eos_token_ids}")
        # Generate text with no_grad
        gen_time = time.time()
        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                max_new_tokens=int(max_new_tokens),
                min_length=int(min_length),
                do_sample=True,
                top_k=top_k,
                top_p=top_p,
                temperature=temperature,
                repetition_penalty=repetition_penalty,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=eos_token_ids,
                early_stopping=True
            )
        gen_duration = time.time() - gen_time
        print(f"Generation time: {gen_duration:.2f} seconds")
        print(f"Tokens per second: {len(outputs[0]) / gen_duration:.2f}")
        print(f"Last token: {outputs[0][-1].item()}")
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"Raw output: {generated_text}")
        print(f"Generated token count: {len(outputs[0])}")
        cleaned_text = clean_text(generated_text)
        print(f"Cleaned output: {cleaned_text}")
        elapsed_time = time.time() - start_time
        print(f"Total time: {elapsed_time:.2f} seconds")
        # Check for repetitive text
        from collections import Counter
        words = cleaned_text.split()
        word_counts = Counter(words)
        if any(count > 3 for count in word_counts.values()):
            print("Warning: Repetitive text detected")
        # Clear memory cache
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        # Debug memory after generation
        print("Memory after generation:")
        print_system_resources()
        # Log tracemalloc
        snapshot = tracemalloc.take_snapshot()
        top_stats = snapshot.statistics('lineno')
        print("[Top 5 memory]", top_stats[:5])
        tracemalloc.stop()
        return cleaned_text
    except Exception as e:
        return f"Error generating text: {e}"

# Gradio interface
demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(
            label="Nhập văn bản đầu vào",
            placeholder="Viết gì đó bằng tiếng Việt...",
            value="Hôm nay là một ngày đẹp trời, hãy kể một câu chuyện dài về một chuyến đi du lịch đáng nhớ, mô tả chi tiết địa điểm bạn đến, người bạn gặp, và cảm xúc của bạn. Câu chuyện phải có bối cảnh rõ ràng, ít nhất một nhân vật, và kết thúc hoàn chỉnh."
        ),
        gr.Slider(0.3, 1.2, value=0.9, step=0.1, label="Nhiệt độ (Temperature, 0.3-0.5 cho mạch lạc, 0.7-1.0 cho sáng tạo)"),
        gr.Slider(25, 75, value=40, step=5, label="Số token mới tối đa (max_new_tokens, 30-50 cho câu vừa, 50-75 cho câu dài)"),
        gr.Slider(10, 50, value=10, step=5, label="Độ dài tối thiểu (min_length, 10-30 cho câu vừa, 30-50 cho câu dài)"),
        gr.Slider(50, 150, value=100, step=5, label="Độ dài tối đa (max_length, 50-100 cho câu ngắn, 100-150 cho câu dài)")
    ],
    outputs="text",
    title="Sinh văn bản tiếng Việt",
    description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.7-1.0, max_new_tokens 30-50, min_length 10-30, và max_length 100 để tạo câu dài, mạch lạc, và hoàn chỉnh. Tránh max_new_tokens>75 (chậm, dễ cắt cụt). Dùng prompt chi tiết (bối cảnh, nhân vật, cảm xúc, kết thúc hoàn chỉnh) để cải thiện chất lượng. Top_k mặc định là 20.",
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)