Yuchan5386 commited on
Commit
257ebb3
·
verified ·
1 Parent(s): 9b8d3c9

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +47 -80
api.py CHANGED
@@ -166,88 +166,55 @@ def is_greedy_response_acceptable(text):
166
 
167
  return True
168
 
169
- def generate_text_flex(model, prompt, max_len=100, max_gen=98,
170
- repetition_penalty=1.2, temperature=0.55,
171
- top_k=50, top_p=0.70, typical_p=None,
172
- min_len=12):
173
  model_input = text_to_ids(f"<start> {prompt} <sep>")
174
  model_input = model_input[:max_len]
175
- generated = list(model_input)
176
-
177
- for _ in range(max_gen):
178
- pad_len = max(0, max_len - len(generated))
179
- input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
180
- input_tensor = tf.convert_to_tensor([input_padded])
181
- logits = model(input_tensor, training=False)
182
- next_logits = logits[0, len(generated) - 1].numpy()
183
-
184
- # Repetition penalty 적용
185
- for t in set(generated):
186
- count = generated.count(t)
187
- next_logits[t] /= (repetition_penalty ** count)
188
-
189
- # Temperature scaling
190
- next_logits = next_logits / temperature
191
-
192
- # 확률 계산
193
- probs = np.exp(next_logits - np.max(next_logits))
194
- probs = probs / probs.sum()
195
-
196
- # Top-K 필터링
197
- if top_k is not None and top_k > 0:
198
- indices_to_remove = probs < np.sort(probs)[-top_k]
199
- probs[indices_to_remove] = 0
200
- probs /= probs.sum()
201
-
202
- # Top-P (Nucleus) 필터링
203
- if top_p is not None and 0 < top_p < 1:
204
- sorted_indices = np.argsort(probs)[::-1]
205
- sorted_probs = probs[sorted_indices]
206
- cumulative_probs = np.cumsum(sorted_probs)
207
- cutoff_index = np.searchsorted(cumulative_probs, top_p, side='right')
208
- keep_indices = sorted_indices[:cutoff_index + 1]
209
-
210
- filtered_probs = np.zeros_like(probs)
211
- filtered_probs[keep_indices] = probs[keep_indices]
212
- filtered_probs /= filtered_probs.sum()
213
- probs = filtered_probs
214
-
215
- # Typical-p 필터링
216
- if typical_p is not None and 0 < typical_p < 1:
217
- log_probs = -np.log(probs + 1e-10)
218
- mean_info = np.mean(log_probs)
219
- deviation = np.abs(log_probs - mean_info)
220
- sorted_indices = np.argsort(deviation)
221
-
222
- filtered_indices = []
223
- cumulative_prob = 0.0
224
- for idx in sorted_indices:
225
- cumulative_prob += probs[idx]
226
- filtered_indices.append(idx)
227
- if cumulative_prob >= typical_p:
228
- break
229
-
230
- filtered_probs = np.zeros_like(probs)
231
- filtered_probs[filtered_indices] = probs[filtered_indices]
232
- filtered_probs /= filtered_probs.sum()
233
- probs = filtered_probs
234
-
235
- # 다음 토큰 샘플링
236
- next_token = np.random.choice(len(probs), p=probs)
237
- generated.append(int(next_token))
238
-
239
- decoded = sp.decode(generated)
240
- for t in ["<start>", "<sep>", "<end>"]:
241
- decoded = decoded.replace(t, "")
242
- decoded = decoded.strip()
243
-
244
- if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
245
- if is_greedy_response_acceptable(decoded):
246
  return decoded
247
- else:
248
- continue
249
 
250
- return sp.decode(generated)
 
 
 
251
 
252
  def mismatch_tone(input_text, output_text):
253
  if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
@@ -372,9 +339,9 @@ def respond(input_text):
372
  return f"{summary}\n다른 궁금한 점 있으신가요?"
373
 
374
  # 일상 대화: 샘플링 + fallback
375
- response = generate_text_flex(model, input_text)
376
  if not is_valid_response(response) or mismatch_tone(input_text, response):
377
- response = generate_text_flex(model, input_text)
378
  return response
379
 
380
  @app.get("/generate", response_class=PlainTextResponse)
 
166
 
167
  return True
168
 
169
+ def generate_text_beam(model, prompt, max_len=100, beam_width=4, length_penalty=0.7):
 
 
 
170
  model_input = text_to_ids(f"<start> {prompt} <sep>")
171
  model_input = model_input[:max_len]
172
+
173
+ beams = [{
174
+ "sequence": list(model_input),
175
+ "score": 0.0
176
+ }]
177
+
178
+ for _ in range(max_len):
179
+ all_candidates = []
180
+
181
+ for beam in beams:
182
+ seq = beam["sequence"]
183
+ pad_len = max(0, max_len - len(seq))
184
+ input_padded = np.pad(seq, (0, pad_len), constant_values=pad_id)
185
+ input_tensor = tf.convert_to_tensor([input_padded])
186
+ logits = model(input_tensor, training=False)[0, len(seq) - 1].numpy()
187
+
188
+ probs = np.exp(logits - np.max(logits))
189
+ probs = probs / probs.sum()
190
+
191
+ top_indices = probs.argsort()[-beam_width:][::-1]
192
+
193
+ for idx in top_indices:
194
+ new_seq = seq + [int(idx)]
195
+ new_score = beam["score"] + np.log(probs[idx])
196
+ all_candidates.append({
197
+ "sequence": new_seq,
198
+ "score": new_score
199
+ })
200
+
201
+ # 길이 보정
202
+ for cand in all_candidates:
203
+ cand["score"] /= (len(cand["sequence"]) ** length_penalty)
204
+
205
+ # 상위 beam_width개만 유지
206
+ beams = sorted(all_candidates, key=lambda x: x["score"], reverse=True)[:beam_width]
207
+
208
+ # 조기 종료 (EOS 토큰 또는 끝나는 말투)
209
+ for b in beams:
210
+ decoded = sp.decode(b["sequence"]).strip()
211
+ if end_id in b["sequence"] and is_greedy_response_acceptable(decoded):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  return decoded
 
 
213
 
214
+ # 최종 후보 중 가장 점수 높은 거 반환
215
+ final = beams[0]["sequence"]
216
+ return sp.decode(final)
217
+
218
 
219
  def mismatch_tone(input_text, output_text):
220
  if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
 
339
  return f"{summary}\n다른 궁금한 점 있으신가요?"
340
 
341
  # 일상 대화: 샘플링 + fallback
342
+ response = generate_text_beam(model, input_text)
343
  if not is_valid_response(response) or mismatch_tone(input_text, response):
344
+ response = generate_text_beam(model, input_text)
345
  return response
346
 
347
  @app.get("/generate", response_class=PlainTextResponse)