VietCat commited on
Commit
58c9dd0
·
1 Parent(s): 5655ff3

adjust generation time

Browse files
Files changed (1) hide show
  1. app.py +28 -14
app.py CHANGED
@@ -3,6 +3,7 @@ import time
3
  import warnings
4
  import re
5
  import gc
 
6
  warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")
7
 
8
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
@@ -59,16 +60,18 @@ def clean_text(text):
59
  text = re.sub(r'\s+', ' ', text).strip()
60
  return text
61
 
62
- def generate_text(prompt, temperature=0.9, max_new_tokens=100, top_k=30, max_length=100):
63
  try:
64
  start_time = time.time()
 
 
65
  # Debug memory before generation
66
  print("Memory before generation:")
67
  print_system_resources()
68
  # Fixed parameters
69
- repetition_penalty = 1.0
 
70
  top_p = 0.9
71
- min_length = 50 # Force longer output
72
  # Log parameters
73
  print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, min_length={min_length}")
74
  # Encode input with attention mask
@@ -82,8 +85,8 @@ def generate_text(prompt, temperature=0.9, max_new_tokens=100, top_k=30, max_len
82
  ).to(device)
83
  print(f"Encoding time: {time.time() - encode_time:.2f} seconds")
84
  print(f"Input tokens: {len(inputs['input_ids'][0])}")
85
- # Avoid stopping at '.', '!', '?' to allow longer text
86
- eos_token_ids = [] # Disable EOS tokens
87
  print(f"EOS token IDs: {eos_token_ids}")
88
  # Generate text
89
  gen_time = time.time()
@@ -91,14 +94,14 @@ def generate_text(prompt, temperature=0.9, max_new_tokens=100, top_k=30, max_len
91
  input_ids=inputs["input_ids"],
92
  attention_mask=inputs["attention_mask"],
93
  max_new_tokens=int(max_new_tokens),
94
- min_length=min_length,
95
  do_sample=True,
96
- top_k=int(top_k),
97
  top_p=top_p,
98
  temperature=temperature,
99
  repetition_penalty=repetition_penalty,
100
  pad_token_id=tokenizer.pad_token_id,
101
- eos_token_id=eos_token_ids if eos_token_ids else None
102
  )
103
  gen_duration = time.time() - gen_time
104
  print(f"Generation time: {gen_duration:.2f} seconds")
@@ -110,6 +113,12 @@ def generate_text(prompt, temperature=0.9, max_new_tokens=100, top_k=30, max_len
110
  print(f"Cleaned output: {cleaned_text}")
111
  elapsed_time = time.time() - start_time
112
  print(f"Total time: {elapsed_time:.2f} seconds")
 
 
 
 
 
 
113
  # Clear memory cache
114
  gc.collect()
115
  if torch.cuda.is_available():
@@ -117,6 +126,11 @@ def generate_text(prompt, temperature=0.9, max_new_tokens=100, top_k=30, max_len
117
  # Debug memory after generation
118
  print("Memory after generation:")
119
  print_system_resources()
 
 
 
 
 
120
  return cleaned_text
121
  except Exception as e:
122
  return f"Error generating text: {e}"
@@ -128,16 +142,16 @@ demo = gr.Interface(
128
  gr.Textbox(
129
  label="Nhập văn bản đầu vào",
130
  placeholder="Viết gì đó bằng tiếng Việt...",
131
- value="Hôm nay là một ngày đẹp trời, hãy kể một câu chuyện dài về trải nghiệm của bạn." # Updated prompt
132
  ),
133
- gr.Slider(0.1, 0.9, value=0.9, step=0.1, label="Nhiệt độ (Temperature, 0.1-0.3 cho mạch lạc, 0.7-0.9 cho câu dàisáng tạo)"),
134
- gr.Slider(25, 100, value=100, step=5, label="Số token mới tối đa (max_new_tokens, 25-50 cho câu ngắn, 75-100 cho câu dài)"),
135
- gr.Slider(5, 30, value=30, step=5, label="Top K (top_k, 5-10 cho mạch lạc, 20-30 cho đa dạng và dài hơn)"),
136
- gr.Slider(30, 100, value=100, step=5, label="Độ dài tối đa (max_length, 30-50 cho câu ngắn, 75-100 cho câu dài)")
137
  ],
138
  outputs="text",
139
  title="Sinh văn bản tiếng Việt",
140
- description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.7-0.9, max_new_tokens 75-100, top_k 20-30, và max_length 75-100 để tạo câu dài và đa dạng. Dùng temperature 0.1-0.3, max_new_tokens 25-50, top_k 5-10, và max_length 30-50 để đạt thời gian <2 giây văn bản mạch lạc. Lưu ý: Để tạo câu dài, dùng prompt yêu cầu câu chuyện hoặc tả chi tiết, và thử temperature=0.9, top_k=30, min_length=50.",
141
  allow_flagging="never"
142
  )
143
 
 
3
  import warnings
4
  import re
5
  import gc
6
+ import tracemalloc
7
  warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")
8
 
9
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
60
  text = re.sub(r'\s+', ' ', text).strip()
61
  return text
62
 
63
+ def generate_text(prompt, temperature=0.7, max_new_tokens=50, min_length=50, max_length=100):
64
  try:
65
  start_time = time.time()
66
+ # Start tracemalloc for memory debugging
67
+ tracemalloc.start()
68
  # Debug memory before generation
69
  print("Memory before generation:")
70
  print_system_resources()
71
  # Fixed parameters
72
+ top_k = 20
73
+ repetition_penalty = 1.5
74
  top_p = 0.9
 
75
  # Log parameters
76
  print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, min_length={min_length}")
77
  # Encode input with attention mask
 
85
  ).to(device)
86
  print(f"Encoding time: {time.time() - encode_time:.2f} seconds")
87
  print(f"Input tokens: {len(inputs['input_ids'][0])}")
88
+ # Define EOS token IDs for '.', '!', '?'
89
+ eos_token_ids = [tokenizer.encode(s)[0] for s in ['.', '!', '?']]
90
  print(f"EOS token IDs: {eos_token_ids}")
91
  # Generate text
92
  gen_time = time.time()
 
94
  input_ids=inputs["input_ids"],
95
  attention_mask=inputs["attention_mask"],
96
  max_new_tokens=int(max_new_tokens),
97
+ min_length=int(min_length),
98
  do_sample=True,
99
+ top_k=top_k,
100
  top_p=top_p,
101
  temperature=temperature,
102
  repetition_penalty=repetition_penalty,
103
  pad_token_id=tokenizer.pad_token_id,
104
+ eos_token_id=eos_token_ids
105
  )
106
  gen_duration = time.time() - gen_time
107
  print(f"Generation time: {gen_duration:.2f} seconds")
 
113
  print(f"Cleaned output: {cleaned_text}")
114
  elapsed_time = time.time() - start_time
115
  print(f"Total time: {elapsed_time:.2f} seconds")
116
+ # Check for repetitive text
117
+ from collections import Counter
118
+ words = cleaned_text.split()
119
+ word_counts = Counter(words)
120
+ if any(count > 3 for count in word_counts.values()):
121
+ print("Warning: Repetitive text detected")
122
  # Clear memory cache
123
  gc.collect()
124
  if torch.cuda.is_available():
 
126
  # Debug memory after generation
127
  print("Memory after generation:")
128
  print_system_resources()
129
+ # Log tracemalloc
130
+ snapshot = tracemalloc.take_snapshot()
131
+ top_stats = snapshot.statistics('lineno')
132
+ print("[Top 5 memory]", top_stats[:5])
133
+ tracemalloc.stop()
134
  return cleaned_text
135
  except Exception as e:
136
  return f"Error generating text: {e}"
 
142
  gr.Textbox(
143
  label="Nhập văn bản đầu vào",
144
  placeholder="Viết gì đó bằng tiếng Việt...",
145
+ value="Hôm nay là một ngày đẹp trời, hãy kể một câu chuyện dài về một trải nghiệm đáng nhớ của bạn, bao gồm chi tiết về bối cảnh, nhân vật, và cảm xúc."
146
  ),
147
+ gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Nhiệt độ (Temperature, 0.1-0.3 cho mạch lạc, 0.7-1.5 cho sáng tạocâu dài)"),
148
+ gr.Slider(25, 100, value=50, step=5, label="Số token mới tối đa (max_new_tokens, 25-50 cho câu ngắn, 75-100 cho câu dài)"),
149
+ gr.Slider(30, 100, value=50, step=5, label="Độ dài tối thiểu (min_length, 30-50 cho câu vừa, 75-100 cho câu dài)"),
150
+ gr.Slider(50, 150, value=100, step=5, label="Độ dài tối đa (max_length, 50-100 cho câu ngắn, 100-150 cho câu dài)")
151
  ],
152
  outputs="text",
153
  title="Sinh văn bản tiếng Việt",
154
+ description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.7, max_new_tokens 50, min_length 50, và max_length 100 để tạo câu dài, mạch lạc, hoàn chỉnh. Dùng temperature 0.1-0.3, max_new_tokens 25-50, min_length 30-50 cho câu ngắn nhanh. Lưu ý: Dùng prompt chi tiết (bối cảnh, nhân vật, cảm xúc) để cải thiện chất lượng câu chuyện dài. Tránh max_new_tokens=100 với temperature=0.3 để giảm lặp từ. Top_k mặc định là 20.",
155
  allow_flagging="never"
156
  )
157