Yuchan5386 commited on
Commit
7f80084
·
verified ·
1 Parent(s): 0abe2b5

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +20 -9
api.py CHANGED
@@ -166,23 +166,35 @@ def is_greedy_response_acceptable(text):
166
 
167
  return True
168
 
169
- def generate_text_greedy(model, prompt, max_len=100, min_len=12):
 
170
  model_input = text_to_ids(f"<start> {prompt} <sep>")
171
  model_input = model_input[:max_len]
172
  generated = list(model_input)
173
 
174
- for _ in range(max_len):
175
  pad_len = max(0, max_len - len(generated))
176
  input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
177
  input_tensor = tf.convert_to_tensor([input_padded])
178
  logits = model(input_tensor, training=False)
179
  next_logits = logits[0, len(generated) - 1].numpy()
180
 
181
- # Greedy: 확률 가장 높은 토큰만 선택
182
- next_token = int(np.argmax(next_logits))
183
- generated.append(next_token)
 
184
 
185
- # 디코딩 필터링
 
 
 
 
 
 
 
 
 
 
186
  decoded = sp.decode(generated)
187
  for t in ["<start>", "<sep>", "<end>"]:
188
  decoded = decoded.replace(t, "")
@@ -194,7 +206,6 @@ def generate_text_greedy(model, prompt, max_len=100, min_len=12):
194
  else:
195
  continue
196
 
197
- # 끝까지 만족하는 문장 없으면 그냥 최종 출력
198
  decoded = sp.decode(generated)
199
  for t in ["<start>", "<sep>", "<end>"]:
200
  decoded = decoded.replace(t, "")
@@ -324,9 +335,9 @@ def respond(input_text):
324
  return f"{summary}\n다른 궁금한 점 있으신가요?"
325
 
326
  # 일상 대화: 샘플링 + fallback
327
- response = generate_text_greedy(model, input_text)
328
  if not is_valid_response(response) or mismatch_tone(input_text, response):
329
- response = generate_text_greedy(model, input_text)
330
  return response
331
 
332
  @app.get("/generate", response_class=PlainTextResponse)
 
166
 
167
  return True
168
 
169
+ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
170
+ temperature=0.7, top_k=40, min_len=12):
171
  model_input = text_to_ids(f"<start> {prompt} <sep>")
172
  model_input = model_input[:max_len]
173
  generated = list(model_input)
174
 
175
+ for _ in range(max_gen):
176
  pad_len = max(0, max_len - len(generated))
177
  input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
178
  input_tensor = tf.convert_to_tensor([input_padded])
179
  logits = model(input_tensor, training=False)
180
  next_logits = logits[0, len(generated) - 1].numpy()
181
 
182
+ # Temperature 적용
183
+ next_logits = next_logits / temperature
184
+ probs = np.exp(next_logits - np.max(next_logits))
185
+ probs = probs / probs.sum()
186
 
187
+ # Top-K 필터링
188
+ if top_k is not None and top_k > 0:
189
+ indices_to_remove = probs < np.sort(probs)[-top_k]
190
+ probs[indices_to_remove] = 0
191
+ probs /= probs.sum()
192
+
193
+ # 샘플링
194
+ next_token = np.random.choice(len(probs), p=probs)
195
+ generated.append(int(next_token))
196
+
197
+ # 디코딩 및 후처리
198
  decoded = sp.decode(generated)
199
  for t in ["<start>", "<sep>", "<end>"]:
200
  decoded = decoded.replace(t, "")
 
206
  else:
207
  continue
208
 
 
209
  decoded = sp.decode(generated)
210
  for t in ["<start>", "<sep>", "<end>"]:
211
  decoded = decoded.replace(t, "")
 
335
  return f"{summary}\n다른 궁금한 점 있으신가요?"
336
 
337
  # 일상 대화: 샘플링 + fallback
338
+ response = generate_text_sample(model, input_text)
339
  if not is_valid_response(response) or mismatch_tone(input_text, response):
340
+ response = generate_text_sample(model, input_text)
341
  return response
342
 
343
  @app.get("/generate", response_class=PlainTextResponse)