Yuchan5386 commited on
Commit
55632cc
ยท
verified ยท
1 Parent(s): 773bcc1

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +28 -10
api.py CHANGED
@@ -288,10 +288,6 @@ class SimilarityMemory:
288
  # ํ…Œ์ŠคํŠธ
289
  memory = SimilarityMemory()
290
 
291
- memory.add("์–ด์ œ๋Š” ๊ธฐ๋ถ„์ด ๋ณ„๋กœ์˜€์–ด")
292
- memory.add("์˜ํ™” ๋ณด๋Ÿฌ ๊ฐ”๋‹ค๊ฐ€ ์นœ๊ตฌ๋ž‘ ์‹ธ์› ์–ด")
293
- memory.add("์นดํŽ˜์—์„œ ๊ณต๋ถ€ํ–ˆ๋Š”๋ฐ ์ง‘์ค‘์ด ์ž˜ ๋์–ด")
294
-
295
  def merge_prompt_with_memory(prompt: str, memories: list):
296
  context = "\n".join(f"{mem}" for mem in memories)
297
  return f"{context} {prompt}"
@@ -395,34 +391,56 @@ def parse_math_question(text):
395
  except:
396
  return "๊ณ„์‚ฐํ•  ์ˆ˜ ์—†๋Š” ์ˆ˜์‹์ด์—์š”. ๋‹ค์‹œ ํ•œ๋ฒˆ ํ™•์ธํ•ด ์ฃผ์„ธ์š”!"
397
 
 
398
  def respond(input_text):
 
 
 
399
  intent = simple_intent_classifier(input_text)
400
 
401
  if "์ด๋ฆ„" in input_text:
402
- return "์ œ ์ด๋ฆ„์€ Flexi์ž…๋‹ˆ๋‹ค."
 
 
403
 
404
  if "๋ˆ„๊ตฌ" in input_text:
405
- return "์ €๋Š” Flexi๋ผ๊ณ  ํ•ด์š”."
 
 
406
 
407
  if intent == "์ˆ˜ํ•™์งˆ๋ฌธ":
408
- return parse_math_question(input_text)
 
 
409
 
410
  if intent == "์ •๋ณด์งˆ๋ฌธ":
411
  keyword = re.sub(r"(์— ๋Œ€ํ•ด|์— ๋Œ€ํ•œ|์— ๋Œ€ํ•ด์„œ)?\s*(์„ค๋ช…ํ•ด์ค˜|์•Œ๋ ค์ค˜|๋ญ์•ผ|๊ฐœ๋…|์ •์˜|์ •๋ณด)?", "", input_text).strip()
412
  if not keyword:
413
- return "์–ด๋–ค ์ฃผ์ œ์— ๋Œ€ํ•ด ๊ถ๊ธˆํ•œ๊ฐ€์š”?"
 
 
414
  summary = summarize_from_wikipedia(keyword)
415
- return f"{summary}\n๋‹ค๋ฅธ ๊ถ๊ธˆํ•œ ์  ์žˆ์œผ์‹ ๊ฐ€์š”?"
 
 
416
 
417
- # โœ… ๊ธฐ์–ต ๊ธฐ๋ฐ˜ ๋ณ‘ํ•ฉ ์ถ”๊ฐ€
418
  related_memories = memory.retrieve(input_text, top_k=3)
419
  merged_prompt = merge_prompt_with_memory(input_text, related_memories)
420
 
 
421
  response = generate_text_sample(model, merged_prompt)
 
 
422
  if not is_valid_response(response) or mismatch_tone(input_text, response):
423
  response = generate_text_sample(model, merged_prompt)
 
 
 
 
424
  return response
425
 
 
426
  @app.get("/generate", response_class=PlainTextResponse)
427
  async def generate(request: Request):
428
  prompt = request.query_params.get("prompt", "์•ˆ๋…•ํ•˜์„ธ์š”")
 
288
  # ํ…Œ์ŠคํŠธ
289
  memory = SimilarityMemory()
290
 
 
 
 
 
291
  def merge_prompt_with_memory(prompt: str, memories: list):
292
  context = "\n".join(f"{mem}" for mem in memories)
293
  return f"{context} {prompt}"
 
391
  except:
392
  return "๊ณ„์‚ฐํ•  ์ˆ˜ ์—†๋Š” ์ˆ˜์‹์ด์—์š”. ๋‹ค์‹œ ํ•œ๋ฒˆ ํ™•์ธํ•ด ์ฃผ์„ธ์š”!"
393
 
394
+
395
  def respond(input_text):
396
+ # 1) ์‚ฌ์šฉ์ž ์ž…๋ ฅ ๊ธฐ์–ต์— ์ €์žฅ (์›ํ•˜๋ฉด)
397
+ memory.add(input_text)
398
+
399
  intent = simple_intent_classifier(input_text)
400
 
401
  if "์ด๋ฆ„" in input_text:
402
+ response = "์ œ ์ด๋ฆ„์€ Flexi์ž…๋‹ˆ๋‹ค."
403
+ memory.add(response) # ๋‹ต๋ณ€๋„ ๊ธฐ์–ต์— ์ถ”๊ฐ€ ๊ฐ€๋Šฅ
404
+ return response
405
 
406
  if "๋ˆ„๊ตฌ" in input_text:
407
+ response = "์ €๋Š” Flexi๋ผ๊ณ  ํ•ด์š”."
408
+ memory.add(response)
409
+ return response
410
 
411
  if intent == "์ˆ˜ํ•™์งˆ๋ฌธ":
412
+ response = parse_math_question(input_text)
413
+ memory.add(response)
414
+ return response
415
 
416
  if intent == "์ •๋ณด์งˆ๋ฌธ":
417
  keyword = re.sub(r"(์— ๋Œ€ํ•ด|์— ๋Œ€ํ•œ|์— ๋Œ€ํ•ด์„œ)?\s*(์„ค๋ช…ํ•ด์ค˜|์•Œ๋ ค์ค˜|๋ญ์•ผ|๊ฐœ๋…|์ •์˜|์ •๋ณด)?", "", input_text).strip()
418
  if not keyword:
419
+ response = "์–ด๋–ค ์ฃผ์ œ์— ๋Œ€ํ•ด ๊ถ๊ธˆํ•œ๊ฐ€์š”?"
420
+ memory.add(response)
421
+ return response
422
  summary = summarize_from_wikipedia(keyword)
423
+ response = f"{summary}\n๋‹ค๋ฅธ ๊ถ๊ธˆํ•œ ์  ์žˆ์œผ์‹ ๊ฐ€์š”?"
424
+ memory.add(response)
425
+ return response
426
 
427
+ # ๊ธฐ์–ต์—์„œ ์œ ์‚ฌ ๋ฌธ์žฅ ๊บผ๋‚ด์„œ ํ”„๋กฌํ”„ํŠธ ๋งŒ๋“ค๊ธฐ
428
  related_memories = memory.retrieve(input_text, top_k=3)
429
  merged_prompt = merge_prompt_with_memory(input_text, related_memories)
430
 
431
+ # ๋ชจ๋ธ๋กœ ์‘๋‹ต ์ƒ์„ฑ
432
  response = generate_text_sample(model, merged_prompt)
433
+
434
+ # ์‘๋‹ต ๊ฒ€์ฆ, ์•ˆ ๋งž์œผ๋ฉด ์žฌ์ƒ์„ฑ
435
  if not is_valid_response(response) or mismatch_tone(input_text, response):
436
  response = generate_text_sample(model, merged_prompt)
437
+
438
+ # ์ตœ์ข… ์‘๋‹ต๋„ ๊ธฐ์–ต์— ์ถ”๊ฐ€
439
+ memory.add(response)
440
+
441
  return response
442
 
443
+
444
  @app.get("/generate", response_class=PlainTextResponse)
445
  async def generate(request: Request):
446
  prompt = request.query_params.get("prompt", "์•ˆ๋…•ํ•˜์„ธ์š”")