Spaces:
Running
Running
adjust generation time
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import time
|
|
3 |
import warnings
|
4 |
import re
|
5 |
import gc
|
6 |
-
warnings.
|
7 |
|
8 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
9 |
import torch
|
@@ -65,15 +65,18 @@ def clean_text(text):
|
|
65 |
text = re.sub(r'\s+', ' ', text).strip()
|
66 |
return text
|
67 |
|
68 |
-
def generate_text(prompt, temperature=0.
|
69 |
try:
|
70 |
start_time = time.time()
|
71 |
print_system_resources()
|
|
|
|
|
|
|
|
|
72 |
# Log parameters
|
73 |
-
max_length =
|
74 |
-
top_k = 20 # Fixed
|
75 |
-
print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}")
|
76 |
# Encode input with attention mask
|
|
|
77 |
inputs = tokenizer(
|
78 |
prompt,
|
79 |
return_tensors="pt",
|
@@ -81,9 +84,13 @@ def generate_text(prompt, temperature=0.7, max_new_tokens=10):
|
|
81 |
truncation=True,
|
82 |
max_length=max_length
|
83 |
).to(device)
|
|
|
84 |
print(f"Input tokens: {len(inputs['input_ids'][0])}")
|
85 |
-
|
|
|
|
|
86 |
# Generate text
|
|
|
87 |
outputs = model.generate(
|
88 |
input_ids=inputs["input_ids"],
|
89 |
attention_mask=inputs["attention_mask"],
|
@@ -92,16 +99,18 @@ def generate_text(prompt, temperature=0.7, max_new_tokens=10):
|
|
92 |
do_sample=True,
|
93 |
top_k=int(top_k),
|
94 |
temperature=temperature,
|
|
|
95 |
pad_token_id=tokenizer.pad_token_id,
|
96 |
-
eos_token_id=
|
97 |
)
|
|
|
98 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
99 |
print(f"Raw output: {generated_text}")
|
100 |
print(f"Generated token count: {len(outputs[0])}")
|
101 |
cleaned_text = clean_text(generated_text)
|
102 |
print(f"Cleaned output: {cleaned_text}")
|
103 |
elapsed_time = time.time() - start_time
|
104 |
-
print(f"
|
105 |
# Clear memory cache
|
106 |
gc.collect()
|
107 |
return cleaned_text
|
@@ -117,12 +126,12 @@ demo = gr.Interface(
|
|
117 |
placeholder="Viết gì đó bằng tiếng Việt...",
|
118 |
value="Hôm nay là một ngày đẹp trời" # Default text
|
119 |
),
|
120 |
-
gr.Slider(0.
|
121 |
-
gr.Slider(
|
122 |
],
|
123 |
outputs="text",
|
124 |
title="Sinh văn bản tiếng Việt",
|
125 |
-
description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt.
|
126 |
allow_flagging="never"
|
127 |
)
|
128 |
|
|
|
3 |
import warnings
|
4 |
import re
|
5 |
import gc
|
6 |
+
warnings.filterWarnings("ignore", category=UserWarning, module="torch._utils")
|
7 |
|
8 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
9 |
import torch
|
|
|
65 |
text = re.sub(r'\s+', ' ', text).strip()
|
66 |
return text
|
67 |
|
68 |
+
def generate_text(prompt, temperature=0.5, max_new_tokens=30):
|
69 |
try:
|
70 |
start_time = time.time()
|
71 |
print_system_resources()
|
72 |
+
# Fixed parameters
|
73 |
+
max_length = 50
|
74 |
+
top_k = 20
|
75 |
+
repetition_penalty = 1.2
|
76 |
# Log parameters
|
77 |
+
print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, repetition_penalty={repetition_penalty}")
|
|
|
|
|
78 |
# Encode input with attention mask
|
79 |
+
encode_time = time.time()
|
80 |
inputs = tokenizer(
|
81 |
prompt,
|
82 |
return_tensors="pt",
|
|
|
84 |
truncation=True,
|
85 |
max_length=max_length
|
86 |
).to(device)
|
87 |
+
print(f"Encoding time: {time.time() - encode_time:.2f} seconds")
|
88 |
print(f"Input tokens: {len(inputs['input_ids'][0])}")
|
89 |
+
# Define EOS token IDs for '.', '!', '?'
|
90 |
+
eos_token_ids = [tokenizer.encode(s)[0] for s in ['.', '!', '?']]
|
91 |
+
print(f"EOS token IDs: {eos_token_ids}")
|
92 |
# Generate text
|
93 |
+
gen_time = time.time()
|
94 |
outputs = model.generate(
|
95 |
input_ids=inputs["input_ids"],
|
96 |
attention_mask=inputs["attention_mask"],
|
|
|
99 |
do_sample=True,
|
100 |
top_k=int(top_k),
|
101 |
temperature=temperature,
|
102 |
+
repetition_penalty=repetition_penalty,
|
103 |
pad_token_id=tokenizer.pad_token_id,
|
104 |
+
eos_token_id=eos_token_ids
|
105 |
)
|
106 |
+
print(f"Generation time: {time.time() - gen_time:.2f} seconds")
|
107 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
108 |
print(f"Raw output: {generated_text}")
|
109 |
print(f"Generated token count: {len(outputs[0])}")
|
110 |
cleaned_text = clean_text(generated_text)
|
111 |
print(f"Cleaned output: {cleaned_text}")
|
112 |
elapsed_time = time.time() - start_time
|
113 |
+
print(f"Total time: {elapsed_time:.2f} seconds")
|
114 |
# Clear memory cache
|
115 |
gc.collect()
|
116 |
return cleaned_text
|
|
|
126 |
placeholder="Viết gì đó bằng tiếng Việt...",
|
127 |
value="Hôm nay là một ngày đẹp trời" # Default text
|
128 |
),
|
129 |
+
gr.Slider(0.3, 0.7, value=0.5, step=0.1, label="Nhiệt độ (Temperature, 0.3-0.5 cho tốc độ nhanh, 0.6-0.7 cho đa dạng hơn)"),
|
130 |
+
gr.Slider(20, 50, value=30, step=5, label="Số token mới tối đa (max_new_tokens, 20-30 cho tốc độ nhanh, 40-50 cho câu dài hơn)")
|
131 |
],
|
132 |
outputs="text",
|
133 |
title="Sinh văn bản tiếng Việt",
|
134 |
+
description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.3-0.5 và max_new_tokens 20-30 để đạt thời gian <2 giây. Dùng temperature 0.6-0.7 và max_new_tokens 40-50 cho câu dài và đa dạng hơn.",
|
135 |
allow_flagging="never"
|
136 |
)
|
137 |
|