Yuchan5386 commited on
Commit
12bb000
·
verified ·
1 Parent(s): 14686d0

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +58 -23
api.py CHANGED
@@ -172,8 +172,10 @@ def is_greedy_response_acceptable(text):
172
 
173
  return True
174
 
175
- def generate_text_typical(model, prompt, max_len=100, max_gen=98,
176
- typical_p=0.5, min_len=20):
 
 
177
  model_input = text_to_ids(f"<start> {prompt} <sep>")
178
  model_input = model_input[:max_len]
179
  generated = list(model_input)
@@ -185,28 +187,61 @@ def generate_text_typical(model, prompt, max_len=100, max_gen=98,
185
  logits = model(input_tensor, training=False)
186
  next_logits = logits[0, len(generated) - 1].numpy()
187
 
188
- # 🔥 Typical Sampling
189
- probs = tf.nn.softmax(next_logits).numpy()
190
- log_probs = -np.log(probs + 1e-10)
191
- info_content = log_probs
192
- mean_info = np.mean(info_content)
193
- deviation = np.abs(info_content - mean_info)
194
- sorted_indices = np.argsort(deviation)
195
-
196
- filtered_indices = []
197
- cumulative_prob = 0.0
198
- for idx in sorted_indices:
199
- cumulative_prob += probs[idx]
200
- filtered_indices.append(idx)
201
- if cumulative_prob >= typical_p:
202
- break
 
 
 
 
 
 
 
 
 
 
203
 
204
- filtered_probs = np.zeros_like(probs)
205
- filtered_probs[filtered_indices] = probs[filtered_indices]
206
- filtered_probs /= filtered_probs.sum()
207
- next_token = np.random.choice(len(filtered_probs), p=filtered_probs)
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  generated.append(int(next_token))
 
210
  decoded = sp.decode(generated)
211
  for t in ["<start>", "<sep>", "<end>"]:
212
  decoded = decoded.replace(t, "")
@@ -343,9 +378,9 @@ def respond(input_text):
343
  return f"{summary}\n다른 궁금한 점 있으신가요?"
344
 
345
  # 일상 대화: 샘플링 + fallback
346
- response = generate_text_typical(model, input_text)
347
  if not is_valid_response(response) or mismatch_tone(input_text, response):
348
- response = generate_text_typical(model, input_text)
349
  return response
350
 
351
  @app.get("/generate", response_class=PlainTextResponse)
 
172
 
173
  return True
174
 
175
+ def generate_text_flex(model, prompt, max_len=100, max_gen=98,
176
+ repetition_penalty=1.2, temperature=1.0,
177
+ top_k=50, top_p=0.85, typical_p=0.72,
178
+ min_len=20):
179
  model_input = text_to_ids(f"<start> {prompt} <sep>")
180
  model_input = model_input[:max_len]
181
  generated = list(model_input)
 
187
  logits = model(input_tensor, training=False)
188
  next_logits = logits[0, len(generated) - 1].numpy()
189
 
190
+ # Repetition penalty 적용
191
+ for t in set(generated):
192
+ count = generated.count(t)
193
+ next_logits[t] /= (repetition_penalty ** count)
194
+
195
+ # Temperature scaling
196
+ next_logits = next_logits / temperature
197
+
198
+ # 확률 계산
199
+ probs = np.exp(next_logits - np.max(next_logits))
200
+ probs = probs / probs.sum()
201
+
202
+ # Top-K 필터링
203
+ if top_k is not None and top_k > 0:
204
+ indices_to_remove = probs < np.sort(probs)[-top_k]
205
+ probs[indices_to_remove] = 0
206
+ probs /= probs.sum()
207
+
208
+ # Top-P (Nucleus) 필터링
209
+ if top_p is not None and 0 < top_p < 1:
210
+ sorted_indices = np.argsort(probs)[::-1]
211
+ sorted_probs = probs[sorted_indices]
212
+ cumulative_probs = np.cumsum(sorted_probs)
213
+ cutoff_index = np.searchsorted(cumulative_probs, top_p, side='right')
214
+ keep_indices = sorted_indices[:cutoff_index + 1]
215
 
216
+ filtered_probs = np.zeros_like(probs)
217
+ filtered_probs[keep_indices] = probs[keep_indices]
218
+ filtered_probs /= filtered_probs.sum()
219
+ probs = filtered_probs
220
 
221
+ # Typical-p 필터링
222
+ if typical_p is not None and 0 < typical_p < 1:
223
+ log_probs = -np.log(probs + 1e-10)
224
+ mean_info = np.mean(log_probs)
225
+ deviation = np.abs(log_probs - mean_info)
226
+ sorted_indices = np.argsort(deviation)
227
+
228
+ filtered_indices = []
229
+ cumulative_prob = 0.0
230
+ for idx in sorted_indices:
231
+ cumulative_prob += probs[idx]
232
+ filtered_indices.append(idx)
233
+ if cumulative_prob >= typical_p:
234
+ break
235
+
236
+ filtered_probs = np.zeros_like(probs)
237
+ filtered_probs[filtered_indices] = probs[filtered_indices]
238
+ filtered_probs /= filtered_probs.sum()
239
+ probs = filtered_probs
240
+
241
+ # 다음 토큰 샘플링
242
+ next_token = np.random.choice(len(probs), p=probs)
243
  generated.append(int(next_token))
244
+
245
  decoded = sp.decode(generated)
246
  for t in ["<start>", "<sep>", "<end>"]:
247
  decoded = decoded.replace(t, "")
 
378
  return f"{summary}\n다른 궁금한 점 있으신가요?"
379
 
380
  # 일상 대화: 샘플링 + fallback
381
+ response = generate_text_flex(model, input_text)
382
  if not is_valid_response(response) or mismatch_tone(input_text, response):
383
+ response = generate_text_flex(model, input_text)
384
  return response
385
 
386
  @app.get("/generate", response_class=PlainTextResponse)