Yuchan5386 commited on
Commit
3b062a6
·
verified ·
1 Parent(s): dcd1d9c

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +26 -54
api.py CHANGED
@@ -150,75 +150,47 @@ from sklearn.metrics.pairwise import cosine_similarity
150
  from fastapi import Request
151
  from fastapi.responses import PlainTextResponse
152
 
153
- # 1. Top-KP 기반 생성기
154
- def generate_text_topkp(model, prompt, max_len=100, max_gen=98,
155
- temperature=0.90, min_len=20,
156
- repetition_penalty=1.2, top_p=0.90, top_k=50):
157
- def top_kp_filtering(logits, top_k, top_p):
158
- probs = np.exp(logits - np.max(logits))
159
- probs /= probs.sum()
160
- sorted_idx = np.argsort(-probs)
161
- sorted_probs = probs[sorted_idx]
162
- if top_k > 0:
163
- sorted_idx = sorted_idx[:top_k]
164
- sorted_probs = sorted_probs[:top_k]
165
- cum_probs = np.cumsum(sorted_probs)
166
- cutoff = np.searchsorted(cum_probs, top_p) + 1
167
- final_idx = sorted_idx[:cutoff]
168
- final_probs = probs[final_idx]
169
- final_probs /= final_probs.sum()
170
- return final_idx, final_probs
171
-
172
  model_input = text_to_ids(f"<start> {prompt} <sep>")
173
  model_input = model_input[:max_len]
174
  generated = list(model_input)
175
- for step in range(max_gen):
 
176
  pad_len = max(0, max_len - len(generated))
177
  input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
178
  input_tensor = tf.convert_to_tensor([input_padded])
179
  logits = model(input_tensor, training=False)
180
  next_logits = logits[0, len(generated) - 1].numpy()
 
 
181
  for t in set(generated):
182
  count = generated.count(t)
183
- next_logits[t] /= (repetition_penalty ** count)
184
- if len(generated) < min_len:
185
- next_logits[end_id] -= 5.0
186
- next_logits[pad_id] -= 10.0
187
- next_logits = next_logits / temperature
188
- final_idx, final_probs = top_kp_filtering(next_logits, top_k=top_k, top_p=top_p)
189
- sampled = np.random.choice(final_idx, p=final_probs)
190
- generated.append(int(sampled))
191
- decoded = sp.decode(generated)
192
- for t in ["<start>", "<sep>", "<end>"]:
193
- decoded = decoded.replace(t, "")
194
- decoded = decoded.strip()
195
- if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
196
- return decoded
197
- return sp.decode(generated)
198
 
199
- # Greedy 버전 생성기
200
- def generate_text_greedy(model, prompt, max_len=100, max_gen=98):
201
- model_input = text_to_ids(f"<start> {prompt} <sep>")
202
- model_input = model_input[:max_len]
203
- generated = list(model_input)
204
- for _ in range(max_gen):
205
- pad_len = max(0, max_len - len(generated))
206
- input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
207
- input_tensor = tf.convert_to_tensor([input_padded])
208
- logits = model(input_tensor, training=False)
209
- next_logits = logits[0, len(generated) - 1].numpy()
210
- next_logits[pad_id] -= 10.0
211
- next_token = np.argmax(next_logits)
212
  generated.append(int(next_token))
 
213
  decoded = sp.decode(generated)
214
  for t in ["<start>", "<sep>", "<end>"]:
215
  decoded = decoded.replace(t, "")
216
- decoded = decoded.strip()
217
- if next_token == end_id or decoded.endswith(('.', '!', '?')):
218
- return decoded
 
 
 
 
 
219
  return sp.decode(generated)
220
 
221
- # 톤 불일치 체크
222
  def mismatch_tone(input_text, output_text):
223
  if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
224
  return True
@@ -340,9 +312,9 @@ def respond(input_text):
340
  return f"{summary}\n다른 궁금한 점 있으신가요?"
341
 
342
  # 일상 대화: 샘플링 + fallback
343
- response = generate_text_topkp(model, input_text)
344
  if not is_valid_response(response) or mismatch_tone(input_text, response):
345
- response = generate_text_greedy(model, input_text)
346
  return response
347
 
348
  @app.get("/generate", response_class=PlainTextResponse)
 
150
  from fastapi import Request
151
  from fastapi.responses import PlainTextResponse
152
 
153
+ def generate_text_greedy_strong(model, prompt, max_len=100, max_gen=98,
154
+ repetition_penalty=1.2, min_len=20):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  model_input = text_to_ids(f"<start> {prompt} <sep>")
156
  model_input = model_input[:max_len]
157
  generated = list(model_input)
158
+
159
+ for _ in range(max_gen):
160
  pad_len = max(0, max_len - len(generated))
161
  input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
162
  input_tensor = tf.convert_to_tensor([input_padded])
163
  logits = model(input_tensor, training=False)
164
  next_logits = logits[0, len(generated) - 1].numpy()
165
+
166
+ # Repetition Penalty
167
  for t in set(generated):
168
  count = generated.count(t)
169
+ next_logits[t] /= (repetition_penalty ** count)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ # Stop token filtering
172
+ stop_tokens = ["음", "어", "그", "뭐지", "..."]
173
+ for tok in stop_tokens:
174
+ tok_id = sp.piece_to_id(tok)
175
+ next_logits[tok_id] -= 5.0
176
+
177
+ next_logits[pad_id] -= 10.0
178
+ next_token = np.argmax(next_logits)
 
 
 
 
 
179
  generated.append(int(next_token))
180
+
181
  decoded = sp.decode(generated)
182
  for t in ["<start>", "<sep>", "<end>"]:
183
  decoded = decoded.replace(t, "")
184
+ decoded = decoded.strip()
185
+
186
+ if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
187
+ if is_greedy_response_acceptable(decoded):
188
+ return decoded
189
+ else:
190
+ continue
191
+
192
  return sp.decode(generated)
193
 
 
194
  def mismatch_tone(input_text, output_text):
195
  if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
196
  return True
 
312
  return f"{summary}\n다른 궁금한 점 있으신가요?"
313
 
314
  # 일상 대화: 샘플링 + fallback
315
+ response = generate_text_greedy_strong(model, input_text)
316
  if not is_valid_response(response) or mismatch_tone(input_text, response):
317
+ response = generate_text_greedy_strong(model, input_text)
318
  return response
319
 
320
  @app.get("/generate", response_class=PlainTextResponse)