Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -6,10 +6,10 @@ import asyncio
|
|
6 |
from fastapi import FastAPI, Request
|
7 |
from fastapi.responses import StreamingResponse
|
8 |
import sentencepiece as spm
|
9 |
-
from typing import List
|
10 |
import re
|
11 |
-
app = FastAPI()
|
12 |
-
|
|
|
13 |
from fastapi.middleware.cors import CORSMiddleware
|
14 |
|
15 |
origins = [
|
@@ -199,14 +199,10 @@ def generate_text_greedy_strong(model, prompt, max_len=100, max_gen=98,
|
|
199 |
next_token = np.argmax(next_logits)
|
200 |
generated.append(int(next_token))
|
201 |
|
202 |
-
decoded = sp.decode(generated)
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
decoded = decoded.replace(t, "")
|
207 |
-
|
208 |
-
# 입력 프롬프트도 제거
|
209 |
-
decoded = decoded.replace(prompt, "").strip()
|
210 |
|
211 |
if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
|
212 |
if is_greedy_response_acceptable(decoded):
|
@@ -289,14 +285,6 @@ def summarize_from_wikipedia(query, top_n=3):
|
|
289 |
raw_summary = get_wikipedia_summary(query)
|
290 |
return textrank_summarize(raw_summary, top_n=top_n)
|
291 |
|
292 |
-
def build_contexted_prompt(history: List[str], user_input: str):
|
293 |
-
# 최근 대화 3개를 합쳐서 요약
|
294 |
-
recent = ' '.join(history[-3:])
|
295 |
-
summary = textrank_summarize(recent, top_n=2)
|
296 |
-
# 요약 + 최신 사용자 입력을 같이 던져줌
|
297 |
-
prompt = f"{summary} {user_input}"
|
298 |
-
return prompt
|
299 |
-
|
300 |
# 의도 분류기
|
301 |
def simple_intent_classifier(text):
|
302 |
text = text.lower()
|
@@ -321,10 +309,8 @@ def parse_math_question(text):
|
|
321 |
except:
|
322 |
return "계산할 수 없는 수식이에요. 다시 한번 확인해 주세요!"
|
323 |
|
324 |
-
#
|
325 |
def respond(input_text):
|
326 |
-
global dialogue_history
|
327 |
-
|
328 |
intent = simple_intent_classifier(input_text)
|
329 |
|
330 |
if "이름" in input_text:
|
@@ -334,38 +320,22 @@ def respond(input_text):
|
|
334 |
return "저는 Ector.V라고 해요."
|
335 |
|
336 |
if intent == "수학질문":
|
337 |
-
|
338 |
-
response = parse_math_question(input_text)
|
339 |
-
dialogue_history.append(f"Ector: {response}")
|
340 |
-
return response
|
341 |
|
342 |
if intent == "인사":
|
343 |
-
|
344 |
-
dialogue_history.append(f"사용자: {input_text}")
|
345 |
-
dialogue_history.append(f"Ector: {response}")
|
346 |
-
return response
|
347 |
|
348 |
if intent == "정보질문":
|
349 |
keyword = re.sub(r"(에 대해|에 대한|에 대해서)?\s*(설명해줘|알려줘|뭐야|개념|정의|정보)?", "", input_text).strip()
|
350 |
if not keyword:
|
351 |
return "어떤 주제에 대해 궁금한가요?"
|
352 |
summary = summarize_from_wikipedia(keyword)
|
353 |
-
|
354 |
-
dialogue_history.append(f"사용자: {input_text}")
|
355 |
-
dialogue_history.append(f"Ector: {response}")
|
356 |
-
return response
|
357 |
|
358 |
-
# 일상 대화:
|
359 |
-
|
360 |
-
response = generate_text_greedy_strong(model, contexted_prompt)
|
361 |
-
|
362 |
-
# fallback
|
363 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
364 |
-
response = generate_text_greedy_strong(model,
|
365 |
-
|
366 |
-
# 히스토리 추가
|
367 |
-
dialogue_history.append(f"사용자: {input_text}")
|
368 |
-
dialogue_history.append(f"Ector: {response}")
|
369 |
return response
|
370 |
|
371 |
@app.get("/generate", response_class=PlainTextResponse)
|
|
|
6 |
from fastapi import FastAPI, Request
|
7 |
from fastapi.responses import StreamingResponse
|
8 |
import sentencepiece as spm
|
|
|
9 |
import re
|
10 |
+
app = FastAPI()
|
11 |
+
|
12 |
+
|
13 |
from fastapi.middleware.cors import CORSMiddleware
|
14 |
|
15 |
origins = [
|
|
|
199 |
next_token = np.argmax(next_logits)
|
200 |
generated.append(int(next_token))
|
201 |
|
202 |
+
decoded = sp.decode(generated)
|
203 |
+
for t in ["<start>", "<sep>", "<end>"]:
|
204 |
+
decoded = decoded.replace(t, "")
|
205 |
+
decoded = decoded.strip()
|
|
|
|
|
|
|
|
|
206 |
|
207 |
if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
|
208 |
if is_greedy_response_acceptable(decoded):
|
|
|
285 |
raw_summary = get_wikipedia_summary(query)
|
286 |
return textrank_summarize(raw_summary, top_n=top_n)
|
287 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
# 의도 분류기
|
289 |
def simple_intent_classifier(text):
|
290 |
text = text.lower()
|
|
|
309 |
except:
|
310 |
return "계산할 수 없는 수식이에요. 다시 한번 확인해 주세요!"
|
311 |
|
312 |
+
# 최종 응답 함수
|
313 |
def respond(input_text):
|
|
|
|
|
314 |
intent = simple_intent_classifier(input_text)
|
315 |
|
316 |
if "이름" in input_text:
|
|
|
320 |
return "저는 Ector.V라고 해요."
|
321 |
|
322 |
if intent == "수학질문":
|
323 |
+
return parse_math_question(input_text)
|
|
|
|
|
|
|
324 |
|
325 |
if intent == "인사":
|
326 |
+
return "반가워요! 무엇을 도와드릴까요?"
|
|
|
|
|
|
|
327 |
|
328 |
if intent == "정보질문":
|
329 |
keyword = re.sub(r"(에 대해|에 대한|에 대해서)?\s*(설명해줘|알려줘|뭐야|개념|정의|정보)?", "", input_text).strip()
|
330 |
if not keyword:
|
331 |
return "어떤 주제에 대해 궁금한가요?"
|
332 |
summary = summarize_from_wikipedia(keyword)
|
333 |
+
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
|
|
|
|
|
|
334 |
|
335 |
+
# 일상 대화: 샘플링 + fallback
|
336 |
+
response = generate_text_greedy_strong(model, input_text)
|
|
|
|
|
|
|
337 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
338 |
+
response = generate_text_greedy_strong(model, input_text)
|
|
|
|
|
|
|
|
|
339 |
return response
|
340 |
|
341 |
@app.get("/generate", response_class=PlainTextResponse)
|