VietCat commited on
Commit
4a6a9fa
·
1 Parent(s): 58c9dd0

adjust generation time

Browse files
Files changed (1) hide show
  1. app.py +22 -21
app.py CHANGED
@@ -60,7 +60,7 @@ def clean_text(text):
60
  text = re.sub(r'\s+', ' ', text).strip()
61
  return text
62
 
63
- def generate_text(prompt, temperature=0.7, max_new_tokens=50, min_length=50, max_length=100):
64
  try:
65
  start_time = time.time()
66
  # Start tracemalloc for memory debugging
@@ -70,7 +70,7 @@ def generate_text(prompt, temperature=0.7, max_new_tokens=50, min_length=50, max
70
  print_system_resources()
71
  # Fixed parameters
72
  top_k = 20
73
- repetition_penalty = 1.5
74
  top_p = 0.9
75
  # Log parameters
76
  print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, min_length={min_length}")
@@ -88,21 +88,22 @@ def generate_text(prompt, temperature=0.7, max_new_tokens=50, min_length=50, max
88
  # Define EOS token IDs for '.', '!', '?'
89
  eos_token_ids = [tokenizer.encode(s)[0] for s in ['.', '!', '?']]
90
  print(f"EOS token IDs: {eos_token_ids}")
91
- # Generate text
92
  gen_time = time.time()
93
- outputs = model.generate(
94
- input_ids=inputs["input_ids"],
95
- attention_mask=inputs["attention_mask"],
96
- max_new_tokens=int(max_new_tokens),
97
- min_length=int(min_length),
98
- do_sample=True,
99
- top_k=top_k,
100
- top_p=top_p,
101
- temperature=temperature,
102
- repetition_penalty=repetition_penalty,
103
- pad_token_id=tokenizer.pad_token_id,
104
- eos_token_id=eos_token_ids
105
- )
 
106
  gen_duration = time.time() - gen_time
107
  print(f"Generation time: {gen_duration:.2f} seconds")
108
  print(f"Tokens per second: {len(outputs[0]) / gen_duration:.2f}")
@@ -142,16 +143,16 @@ demo = gr.Interface(
142
  gr.Textbox(
143
  label="Nhập văn bản đầu vào",
144
  placeholder="Viết gì đó bằng tiếng Việt...",
145
- value="Hôm nay là một ngày đẹp trời, hãy kể một câu chuyện dài về một trải nghiệm đáng nhớ của bạn, bao gồm chi tiết về bối cảnh, nhân vật, và cảm xúc."
146
  ),
147
- gr.Slider(0.1, 1.5, value=0.7, step=0.1, label="Nhiệt độ (Temperature, 0.1-0.3 cho mạch lạc, 0.7-1.5 cho sáng tạo và câu dài)"),
148
- gr.Slider(25, 100, value=50, step=5, label="Số token mới tối đa (max_new_tokens, 25-50 cho câu ngắn, 75-100 cho câu dài)"),
149
- gr.Slider(30, 100, value=50, step=5, label="Độ dài tối thiểu (min_length, 30-50 cho câu vừa, 75-100 cho câu dài)"),
150
  gr.Slider(50, 150, value=100, step=5, label="Độ dài tối đa (max_length, 50-100 cho câu ngắn, 100-150 cho câu dài)")
151
  ],
152
  outputs="text",
153
  title="Sinh văn bản tiếng Việt",
154
- description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.7, max_new_tokens 50, min_length 50, và max_length 100 để tạo câu dài, mạch lạc, và hoàn chỉnh. Dùng temperature 0.1-0.3, max_new_tokens 25-50, min_length 30-50 cho câu ngắn và nhanh. Lưu ý: Dùng prompt chi tiết (bối cảnh, nhân vật, cảm xúc) để cải thiện chất lượng câu chuyện dài. Tránh max_new_tokens=100 với temperature=0.3 để giảm lặp từ. Top_k mặc định là 20.",
155
  allow_flagging="never"
156
  )
157
 
 
60
  text = re.sub(r'\s+', ' ', text).strip()
61
  return text
62
 
63
+ def generate_text(prompt, temperature=0.9, max_new_tokens=30, min_length=20, max_length=100):
64
  try:
65
  start_time = time.time()
66
  # Start tracemalloc for memory debugging
 
70
  print_system_resources()
71
  # Fixed parameters
72
  top_k = 20
73
+ repetition_penalty = 2.0
74
  top_p = 0.9
75
  # Log parameters
76
  print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, min_length={min_length}")
 
88
  # Define EOS token IDs for '.', '!', '?'
89
  eos_token_ids = [tokenizer.encode(s)[0] for s in ['.', '!', '?']]
90
  print(f"EOS token IDs: {eos_token_ids}")
91
+ # Generate text with no_grad
92
  gen_time = time.time()
93
+ with torch.no_grad():
94
+ outputs = model.generate(
95
+ input_ids=inputs["input_ids"],
96
+ attention_mask=inputs["attention_mask"],
97
+ max_new_tokens=int(max_new_tokens),
98
+ min_length=int(min_length),
99
+ do_sample=True,
100
+ top_k=top_k,
101
+ top_p=top_p,
102
+ temperature=temperature,
103
+ repetition_penalty=repetition_penalty,
104
+ pad_token_id=tokenizer.pad_token_id,
105
+ eos_token_id=eos_token_ids
106
+ )
107
  gen_duration = time.time() - gen_time
108
  print(f"Generation time: {gen_duration:.2f} seconds")
109
  print(f"Tokens per second: {len(outputs[0]) / gen_duration:.2f}")
 
143
  gr.Textbox(
144
  label="Nhập văn bản đầu vào",
145
  placeholder="Viết gì đó bằng tiếng Việt...",
146
+ value="Hôm nay là một ngày đẹp trời, hãy kể một câu chuyện dài về một chuyến đi đáng nhớ, tả chi tiết nơi bạn đến, người bạn gặp, và cảm xúc của bạn."
147
  ),
148
+ gr.Slider(0.1, 1.5, value=0.9, step=0.1, label="Nhiệt độ (Temperature, 0.1-0.3 cho mạch lạc, 0.7-1.2 cho sáng tạo và câu dài)"),
149
+ gr.Slider(25, 100, value=30, step=5, label="Số token mới tối đa (max_new_tokens, 25-50 cho câu ngắn, 75-100 cho câu dài)"),
150
+ gr.Slider(20, 100, value=20, step=5, label="Độ dài tối thiểu (min_length, 20-50 cho câu vừa, 75-100 cho câu dài)"),
151
  gr.Slider(50, 150, value=100, step=5, label="Độ dài tối đa (max_length, 50-100 cho câu ngắn, 100-150 cho câu dài)")
152
  ],
153
  outputs="text",
154
  title="Sinh văn bản tiếng Việt",
155
+ description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.7-1.2, max_new_tokens 30-50, min_length 20-50, và max_length 100 để tạo câu dài, mạch lạc, và hoàn chỉnh. Dùng temperature 0.1-0.3, max_new_tokens 25-30, min_length 20-30 cho câu ngắn và nhanh. Lưu ý: Tránh temperature=0.1 (dễ lặp ý) và max_new_tokens=100 (chậm, dễ cắt cụt). Dùng prompt chi tiết (bối cảnh, nhân vật, cảm xúc) để cải thiện chất lượng. Top_k mặc định là 20.",
156
  allow_flagging="never"
157
  )
158