Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -166,88 +166,55 @@ def is_greedy_response_acceptable(text):
|
|
166 |
|
167 |
return True
|
168 |
|
169 |
-
def
|
170 |
-
repetition_penalty=1.2, temperature=0.55,
|
171 |
-
top_k=50, top_p=0.70, typical_p=None,
|
172 |
-
min_len=12):
|
173 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
174 |
model_input = model_input[:max_len]
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
# Typical-p 필터링
|
216 |
-
if typical_p is not None and 0 < typical_p < 1:
|
217 |
-
log_probs = -np.log(probs + 1e-10)
|
218 |
-
mean_info = np.mean(log_probs)
|
219 |
-
deviation = np.abs(log_probs - mean_info)
|
220 |
-
sorted_indices = np.argsort(deviation)
|
221 |
-
|
222 |
-
filtered_indices = []
|
223 |
-
cumulative_prob = 0.0
|
224 |
-
for idx in sorted_indices:
|
225 |
-
cumulative_prob += probs[idx]
|
226 |
-
filtered_indices.append(idx)
|
227 |
-
if cumulative_prob >= typical_p:
|
228 |
-
break
|
229 |
-
|
230 |
-
filtered_probs = np.zeros_like(probs)
|
231 |
-
filtered_probs[filtered_indices] = probs[filtered_indices]
|
232 |
-
filtered_probs /= filtered_probs.sum()
|
233 |
-
probs = filtered_probs
|
234 |
-
|
235 |
-
# 다음 토큰 샘플링
|
236 |
-
next_token = np.random.choice(len(probs), p=probs)
|
237 |
-
generated.append(int(next_token))
|
238 |
-
|
239 |
-
decoded = sp.decode(generated)
|
240 |
-
for t in ["<start>", "<sep>", "<end>"]:
|
241 |
-
decoded = decoded.replace(t, "")
|
242 |
-
decoded = decoded.strip()
|
243 |
-
|
244 |
-
if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
|
245 |
-
if is_greedy_response_acceptable(decoded):
|
246 |
return decoded
|
247 |
-
else:
|
248 |
-
continue
|
249 |
|
250 |
-
|
|
|
|
|
|
|
251 |
|
252 |
def mismatch_tone(input_text, output_text):
|
253 |
if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
|
@@ -372,9 +339,9 @@ def respond(input_text):
|
|
372 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
373 |
|
374 |
# 일상 대화: 샘플링 + fallback
|
375 |
-
response =
|
376 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
377 |
-
response =
|
378 |
return response
|
379 |
|
380 |
@app.get("/generate", response_class=PlainTextResponse)
|
|
|
166 |
|
167 |
return True
|
168 |
|
169 |
+
def generate_text_beam(model, prompt, max_len=100, beam_width=4, length_penalty=0.7):
|
|
|
|
|
|
|
170 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
171 |
model_input = model_input[:max_len]
|
172 |
+
|
173 |
+
beams = [{
|
174 |
+
"sequence": list(model_input),
|
175 |
+
"score": 0.0
|
176 |
+
}]
|
177 |
+
|
178 |
+
for _ in range(max_len):
|
179 |
+
all_candidates = []
|
180 |
+
|
181 |
+
for beam in beams:
|
182 |
+
seq = beam["sequence"]
|
183 |
+
pad_len = max(0, max_len - len(seq))
|
184 |
+
input_padded = np.pad(seq, (0, pad_len), constant_values=pad_id)
|
185 |
+
input_tensor = tf.convert_to_tensor([input_padded])
|
186 |
+
logits = model(input_tensor, training=False)[0, len(seq) - 1].numpy()
|
187 |
+
|
188 |
+
probs = np.exp(logits - np.max(logits))
|
189 |
+
probs = probs / probs.sum()
|
190 |
+
|
191 |
+
top_indices = probs.argsort()[-beam_width:][::-1]
|
192 |
+
|
193 |
+
for idx in top_indices:
|
194 |
+
new_seq = seq + [int(idx)]
|
195 |
+
new_score = beam["score"] + np.log(probs[idx])
|
196 |
+
all_candidates.append({
|
197 |
+
"sequence": new_seq,
|
198 |
+
"score": new_score
|
199 |
+
})
|
200 |
+
|
201 |
+
# 길이 보정
|
202 |
+
for cand in all_candidates:
|
203 |
+
cand["score"] /= (len(cand["sequence"]) ** length_penalty)
|
204 |
+
|
205 |
+
# 상위 beam_width개만 유지
|
206 |
+
beams = sorted(all_candidates, key=lambda x: x["score"], reverse=True)[:beam_width]
|
207 |
+
|
208 |
+
# 조기 종료 (EOS 토큰 또는 끝나는 말투)
|
209 |
+
for b in beams:
|
210 |
+
decoded = sp.decode(b["sequence"]).strip()
|
211 |
+
if end_id in b["sequence"] and is_greedy_response_acceptable(decoded):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
return decoded
|
|
|
|
|
213 |
|
214 |
+
# 최종 후보 중 가장 점수 높은 거 반환
|
215 |
+
final = beams[0]["sequence"]
|
216 |
+
return sp.decode(final)
|
217 |
+
|
218 |
|
219 |
def mismatch_tone(input_text, output_text):
|
220 |
if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
|
|
|
339 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
340 |
|
341 |
# 일상 대화: 샘플링 + fallback
|
342 |
+
response = generate_text_beam(model, input_text)
|
343 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
344 |
+
response = generate_text_beam(model, input_text)
|
345 |
return response
|
346 |
|
347 |
@app.get("/generate", response_class=PlainTextResponse)
|