Spaces:
Running
Running
adjust generation time
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import time
|
|
3 |
import warnings
|
4 |
import re
|
5 |
import gc
|
|
|
6 |
warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")
|
7 |
|
8 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
@@ -59,16 +60,18 @@ def clean_text(text):
|
|
59 |
text = re.sub(r'\s+', ' ', text).strip()
|
60 |
return text
|
61 |
|
62 |
-
def generate_text(prompt, temperature=0.
|
63 |
try:
|
64 |
start_time = time.time()
|
|
|
|
|
65 |
# Debug memory before generation
|
66 |
print("Memory before generation:")
|
67 |
print_system_resources()
|
68 |
# Fixed parameters
|
69 |
-
|
|
|
70 |
top_p = 0.9
|
71 |
-
min_length = 50 # Force longer output
|
72 |
# Log parameters
|
73 |
print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, min_length={min_length}")
|
74 |
# Encode input with attention mask
|
@@ -82,8 +85,8 @@ def generate_text(prompt, temperature=0.9, max_new_tokens=100, top_k=30, max_len
|
|
82 |
).to(device)
|
83 |
print(f"Encoding time: {time.time() - encode_time:.2f} seconds")
|
84 |
print(f"Input tokens: {len(inputs['input_ids'][0])}")
|
85 |
-
#
|
86 |
-
eos_token_ids = []
|
87 |
print(f"EOS token IDs: {eos_token_ids}")
|
88 |
# Generate text
|
89 |
gen_time = time.time()
|
@@ -91,14 +94,14 @@ def generate_text(prompt, temperature=0.9, max_new_tokens=100, top_k=30, max_len
|
|
91 |
input_ids=inputs["input_ids"],
|
92 |
attention_mask=inputs["attention_mask"],
|
93 |
max_new_tokens=int(max_new_tokens),
|
94 |
-
min_length=min_length,
|
95 |
do_sample=True,
|
96 |
-
top_k=
|
97 |
top_p=top_p,
|
98 |
temperature=temperature,
|
99 |
repetition_penalty=repetition_penalty,
|
100 |
pad_token_id=tokenizer.pad_token_id,
|
101 |
-
eos_token_id=eos_token_ids
|
102 |
)
|
103 |
gen_duration = time.time() - gen_time
|
104 |
print(f"Generation time: {gen_duration:.2f} seconds")
|
@@ -110,6 +113,12 @@ def generate_text(prompt, temperature=0.9, max_new_tokens=100, top_k=30, max_len
|
|
110 |
print(f"Cleaned output: {cleaned_text}")
|
111 |
elapsed_time = time.time() - start_time
|
112 |
print(f"Total time: {elapsed_time:.2f} seconds")
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
# Clear memory cache
|
114 |
gc.collect()
|
115 |
if torch.cuda.is_available():
|
@@ -117,6 +126,11 @@ def generate_text(prompt, temperature=0.9, max_new_tokens=100, top_k=30, max_len
|
|
117 |
# Debug memory after generation
|
118 |
print("Memory after generation:")
|
119 |
print_system_resources()
|
|
|
|
|
|
|
|
|
|
|
120 |
return cleaned_text
|
121 |
except Exception as e:
|
122 |
return f"Error generating text: {e}"
|
@@ -128,16 +142,16 @@ demo = gr.Interface(
|
|
128 |
gr.Textbox(
|
129 |
label="Nhập văn bản đầu vào",
|
130 |
placeholder="Viết gì đó bằng tiếng Việt...",
|
131 |
-
value="Hôm nay là một ngày đẹp trời, hãy kể một câu chuyện dài về trải nghiệm của bạn."
|
132 |
),
|
133 |
-
gr.Slider(0.1,
|
134 |
-
gr.Slider(25, 100, value=
|
135 |
-
gr.Slider(
|
136 |
-
gr.Slider(
|
137 |
],
|
138 |
outputs="text",
|
139 |
title="Sinh văn bản tiếng Việt",
|
140 |
-
description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.7
|
141 |
allow_flagging="never"
|
142 |
)
|
143 |
|
|
|
3 |
import warnings
|
4 |
import re
|
5 |
import gc
|
6 |
+
import tracemalloc
|
7 |
warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")
|
8 |
|
9 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
|
|
60 |
text = re.sub(r'\s+', ' ', text).strip()
|
61 |
return text
|
62 |
|
63 |
+
def generate_text(prompt, temperature=0.7, max_new_tokens=50, min_length=50, max_length=100):
|
64 |
try:
|
65 |
start_time = time.time()
|
66 |
+
# Start tracemalloc for memory debugging
|
67 |
+
tracemalloc.start()
|
68 |
# Debug memory before generation
|
69 |
print("Memory before generation:")
|
70 |
print_system_resources()
|
71 |
# Fixed parameters
|
72 |
+
top_k = 20
|
73 |
+
repetition_penalty = 1.5
|
74 |
top_p = 0.9
|
|
|
75 |
# Log parameters
|
76 |
print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, min_length={min_length}")
|
77 |
# Encode input with attention mask
|
|
|
85 |
).to(device)
|
86 |
print(f"Encoding time: {time.time() - encode_time:.2f} seconds")
|
87 |
print(f"Input tokens: {len(inputs['input_ids'][0])}")
|
88 |
+
# Define EOS token IDs for '.', '!', '?'
|
89 |
+
eos_token_ids = [tokenizer.encode(s)[0] for s in ['.', '!', '?']]
|
90 |
print(f"EOS token IDs: {eos_token_ids}")
|
91 |
# Generate text
|
92 |
gen_time = time.time()
|
|
|
94 |
input_ids=inputs["input_ids"],
|
95 |
attention_mask=inputs["attention_mask"],
|
96 |
max_new_tokens=int(max_new_tokens),
|
97 |
+
min_length=int(min_length),
|
98 |
do_sample=True,
|
99 |
+
top_k=top_k,
|
100 |
top_p=top_p,
|
101 |
temperature=temperature,
|
102 |
repetition_penalty=repetition_penalty,
|
103 |
pad_token_id=tokenizer.pad_token_id,
|
104 |
+
eos_token_id=eos_token_ids
|
105 |
)
|
106 |
gen_duration = time.time() - gen_time
|
107 |
print(f"Generation time: {gen_duration:.2f} seconds")
|
|
|
113 |
print(f"Cleaned output: {cleaned_text}")
|
114 |
elapsed_time = time.time() - start_time
|
115 |
print(f"Total time: {elapsed_time:.2f} seconds")
|
116 |
+
# Check for repetitive text
|
117 |
+
from collections import Counter
|
118 |
+
words = cleaned_text.split()
|
119 |
+
word_counts = Counter(words)
|
120 |
+
if any(count > 3 for count in word_counts.values()):
|
121 |
+
print("Warning: Repetitive text detected")
|
122 |
# Clear memory cache
|
123 |
gc.collect()
|
124 |
if torch.cuda.is_available():
|
|
|
126 |
# Debug memory after generation
|
127 |
print("Memory after generation:")
|
128 |
print_system_resources()
|
129 |
+
# Log tracemalloc
|
130 |
+
snapshot = tracemalloc.take_snapshot()
|
131 |
+
top_stats = snapshot.statistics('lineno')
|
132 |
+
print("[Top 5 memory]", top_stats[:5])
|
133 |
+
tracemalloc.stop()
|
134 |
return cleaned_text
|
135 |
except Exception as e:
|
136 |
return f"Error generating text: {e}"
|
|
|
142 |
gr.Textbox(
|
143 |
label="Nhập văn bản đầu vào",
|
144 |
placeholder="Viết gì đó bằng tiếng Việt...",
|
145 |
+
value="Hôm nay là một ngày đẹp trời, hãy kể một câu chuyện dài về một trải nghiệm đáng nhớ của bạn, bao gồm chi tiết về bối cảnh, nhân vật, và cảm xúc."
|
146 |
),
|
147 |
+
gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Nhiệt độ (Temperature, 0.1-0.3 cho mạch lạc, 0.7-1.5 cho sáng tạo và câu dài)"),
|
148 |
+
gr.Slider(25, 100, value=50, step=5, label="Số token mới tối đa (max_new_tokens, 25-50 cho câu ngắn, 75-100 cho câu dài)"),
|
149 |
+
gr.Slider(30, 100, value=50, step=5, label="Độ dài tối thiểu (min_length, 30-50 cho câu vừa, 75-100 cho câu dài)"),
|
150 |
+
gr.Slider(50, 150, value=100, step=5, label="Độ dài tối đa (max_length, 50-100 cho câu ngắn, 100-150 cho câu dài)")
|
151 |
],
|
152 |
outputs="text",
|
153 |
title="Sinh văn bản tiếng Việt",
|
154 |
+
description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.7, max_new_tokens 50, min_length 50, và max_length 100 để tạo câu dài, mạch lạc, và hoàn chỉnh. Dùng temperature 0.1-0.3, max_new_tokens 25-50, min_length 30-50 cho câu ngắn và nhanh. Lưu ý: Dùng prompt chi tiết (bối cảnh, nhân vật, cảm xúc) để cải thiện chất lượng câu chuyện dài. Tránh max_new_tokens=100 với temperature=0.3 để giảm lặp từ. Top_k mặc định là 20.",
|
155 |
allow_flagging="never"
|
156 |
)
|
157 |
|