VietCat commited on
Commit
6c05290
·
1 Parent(s): 6fc92b9

adjust generation time

Browse files
Files changed (1) hide show
  1. app.py +9 -8
app.py CHANGED
@@ -60,16 +60,18 @@ print(f"Memory allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB" if t
60
  print_system_resources()
61
 
62
  def clean_text(text):
63
- """Clean generated text by removing non-alphabetic characters and normalizing spaces."""
64
  text = re.sub(r'[^\w\s.,!?àáâãèéêìíòóôõùúýăđĩũơưạảấầẩẫậắằẳẵặẹẻẽếềểễệỉịọỏốồổỗộớờởỡợụủứừửữựỳỵỷỹ]', '', text)
65
  text = re.sub(r'\s+', ' ', text).strip()
66
  return text
67
 
68
- def generate_text(prompt, max_length=30, temperature=1.0, max_new_tokens=8, top_k=15):
69
  try:
70
  start_time = time.time()
71
  print_system_resources()
72
  # Log parameters
 
 
73
  print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}")
74
  # Encode input with attention mask
75
  inputs = tokenizer(
@@ -90,7 +92,8 @@ def generate_text(prompt, max_length=30, temperature=1.0, max_new_tokens=8, top_
90
  do_sample=True,
91
  top_k=int(top_k),
92
  temperature=temperature,
93
- pad_token_id=tokenizer.pad_token_id
 
94
  )
95
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
96
  print(f"Raw output: {generated_text}")
@@ -114,14 +117,12 @@ demo = gr.Interface(
114
  placeholder="Viết gì đó bằng tiếng Việt...",
115
  value="Hôm nay là một ngày đẹp trời" # Default text
116
  ),
117
- gr.Slider(20, 100, value=30, step=10, label="Độ dài tối đa (max_length)"),
118
- gr.Slider(0.7, 1.5, value=1.0, step=0.1, label="Nhiệt độ (Temperature)"),
119
- gr.Slider(1, 20, value=8, step=1, label="Số token mới tối đa (max_new_tokens)"),
120
- gr.Slider(5, 50, value=15, step=5, label="Top K (top_k)")
121
  ],
122
  outputs="text",
123
  title="Sinh văn bản tiếng Việt",
124
- description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Điều chỉnh các tham số để tối ưu thời gian và chất lượng đầu ra.",
125
  allow_flagging="never"
126
  )
127
 
 
60
  print_system_resources()
61
 
62
  def clean_text(text):
63
+ """Normalize text by removing invalid characters and extra spaces."""
64
  text = re.sub(r'[^\w\s.,!?àáâãèéêìíòóôõùúýăđĩũơưạảấầẩẫậắằẳẵặẹẻẽếềểễệỉịọỏốồổỗộớờởỡợụủứừửữựỳỵỷỹ]', '', text)
65
  text = re.sub(r'\s+', ' ', text).strip()
66
  return text
67
 
68
+ def generate_text(prompt, temperature=0.7, max_new_tokens=10):
69
  try:
70
  start_time = time.time()
71
  print_system_resources()
72
  # Log parameters
73
+ max_length = 50 # Fixed
74
+ top_k = 20 # Fixed
75
  print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}")
76
  # Encode input with attention mask
77
  inputs = tokenizer(
 
92
  do_sample=True,
93
  top_k=int(top_k),
94
  temperature=temperature,
95
+ pad_token_id=tokenizer.pad_token_id,
96
+ eos_token_id=tokenizer.encode('.')[0] # Stop at period
97
  )
98
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
  print(f"Raw output: {generated_text}")
 
117
  placeholder="Viết gì đó bằng tiếng Việt...",
118
  value="Hôm nay là một ngày đẹp trời" # Default text
119
  ),
120
+ gr.Slider(0.5, 1.5, value=0.7, step=0.1, label="Nhiệt độ (Temperature)"),
121
+ gr.Slider(5, 30, value=10, step=1, label="Số token mới tối đa (max_new_tokens)")
 
 
122
  ],
123
  outputs="text",
124
  title="Sinh văn bản tiếng Việt",
125
+ description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Điều chỉnh temperature max_new_tokens để tối ưu thời gian và chất lượng đầu ra.",
126
  allow_flagging="never"
127
  )
128