Yuchan5386 commited on
Commit
0abe2b5
·
verified ·
1 Parent(s): 257ebb3

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +29 -44
api.py CHANGED
@@ -166,54 +166,39 @@ def is_greedy_response_acceptable(text):
166
 
167
  return True
168
 
169
- def generate_text_beam(model, prompt, max_len=100, beam_width=4, length_penalty=0.7):
170
  model_input = text_to_ids(f"<start> {prompt} <sep>")
171
  model_input = model_input[:max_len]
172
-
173
- beams = [{
174
- "sequence": list(model_input),
175
- "score": 0.0
176
- }]
177
 
178
  for _ in range(max_len):
179
- all_candidates = []
180
-
181
- for beam in beams:
182
- seq = beam["sequence"]
183
- pad_len = max(0, max_len - len(seq))
184
- input_padded = np.pad(seq, (0, pad_len), constant_values=pad_id)
185
- input_tensor = tf.convert_to_tensor([input_padded])
186
- logits = model(input_tensor, training=False)[0, len(seq) - 1].numpy()
187
-
188
- probs = np.exp(logits - np.max(logits))
189
- probs = probs / probs.sum()
190
-
191
- top_indices = probs.argsort()[-beam_width:][::-1]
192
-
193
- for idx in top_indices:
194
- new_seq = seq + [int(idx)]
195
- new_score = beam["score"] + np.log(probs[idx])
196
- all_candidates.append({
197
- "sequence": new_seq,
198
- "score": new_score
199
- })
200
-
201
- # 길이 보정
202
- for cand in all_candidates:
203
- cand["score"] /= (len(cand["sequence"]) ** length_penalty)
204
-
205
- # 상위 beam_width개만 유지
206
- beams = sorted(all_candidates, key=lambda x: x["score"], reverse=True)[:beam_width]
207
-
208
- # 조기 종료 (EOS 토큰 또는 끝나는 말투)
209
- for b in beams:
210
- decoded = sp.decode(b["sequence"]).strip()
211
- if end_id in b["sequence"] and is_greedy_response_acceptable(decoded):
212
  return decoded
 
 
213
 
214
- # 최종 후보 가장 점수 높은 거 반환
215
- final = beams[0]["sequence"]
216
- return sp.decode(final)
 
 
217
 
218
 
219
  def mismatch_tone(input_text, output_text):
@@ -339,9 +324,9 @@ def respond(input_text):
339
  return f"{summary}\n다른 궁금한 점 있으신가요?"
340
 
341
  # 일상 대화: 샘플링 + fallback
342
- response = generate_text_beam(model, input_text)
343
  if not is_valid_response(response) or mismatch_tone(input_text, response):
344
- response = generate_text_beam(model, input_text)
345
  return response
346
 
347
  @app.get("/generate", response_class=PlainTextResponse)
 
166
 
167
  return True
168
 
169
+ def generate_text_greedy(model, prompt, max_len=100, min_len=12):
170
  model_input = text_to_ids(f"<start> {prompt} <sep>")
171
  model_input = model_input[:max_len]
172
+ generated = list(model_input)
 
 
 
 
173
 
174
  for _ in range(max_len):
175
+ pad_len = max(0, max_len - len(generated))
176
+ input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
177
+ input_tensor = tf.convert_to_tensor([input_padded])
178
+ logits = model(input_tensor, training=False)
179
+ next_logits = logits[0, len(generated) - 1].numpy()
180
+
181
+ # Greedy: 확률 가장 높은 토큰만 선택
182
+ next_token = int(np.argmax(next_logits))
183
+ generated.append(next_token)
184
+
185
+ # 디코딩 필터링
186
+ decoded = sp.decode(generated)
187
+ for t in ["<start>", "<sep>", "<end>"]:
188
+ decoded = decoded.replace(t, "")
189
+ decoded = decoded.strip()
190
+
191
+ if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
192
+ if is_greedy_response_acceptable(decoded):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  return decoded
194
+ else:
195
+ continue
196
 
197
+ # 끝까지 만족하는 문장 없으면 그냥 최종 출력
198
+ decoded = sp.decode(generated)
199
+ for t in ["<start>", "<sep>", "<end>"]:
200
+ decoded = decoded.replace(t, "")
201
+ return decoded.strip()
202
 
203
 
204
  def mismatch_tone(input_text, output_text):
 
324
  return f"{summary}\n다른 궁금한 점 있으신가요?"
325
 
326
  # 일상 대화: 샘플링 + fallback
327
+ response = generate_text_greedy(model, input_text)
328
  if not is_valid_response(response) or mismatch_tone(input_text, response):
329
+ response = generate_text_greedy(model, input_text)
330
  return response
331
 
332
  @app.get("/generate", response_class=PlainTextResponse)