Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -166,54 +166,39 @@ def is_greedy_response_acceptable(text):
|
|
166 |
|
167 |
return True
|
168 |
|
169 |
-
def
|
170 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
171 |
model_input = model_input[:max_len]
|
172 |
-
|
173 |
-
beams = [{
|
174 |
-
"sequence": list(model_input),
|
175 |
-
"score": 0.0
|
176 |
-
}]
|
177 |
|
178 |
for _ in range(max_len):
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
"sequence": new_seq,
|
198 |
-
"score": new_score
|
199 |
-
})
|
200 |
-
|
201 |
-
# 길이 보정
|
202 |
-
for cand in all_candidates:
|
203 |
-
cand["score"] /= (len(cand["sequence"]) ** length_penalty)
|
204 |
-
|
205 |
-
# 상위 beam_width개만 유지
|
206 |
-
beams = sorted(all_candidates, key=lambda x: x["score"], reverse=True)[:beam_width]
|
207 |
-
|
208 |
-
# 조기 종료 (EOS 토큰 또는 끝나는 말투)
|
209 |
-
for b in beams:
|
210 |
-
decoded = sp.decode(b["sequence"]).strip()
|
211 |
-
if end_id in b["sequence"] and is_greedy_response_acceptable(decoded):
|
212 |
return decoded
|
|
|
|
|
213 |
|
214 |
-
#
|
215 |
-
|
216 |
-
|
|
|
|
|
217 |
|
218 |
|
219 |
def mismatch_tone(input_text, output_text):
|
@@ -339,9 +324,9 @@ def respond(input_text):
|
|
339 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
340 |
|
341 |
# 일상 대화: 샘플링 + fallback
|
342 |
-
response =
|
343 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
344 |
-
response =
|
345 |
return response
|
346 |
|
347 |
@app.get("/generate", response_class=PlainTextResponse)
|
|
|
166 |
|
167 |
return True
|
168 |
|
169 |
+
def generate_text_greedy(model, prompt, max_len=100, min_len=12):
|
170 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
171 |
model_input = model_input[:max_len]
|
172 |
+
generated = list(model_input)
|
|
|
|
|
|
|
|
|
173 |
|
174 |
for _ in range(max_len):
|
175 |
+
pad_len = max(0, max_len - len(generated))
|
176 |
+
input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
|
177 |
+
input_tensor = tf.convert_to_tensor([input_padded])
|
178 |
+
logits = model(input_tensor, training=False)
|
179 |
+
next_logits = logits[0, len(generated) - 1].numpy()
|
180 |
+
|
181 |
+
# Greedy: 확률 가장 높은 토큰만 선택
|
182 |
+
next_token = int(np.argmax(next_logits))
|
183 |
+
generated.append(next_token)
|
184 |
+
|
185 |
+
# 디코딩 및 필터링
|
186 |
+
decoded = sp.decode(generated)
|
187 |
+
for t in ["<start>", "<sep>", "<end>"]:
|
188 |
+
decoded = decoded.replace(t, "")
|
189 |
+
decoded = decoded.strip()
|
190 |
+
|
191 |
+
if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
|
192 |
+
if is_greedy_response_acceptable(decoded):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
return decoded
|
194 |
+
else:
|
195 |
+
continue
|
196 |
|
197 |
+
# 끝까지 만족하는 문장 없으면 그냥 최종 출력
|
198 |
+
decoded = sp.decode(generated)
|
199 |
+
for t in ["<start>", "<sep>", "<end>"]:
|
200 |
+
decoded = decoded.replace(t, "")
|
201 |
+
return decoded.strip()
|
202 |
|
203 |
|
204 |
def mismatch_tone(input_text, output_text):
|
|
|
324 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
325 |
|
326 |
# 일상 대화: 샘플링 + fallback
|
327 |
+
response = generate_text_greedy(model, input_text)
|
328 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
329 |
+
response = generate_text_greedy(model, input_text)
|
330 |
return response
|
331 |
|
332 |
@app.get("/generate", response_class=PlainTextResponse)
|