VietCat commited on
Commit
e3ea79a
·
1 Parent(s): 6c05290

adjust generation time

Browse files
Files changed (1) hide show
  1. app.py +20 -11
app.py CHANGED
@@ -3,7 +3,7 @@ import time
3
  import warnings
4
  import re
5
  import gc
6
- warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")
7
 
8
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
9
  import torch
@@ -65,15 +65,18 @@ def clean_text(text):
65
  text = re.sub(r'\s+', ' ', text).strip()
66
  return text
67
 
68
- def generate_text(prompt, temperature=0.7, max_new_tokens=10):
69
  try:
70
  start_time = time.time()
71
  print_system_resources()
 
 
 
 
72
  # Log parameters
73
- max_length = 50 # Fixed
74
- top_k = 20 # Fixed
75
- print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}")
76
  # Encode input with attention mask
 
77
  inputs = tokenizer(
78
  prompt,
79
  return_tensors="pt",
@@ -81,9 +84,13 @@ def generate_text(prompt, temperature=0.7, max_new_tokens=10):
81
  truncation=True,
82
  max_length=max_length
83
  ).to(device)
 
84
  print(f"Input tokens: {len(inputs['input_ids'][0])}")
85
-
 
 
86
  # Generate text
 
87
  outputs = model.generate(
88
  input_ids=inputs["input_ids"],
89
  attention_mask=inputs["attention_mask"],
@@ -92,16 +99,18 @@ def generate_text(prompt, temperature=0.7, max_new_tokens=10):
92
  do_sample=True,
93
  top_k=int(top_k),
94
  temperature=temperature,
 
95
  pad_token_id=tokenizer.pad_token_id,
96
- eos_token_id=tokenizer.encode('.')[0] # Stop at period
97
  )
 
98
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
  print(f"Raw output: {generated_text}")
100
  print(f"Generated token count: {len(outputs[0])}")
101
  cleaned_text = clean_text(generated_text)
102
  print(f"Cleaned output: {cleaned_text}")
103
  elapsed_time = time.time() - start_time
104
- print(f"Generation time: {elapsed_time:.2f} seconds")
105
  # Clear memory cache
106
  gc.collect()
107
  return cleaned_text
@@ -117,12 +126,12 @@ demo = gr.Interface(
117
  placeholder="Viết gì đó bằng tiếng Việt...",
118
  value="Hôm nay là một ngày đẹp trời" # Default text
119
  ),
120
- gr.Slider(0.5, 1.5, value=0.7, step=0.1, label="Nhiệt độ (Temperature)"),
121
- gr.Slider(5, 30, value=10, step=1, label="Số token mới tối đa (max_new_tokens)")
122
  ],
123
  outputs="text",
124
  title="Sinh văn bản tiếng Việt",
125
- description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Điều chỉnh temperature và max_new_tokens để tối ưu thời gian và chất lượng đầu ra.",
126
  allow_flagging="never"
127
  )
128
 
 
3
  import warnings
4
  import re
5
  import gc
6
+ warnings.filterWarnings("ignore", category=UserWarning, module="torch._utils")
7
 
8
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
9
  import torch
 
65
  text = re.sub(r'\s+', ' ', text).strip()
66
  return text
67
 
68
+ def generate_text(prompt, temperature=0.5, max_new_tokens=30):
69
  try:
70
  start_time = time.time()
71
  print_system_resources()
72
+ # Fixed parameters
73
+ max_length = 50
74
+ top_k = 20
75
+ repetition_penalty = 1.2
76
  # Log parameters
77
+ print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, repetition_penalty={repetition_penalty}")
 
 
78
  # Encode input with attention mask
79
+ encode_time = time.time()
80
  inputs = tokenizer(
81
  prompt,
82
  return_tensors="pt",
 
84
  truncation=True,
85
  max_length=max_length
86
  ).to(device)
87
+ print(f"Encoding time: {time.time() - encode_time:.2f} seconds")
88
  print(f"Input tokens: {len(inputs['input_ids'][0])}")
89
+ # Define EOS token IDs for '.', '!', '?'
90
+ eos_token_ids = [tokenizer.encode(s)[0] for s in ['.', '!', '?']]
91
+ print(f"EOS token IDs: {eos_token_ids}")
92
  # Generate text
93
+ gen_time = time.time()
94
  outputs = model.generate(
95
  input_ids=inputs["input_ids"],
96
  attention_mask=inputs["attention_mask"],
 
99
  do_sample=True,
100
  top_k=int(top_k),
101
  temperature=temperature,
102
+ repetition_penalty=repetition_penalty,
103
  pad_token_id=tokenizer.pad_token_id,
104
+ eos_token_id=eos_token_ids
105
  )
106
+ print(f"Generation time: {time.time() - gen_time:.2f} seconds")
107
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
108
  print(f"Raw output: {generated_text}")
109
  print(f"Generated token count: {len(outputs[0])}")
110
  cleaned_text = clean_text(generated_text)
111
  print(f"Cleaned output: {cleaned_text}")
112
  elapsed_time = time.time() - start_time
113
+ print(f"Total time: {elapsed_time:.2f} seconds")
114
  # Clear memory cache
115
  gc.collect()
116
  return cleaned_text
 
126
  placeholder="Viết gì đó bằng tiếng Việt...",
127
  value="Hôm nay là một ngày đẹp trời" # Default text
128
  ),
129
+ gr.Slider(0.3, 0.7, value=0.5, step=0.1, label="Nhiệt độ (Temperature, 0.3-0.5 cho tốc độ nhanh, 0.6-0.7 cho đa dạng hơn)"),
130
+ gr.Slider(20, 50, value=30, step=5, label="Số token mới tối đa (max_new_tokens, 20-30 cho tốc độ nhanh, 40-50 cho câu dài hơn)")
131
  ],
132
  outputs="text",
133
  title="Sinh văn bản tiếng Việt",
134
+ description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.3-0.5 và max_new_tokens 20-30 để đạt thời gian <2 giây. Dùng temperature 0.6-0.7 max_new_tokens 40-50 cho câu dài và đa dạng hơn.",
135
  allow_flagging="never"
136
  )
137