Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -172,13 +172,11 @@ def is_greedy_response_acceptable(text):
|
|
172 |
|
173 |
return True
|
174 |
|
175 |
-
def
|
176 |
-
|
177 |
-
eta=0.1, min_len=20):
|
178 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
179 |
model_input = model_input[:max_len]
|
180 |
generated = list(model_input)
|
181 |
-
tau = tau_init
|
182 |
|
183 |
for _ in range(max_gen):
|
184 |
pad_len = max(0, max_len - len(generated))
|
@@ -187,29 +185,33 @@ def generate_text_with_mirostat(model, prompt, max_len=100, max_gen=98,
|
|
187 |
logits = model(input_tensor, training=False)
|
188 |
next_logits = logits[0, len(generated) - 1].numpy()
|
189 |
|
190 |
-
#
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
|
|
|
|
194 |
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
generated.append(int(next_token))
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
207 |
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
210 |
|
211 |
-
# 디코딩 및 조건 검증
|
212 |
-
decoded = sp.decode(generated).replace("<start>", "").replace("<sep>", "").replace("<end>", "").strip()
|
213 |
if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
|
214 |
if is_greedy_response_acceptable(decoded):
|
215 |
return decoded
|
@@ -341,9 +343,9 @@ def respond(input_text):
|
|
341 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
342 |
|
343 |
# 일상 대화: 샘플링 + fallback
|
344 |
-
response =
|
345 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
346 |
-
response =
|
347 |
return response
|
348 |
|
349 |
@app.get("/generate", response_class=PlainTextResponse)
|
|
|
172 |
|
173 |
return True
|
174 |
|
175 |
+
def generate_text_typical(model, prompt, max_len=100, max_gen=98,
|
176 |
+
typical_p=0.9, min_len=20):
|
|
|
177 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
178 |
model_input = model_input[:max_len]
|
179 |
generated = list(model_input)
|
|
|
180 |
|
181 |
for _ in range(max_gen):
|
182 |
pad_len = max(0, max_len - len(generated))
|
|
|
185 |
logits = model(input_tensor, training=False)
|
186 |
next_logits = logits[0, len(generated) - 1].numpy()
|
187 |
|
188 |
+
# 🔥 Typical Sampling
|
189 |
+
probs = tf.nn.softmax(next_logits).numpy()
|
190 |
+
log_probs = -np.log(probs + 1e-10)
|
191 |
+
info_content = log_probs
|
192 |
+
mean_info = np.mean(info_content)
|
193 |
+
deviation = np.abs(info_content - mean_info)
|
194 |
+
sorted_indices = np.argsort(deviation)
|
195 |
|
196 |
+
filtered_indices = []
|
197 |
+
cumulative_prob = 0.0
|
198 |
+
for idx in sorted_indices:
|
199 |
+
cumulative_prob += probs[idx]
|
200 |
+
filtered_indices.append(idx)
|
201 |
+
if cumulative_prob >= typical_p:
|
202 |
+
break
|
|
|
203 |
|
204 |
+
filtered_probs = np.zeros_like(probs)
|
205 |
+
filtered_probs[filtered_indices] = probs[filtered_indices]
|
206 |
+
filtered_probs /= filtered_probs.sum()
|
207 |
+
next_token = np.random.choice(len(filtered_probs), p=filtered_probs)
|
208 |
|
209 |
+
generated.append(int(next_token))
|
210 |
+
decoded = sp.decode(generated)
|
211 |
+
for t in ["<start>", "<sep>", "<end>"]:
|
212 |
+
decoded = decoded.replace(t, "")
|
213 |
+
decoded = decoded.strip()
|
214 |
|
|
|
|
|
215 |
if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
|
216 |
if is_greedy_response_acceptable(decoded):
|
217 |
return decoded
|
|
|
343 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
344 |
|
345 |
# 일상 대화: 샘플링 + fallback
|
346 |
+
response = generate_text_typical(model, input_text)
|
347 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
348 |
+
response = generate_text_typical(model, input_text)
|
349 |
return response
|
350 |
|
351 |
@app.get("/generate", response_class=PlainTextResponse)
|