Spaces:
Running
Running
adjust generation time
Browse files
app.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import time
|
3 |
import warnings
|
4 |
import re
|
|
|
5 |
warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")
|
6 |
|
7 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
@@ -16,9 +17,9 @@ def print_system_resources():
|
|
16 |
# Get container memory limit (for Docker)
|
17 |
try:
|
18 |
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f:
|
19 |
-
mem_limit = int(f.read().strip()) / 1e9 #
|
20 |
except:
|
21 |
-
mem_limit = 16.0 # Fallback for HFS free
|
22 |
print(f"Total physical memory (psutil): {memory.total/1e9:.2f} GB")
|
23 |
print(f"Container memory limit: {mem_limit:.2f} GB")
|
24 |
print(f"CPU usage: {cpu_percent}%")
|
@@ -67,12 +68,12 @@ def clean_text(text):
|
|
67 |
if complete_sentences:
|
68 |
text = ' '.join(complete_sentences)
|
69 |
else:
|
70 |
-
# Fallback: Keep until last valid word
|
71 |
words = text.split()
|
72 |
text = ' '.join(words[:-1]) if len(words) > 1 else text
|
73 |
return text
|
74 |
|
75 |
-
def generate_text(prompt, max_length=
|
76 |
try:
|
77 |
start_time = time.time()
|
78 |
print_system_resources()
|
@@ -84,18 +85,18 @@ def generate_text(prompt, max_length=50, temperature=0.9):
|
|
84 |
truncation=True,
|
85 |
max_length=max_length
|
86 |
).to(device)
|
|
|
87 |
|
88 |
# Generate text
|
89 |
outputs = model.generate(
|
90 |
input_ids=inputs["input_ids"],
|
91 |
attention_mask=inputs["attention_mask"],
|
92 |
-
max_new_tokens=
|
93 |
-
min_length=
|
94 |
do_sample=True,
|
95 |
-
top_k=
|
96 |
-
top_p=0.
|
97 |
temperature=temperature,
|
98 |
-
no_repeat_ngram_size=2,
|
99 |
pad_token_id=tokenizer.pad_token_id
|
100 |
)
|
101 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
@@ -104,6 +105,9 @@ def generate_text(prompt, max_length=50, temperature=0.9):
|
|
104 |
cleaned_text = clean_text(generated_text)
|
105 |
elapsed_time = time.time() - start_time
|
106 |
print(f"Generation time: {elapsed_time:.2f} seconds")
|
|
|
|
|
|
|
107 |
return cleaned_text
|
108 |
except Exception as e:
|
109 |
return f"Error generating text: {e}"
|
@@ -117,8 +121,8 @@ demo = gr.Interface(
|
|
117 |
placeholder="Viết gì đó bằng tiếng Việt...",
|
118 |
value="Hôm nay là một ngày đẹp trời" # Default text
|
119 |
),
|
120 |
-
gr.Slider(20, 100, value=
|
121 |
-
gr.Slider(0.7, 1.0, value=0
|
122 |
],
|
123 |
outputs="text",
|
124 |
title="Sinh văn bản tiếng Việt",
|
|
|
2 |
import time
|
3 |
import warnings
|
4 |
import re
|
5 |
+
import gc
|
6 |
warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")
|
7 |
|
8 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
|
|
17 |
# Get container memory limit (for Docker)
|
18 |
try:
|
19 |
with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f:
|
20 |
+
mem_limit = min(int(f.read().strip()) / 1e9, 16.0) # Cap at 16GB for HFS free
|
21 |
except:
|
22 |
+
mem_limit = 16.0 # Fallback for HFS free
|
23 |
print(f"Total physical memory (psutil): {memory.total/1e9:.2f} GB")
|
24 |
print(f"Container memory limit: {mem_limit:.2f} GB")
|
25 |
print(f"CPU usage: {cpu_percent}%")
|
|
|
68 |
if complete_sentences:
|
69 |
text = ' '.join(complete_sentences)
|
70 |
else:
|
71 |
+
# Fallback: Keep until last valid word
|
72 |
words = text.split()
|
73 |
text = ' '.join(words[:-1]) if len(words) > 1 else text
|
74 |
return text
|
75 |
|
76 |
+
def generate_text(prompt, max_length=40, temperature=1.0):
|
77 |
try:
|
78 |
start_time = time.time()
|
79 |
print_system_resources()
|
|
|
85 |
truncation=True,
|
86 |
max_length=max_length
|
87 |
).to(device)
|
88 |
+
print(f"Input tokens: {len(inputs['input_ids'][0])}")
|
89 |
|
90 |
# Generate text
|
91 |
outputs = model.generate(
|
92 |
input_ids=inputs["input_ids"],
|
93 |
attention_mask=inputs["attention_mask"],
|
94 |
+
max_new_tokens=10, # Reduce for speed
|
95 |
+
min_length=5,
|
96 |
do_sample=True,
|
97 |
+
top_k=20, # Reduce for speed
|
98 |
+
top_p=0.85, # Prioritize high-quality tokens
|
99 |
temperature=temperature,
|
|
|
100 |
pad_token_id=tokenizer.pad_token_id
|
101 |
)
|
102 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
105 |
cleaned_text = clean_text(generated_text)
|
106 |
elapsed_time = time.time() - start_time
|
107 |
print(f"Generation time: {elapsed_time:.2f} seconds")
|
108 |
+
# Clear memory cache
|
109 |
+
gc.collect()
|
110 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
111 |
return cleaned_text
|
112 |
except Exception as e:
|
113 |
return f"Error generating text: {e}"
|
|
|
121 |
placeholder="Viết gì đó bằng tiếng Việt...",
|
122 |
value="Hôm nay là một ngày đẹp trời" # Default text
|
123 |
),
|
124 |
+
gr.Slider(20, 100, value=40, step=10, label="Độ dài tối đa"),
|
125 |
+
gr.Slider(0.7, 1.0, value=1.0, step=0.1, label="Nhiệt độ (Temperature)")
|
126 |
],
|
127 |
outputs="text",
|
128 |
title="Sinh văn bản tiếng Việt",
|