VietCat commited on
Commit
6fc92b9
·
1 Parent(s): c0af0d6

adjust generation time

Browse files
Files changed (1) hide show
  1. app.py +13 -19
app.py CHANGED
@@ -60,26 +60,17 @@ print(f"Memory allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB" if t
60
  print_system_resources()
61
 
62
  def clean_text(text):
63
- """Clean generated text by removing non-alphabetic characters and incomplete sentences."""
64
  text = re.sub(r'[^\w\s.,!?àáâãèéêìíòóôõùúýăđĩũơưạảấầẩẫậắằẳẵặẹẻẽếềểễệỉịọỏốồổỗộớờởỡợụủứừửữựỳỵỷỹ]', '', text)
65
  text = re.sub(r'\s+', ' ', text).strip()
66
- # Keep only complete sentences (ending with .!?)
67
- sentences = re.split(r'(?<=[.!?])\s+', text)
68
- complete_sentences = [s for s in sentences if re.search(r'[.!?]$', s)]
69
- if complete_sentences:
70
- return ' '.join(complete_sentences)
71
- # Fallback: Cut at last comma if present
72
- last_comma = text.rfind(',')
73
- if last_comma != -1:
74
- return text[:last_comma].strip()
75
- # Fallback: Cut at last valid word
76
- words = text.split()
77
- return ' '.join(words[:-1]) if len(words) > 1 else text
78
 
79
- def generate_text(prompt, max_length=30, temperature=1.0):
80
  try:
81
  start_time = time.time()
82
  print_system_resources()
 
 
83
  # Encode input with attention mask
84
  inputs = tokenizer(
85
  prompt,
@@ -94,10 +85,10 @@ def generate_text(prompt, max_length=30, temperature=1.0):
94
  outputs = model.generate(
95
  input_ids=inputs["input_ids"],
96
  attention_mask=inputs["attention_mask"],
97
- max_new_tokens=8, # Reduce for speed
98
  min_length=3,
99
  do_sample=True,
100
- top_k=15, # Reduce for speed
101
  temperature=temperature,
102
  pad_token_id=tokenizer.pad_token_id
103
  )
@@ -105,6 +96,7 @@ def generate_text(prompt, max_length=30, temperature=1.0):
105
  print(f"Raw output: {generated_text}")
106
  print(f"Generated token count: {len(outputs[0])}")
107
  cleaned_text = clean_text(generated_text)
 
108
  elapsed_time = time.time() - start_time
109
  print(f"Generation time: {elapsed_time:.2f} seconds")
110
  # Clear memory cache
@@ -122,12 +114,14 @@ demo = gr.Interface(
122
  placeholder="Viết gì đó bằng tiếng Việt...",
123
  value="Hôm nay là một ngày đẹp trời" # Default text
124
  ),
125
- gr.Slider(20, 100, value=30, step=10, label="Độ dài tối đa"),
126
- gr.Slider(0.7, 1.0, value=1.0, step=0.1, label="Nhiệt độ (Temperature)")
 
 
127
  ],
128
  outputs="text",
129
  title="Sinh văn bản tiếng Việt",
130
- description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt.",
131
  allow_flagging="never"
132
  )
133
 
 
60
  print_system_resources()
61
 
62
  def clean_text(text):
63
+ """Clean generated text by removing non-alphabetic characters and normalizing spaces."""
64
  text = re.sub(r'[^\w\s.,!?àáâãèéêìíòóôõùúýăđĩũơưạảấầẩẫậắằẳẵặẹẻẽếềểễệỉịọỏốồổỗộớờởỡợụủứừửữựỳỵỷỹ]', '', text)
65
  text = re.sub(r'\s+', ' ', text).strip()
66
+ return text
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ def generate_text(prompt, max_length=30, temperature=1.0, max_new_tokens=8, top_k=15):
69
  try:
70
  start_time = time.time()
71
  print_system_resources()
72
+ # Log parameters
73
+ print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}")
74
  # Encode input with attention mask
75
  inputs = tokenizer(
76
  prompt,
 
85
  outputs = model.generate(
86
  input_ids=inputs["input_ids"],
87
  attention_mask=inputs["attention_mask"],
88
+ max_new_tokens=int(max_new_tokens),
89
  min_length=3,
90
  do_sample=True,
91
+ top_k=int(top_k),
92
  temperature=temperature,
93
  pad_token_id=tokenizer.pad_token_id
94
  )
 
96
  print(f"Raw output: {generated_text}")
97
  print(f"Generated token count: {len(outputs[0])}")
98
  cleaned_text = clean_text(generated_text)
99
+ print(f"Cleaned output: {cleaned_text}")
100
  elapsed_time = time.time() - start_time
101
  print(f"Generation time: {elapsed_time:.2f} seconds")
102
  # Clear memory cache
 
114
  placeholder="Viết gì đó bằng tiếng Việt...",
115
  value="Hôm nay là một ngày đẹp trời" # Default text
116
  ),
117
+ gr.Slider(20, 100, value=30, step=10, label="Độ dài tối đa (max_length)"),
118
+ gr.Slider(0.7, 1.5, value=1.0, step=0.1, label="Nhiệt độ (Temperature)"),
119
+ gr.Slider(1, 20, value=8, step=1, label="Số token mới tối đa (max_new_tokens)"),
120
+ gr.Slider(5, 50, value=15, step=5, label="Top K (top_k)")
121
  ],
122
  outputs="text",
123
  title="Sinh văn bản tiếng Việt",
124
+ description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Điều chỉnh các tham số để tối ưu thời gian và chất lượng đầu ra.",
125
  allow_flagging="never"
126
  )
127