Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -172,8 +172,10 @@ def is_greedy_response_acceptable(text):
|
|
172 |
|
173 |
return True
|
174 |
|
175 |
-
def
|
176 |
-
|
|
|
|
|
177 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
178 |
model_input = model_input[:max_len]
|
179 |
generated = list(model_input)
|
@@ -185,28 +187,61 @@ def generate_text_typical(model, prompt, max_len=100, max_gen=98,
|
|
185 |
logits = model(input_tensor, training=False)
|
186 |
next_logits = logits[0, len(generated) - 1].numpy()
|
187 |
|
188 |
-
#
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
generated.append(int(next_token))
|
|
|
210 |
decoded = sp.decode(generated)
|
211 |
for t in ["<start>", "<sep>", "<end>"]:
|
212 |
decoded = decoded.replace(t, "")
|
@@ -343,9 +378,9 @@ def respond(input_text):
|
|
343 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
344 |
|
345 |
# 일상 대화: 샘플링 + fallback
|
346 |
-
response =
|
347 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
348 |
-
response =
|
349 |
return response
|
350 |
|
351 |
@app.get("/generate", response_class=PlainTextResponse)
|
|
|
172 |
|
173 |
return True
|
174 |
|
175 |
+
def generate_text_flex(model, prompt, max_len=100, max_gen=98,
|
176 |
+
repetition_penalty=1.2, temperature=1.0,
|
177 |
+
top_k=50, top_p=0.85, typical_p=0.72,
|
178 |
+
min_len=20):
|
179 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
180 |
model_input = model_input[:max_len]
|
181 |
generated = list(model_input)
|
|
|
187 |
logits = model(input_tensor, training=False)
|
188 |
next_logits = logits[0, len(generated) - 1].numpy()
|
189 |
|
190 |
+
# Repetition penalty 적용
|
191 |
+
for t in set(generated):
|
192 |
+
count = generated.count(t)
|
193 |
+
next_logits[t] /= (repetition_penalty ** count)
|
194 |
+
|
195 |
+
# Temperature scaling
|
196 |
+
next_logits = next_logits / temperature
|
197 |
+
|
198 |
+
# 확률 계산
|
199 |
+
probs = np.exp(next_logits - np.max(next_logits))
|
200 |
+
probs = probs / probs.sum()
|
201 |
+
|
202 |
+
# Top-K 필터링
|
203 |
+
if top_k is not None and top_k > 0:
|
204 |
+
indices_to_remove = probs < np.sort(probs)[-top_k]
|
205 |
+
probs[indices_to_remove] = 0
|
206 |
+
probs /= probs.sum()
|
207 |
+
|
208 |
+
# Top-P (Nucleus) 필터링
|
209 |
+
if top_p is not None and 0 < top_p < 1:
|
210 |
+
sorted_indices = np.argsort(probs)[::-1]
|
211 |
+
sorted_probs = probs[sorted_indices]
|
212 |
+
cumulative_probs = np.cumsum(sorted_probs)
|
213 |
+
cutoff_index = np.searchsorted(cumulative_probs, top_p, side='right')
|
214 |
+
keep_indices = sorted_indices[:cutoff_index + 1]
|
215 |
|
216 |
+
filtered_probs = np.zeros_like(probs)
|
217 |
+
filtered_probs[keep_indices] = probs[keep_indices]
|
218 |
+
filtered_probs /= filtered_probs.sum()
|
219 |
+
probs = filtered_probs
|
220 |
|
221 |
+
# Typical-p 필터링
|
222 |
+
if typical_p is not None and 0 < typical_p < 1:
|
223 |
+
log_probs = -np.log(probs + 1e-10)
|
224 |
+
mean_info = np.mean(log_probs)
|
225 |
+
deviation = np.abs(log_probs - mean_info)
|
226 |
+
sorted_indices = np.argsort(deviation)
|
227 |
+
|
228 |
+
filtered_indices = []
|
229 |
+
cumulative_prob = 0.0
|
230 |
+
for idx in sorted_indices:
|
231 |
+
cumulative_prob += probs[idx]
|
232 |
+
filtered_indices.append(idx)
|
233 |
+
if cumulative_prob >= typical_p:
|
234 |
+
break
|
235 |
+
|
236 |
+
filtered_probs = np.zeros_like(probs)
|
237 |
+
filtered_probs[filtered_indices] = probs[filtered_indices]
|
238 |
+
filtered_probs /= filtered_probs.sum()
|
239 |
+
probs = filtered_probs
|
240 |
+
|
241 |
+
# 다음 토큰 샘플링
|
242 |
+
next_token = np.random.choice(len(probs), p=probs)
|
243 |
generated.append(int(next_token))
|
244 |
+
|
245 |
decoded = sp.decode(generated)
|
246 |
for t in ["<start>", "<sep>", "<end>"]:
|
247 |
decoded = decoded.replace(t, "")
|
|
|
378 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
379 |
|
380 |
# 일상 대화: 샘플링 + fallback
|
381 |
+
response = generate_text_flex(model, input_text)
|
382 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
383 |
+
response = generate_text_flex(model, input_text)
|
384 |
return response
|
385 |
|
386 |
@app.get("/generate", response_class=PlainTextResponse)
|