File size: 2,381 Bytes
6e82314
0673a12
 
 
7fcb72d
 
4bf4a35
3f4ce15
9f79fe4
2454010
9f79fe4
 
 
 
 
 
3f4ce15
6c4916c
 
 
 
 
9f79fe4
7fcb72d
 
 
4bf4a35
0673a12
 
 
 
7fcb72d
9f79fe4
6c4916c
 
 
 
 
 
 
 
 
 
9f79fe4
6c4916c
 
9f79fe4
 
 
6c4916c
 
9f79fe4
 
 
 
7fcb72d
9f79fe4
1cf6170
7fcb72d
1cf6170
7fcb72d
 
 
1cf6170
 
7fcb72d
 
 
1cf6170
3f4ce15
99c05e6
9f79fe4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import gradio as gr

# Load model and tokenizer
model_id = "NlpHUST/gpt2-vietnamese"
try:
    tokenizer = GPT2Tokenizer.from_pretrained(model_id)
    model = GPT2LMHeadModel.from_pretrained(model_id)
except Exception as e:
    print(f"Error loading model: {e}")
    raise e

# Set pad_token_id to eos_token_id if not set
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = tokenizer.eos_token_id

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Print device and memory info for debugging
print(f"Device: {device}")
print(f"Memory allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only")

def generate_text(prompt, max_length=100, temperature=1.0):
    try:
        # Encode input with attention mask
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length
        ).to(device)
        
        # Generate text
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=max_length,
            temperature=temperature,
            do_sample=True,
            num_beams=1,
            pad_token_id=tokenizer.pad_token_id
        )
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        return f"Error generating text: {e}"

# Gradio interface
demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Nhập văn bản đầu vào", placeholder="Viết gì đó bằng tiếng Việt..."),
        gr.Slider(20, 300, value=100, step=10, label="Độ dài tối đa"),
        gr.Slider(0.5, 1.5, value=1.0, step=0.1, label="Nhiệt độ (Temperature)")
    ],
    outputs="text",
    title="Sinh văn bản tiếng Việt",
    description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt.",
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)