File size: 3,962 Bytes
6e82314
4f711b0
0673a12
2cbdbe8
0673a12
 
7fcb72d
 
4bf4a35
4f711b0
3f4ce15
2cbdbe8
 
 
d78035c
 
2cbdbe8
 
 
9f79fe4
2454010
9f79fe4
 
 
 
 
 
3f4ce15
6c4916c
 
 
 
 
9f79fe4
7fcb72d
 
 
4bf4a35
d78035c
 
 
 
 
0673a12
 
 
4f711b0
 
2cbdbe8
 
d78035c
 
2cbdbe8
 
d78035c
9f79fe4
4f711b0
d78035c
6c4916c
 
 
 
 
 
 
 
 
 
9f79fe4
6c4916c
 
d78035c
 
 
 
 
2cbdbe8
d78035c
9f79fe4
4f711b0
d78035c
2cbdbe8
4f711b0
d78035c
4f711b0
2cbdbe8
9f79fe4
 
7fcb72d
9f79fe4
1cf6170
7fcb72d
1cf6170
a539a2e
 
 
d78035c
a539a2e
d78035c
 
1cf6170
 
7fcb72d
 
 
1cf6170
3f4ce15
99c05e6
d78035c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import time
import warnings
import re
warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gradio as gr
import psutil

# Print system resources for debugging
def print_system_resources():
    memory = psutil.virtual_memory()
    cpu_percent = psutil.cpu_percent(interval=1)
    print(f"Total physical memory: {memory.total/1e9:.2f} GB")
    print(f"CPU usage: {cpu_percent}%")
    print(f"Memory usage: {memory.percent}% ({memory.used/1e9:.2f}/{memory.total/1e9:.2f} GB)")

# Load model and tokenizer
model_id = "NlpHUST/gpt2-vietnamese"
try:
    tokenizer = GPT2Tokenizer.from_pretrained(model_id)
    model = GPT2LMHeadModel.from_pretrained(model_id)
except Exception as e:
    print(f"Error loading model: {e}")
    raise e

# Set pad_token_id to eos_token_id if not set
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = tokenizer.eos_token_id

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Apply quantization to reduce memory and speed up
model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

# Print device and memory info for debugging
print(f"Device: {device}")
print(f"Memory allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only")
print_system_resources()

def clean_text(text):
    """Clean generated text by removing non-alphabetic characters and extra spaces."""
    text = re.sub(r'[^\w\s.,!?àáâãèéêìíòóôõùúýăđĩũơưạảấầẩẫậắằẳẵặẹẻẽếềểễệỉịọỏốồổỗộớờởỡợụủứừửữựỳỵỷỹ]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

def generate_text(prompt, max_length=50, temperature=0.9):
    try:
        start_time = time.time()
        print_system_resources()
        # Encode input with attention mask
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)
        
        # Generate text
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=25,  # Slightly increase for more content
            min_length=10,
            do_sample=True,  # Enable sampling for diversity
            top_k=50,  # Limit to top 50 tokens
            top_p=0.9,  # Nucleus sampling
            no_repeat_ngram_size=2,
            pad_token_id=tokenizer.pad_token_id
        )
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"Raw output: {generated_text}")
        cleaned_text = clean_text(generated_text)
        elapsed_time = time.time() - start_time
        print_system_resources()
        print(f"Generation time: {elapsed_time:.2f} seconds")
        return cleaned_text
    except Exception as e:
        return f"Error generating text: {e}"

# Gradio interface
demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(
            label="Nhập văn bản đầu vào",
            placeholder="Viết gì đó bằng tiếng Việt...",
            value="Hôm nay là một ngày đẹp trời"  # Default text
        ),
        gr.Slider(20, 100, value=50, step=10, label="Độ dài tối đa"),
        gr.Slider(0.7, 1.0, value=0.9, step=0.1, label="Nhiệt độ (Temperature)")
    ],
    outputs="text",
    title="Sinh văn bản tiếng Việt",
    description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt.",
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, queue=False)