Yuchan5386 commited on
Commit
c0437ed
·
verified ·
1 Parent(s): 8c00d7b

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +27 -25
api.py CHANGED
@@ -172,13 +172,11 @@ def is_greedy_response_acceptable(text):
172
 
173
  return True
174
 
175
- def generate_text_with_mirostat(model, prompt, max_len=100, max_gen=98,
176
- repetition_penalty=1.2, tau_init=5.0, mu=5.0,
177
- eta=0.1, min_len=20):
178
  model_input = text_to_ids(f"<start> {prompt} <sep>")
179
  model_input = model_input[:max_len]
180
  generated = list(model_input)
181
- tau = tau_init
182
 
183
  for _ in range(max_gen):
184
  pad_len = max(0, max_len - len(generated))
@@ -187,29 +185,33 @@ def generate_text_with_mirostat(model, prompt, max_len=100, max_gen=98,
187
  logits = model(input_tensor, training=False)
188
  next_logits = logits[0, len(generated) - 1].numpy()
189
 
190
- # Repetition penalty
191
- for t in set(generated):
192
- count = generated.count(t)
193
- next_logits[t] /= (repetition_penalty ** count)
 
 
 
194
 
195
- # Mirostat: softmax with current τ
196
- scaled_logits = next_logits / tau
197
- exp_logits = np.exp(scaled_logits - np.max(scaled_logits))
198
- probs = exp_logits / exp_logits.sum()
199
-
200
- # 샘플링
201
- next_token = np.random.choice(len(probs), p=probs)
202
- generated.append(int(next_token))
203
 
204
- # 정보량 계산 (r = -log(p(x)))
205
- prob = probs[next_token]
206
- r = -np.log(prob + 1e-10)
 
207
 
208
- # τ 업데이트
209
- tau = tau - eta * (r - mu)
 
 
 
210
 
211
- # 디코딩 및 조건 검증
212
- decoded = sp.decode(generated).replace("<start>", "").replace("<sep>", "").replace("<end>", "").strip()
213
  if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
214
  if is_greedy_response_acceptable(decoded):
215
  return decoded
@@ -341,9 +343,9 @@ def respond(input_text):
341
  return f"{summary}\n다른 궁금한 점 있으신가요?"
342
 
343
  # 일상 대화: 샘플링 + fallback
344
- response = generate_text_with_mirostat(model, input_text)
345
  if not is_valid_response(response) or mismatch_tone(input_text, response):
346
- response = generate_text_with_mirostat(model, input_text)
347
  return response
348
 
349
  @app.get("/generate", response_class=PlainTextResponse)
 
172
 
173
  return True
174
 
175
+ def generate_text_typical(model, prompt, max_len=100, max_gen=98,
176
+ typical_p=0.9, min_len=20):
 
177
  model_input = text_to_ids(f"<start> {prompt} <sep>")
178
  model_input = model_input[:max_len]
179
  generated = list(model_input)
 
180
 
181
  for _ in range(max_gen):
182
  pad_len = max(0, max_len - len(generated))
 
185
  logits = model(input_tensor, training=False)
186
  next_logits = logits[0, len(generated) - 1].numpy()
187
 
188
+ # 🔥 Typical Sampling
189
+ probs = tf.nn.softmax(next_logits).numpy()
190
+ log_probs = -np.log(probs + 1e-10)
191
+ info_content = log_probs
192
+ mean_info = np.mean(info_content)
193
+ deviation = np.abs(info_content - mean_info)
194
+ sorted_indices = np.argsort(deviation)
195
 
196
+ filtered_indices = []
197
+ cumulative_prob = 0.0
198
+ for idx in sorted_indices:
199
+ cumulative_prob += probs[idx]
200
+ filtered_indices.append(idx)
201
+ if cumulative_prob >= typical_p:
202
+ break
 
203
 
204
+ filtered_probs = np.zeros_like(probs)
205
+ filtered_probs[filtered_indices] = probs[filtered_indices]
206
+ filtered_probs /= filtered_probs.sum()
207
+ next_token = np.random.choice(len(filtered_probs), p=filtered_probs)
208
 
209
+ generated.append(int(next_token))
210
+ decoded = sp.decode(generated)
211
+ for t in ["<start>", "<sep>", "<end>"]:
212
+ decoded = decoded.replace(t, "")
213
+ decoded = decoded.strip()
214
 
 
 
215
  if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
216
  if is_greedy_response_acceptable(decoded):
217
  return decoded
 
343
  return f"{summary}\n다른 궁금한 점 있으신가요?"
344
 
345
  # 일상 대화: 샘플링 + fallback
346
+ response = generate_text_typical(model, input_text)
347
  if not is_valid_response(response) or mismatch_tone(input_text, response):
348
+ response = generate_text_typical(model, input_text)
349
  return response
350
 
351
  @app.get("/generate", response_class=PlainTextResponse)