Yuchan5386 commited on
Commit
0b02fb4
ยท
verified ยท
1 Parent(s): 77642ac

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +6 -14
api.py CHANGED
@@ -307,12 +307,11 @@ def is_valid_response(response):
307
 
308
 
309
  def respond(input_text):
310
- # 1) ์‚ฌ์šฉ์ž ์ž…๋ ฅ ๊ธฐ์–ต์— ์ €์žฅ (์›ํ•˜๋ฉด)
311
  memory.process_input(input_text)
312
 
313
  if "์ด๋ฆ„" in input_text:
314
  response = "์ œ ์ด๋ฆ„์€ Flexi์ž…๋‹ˆ๋‹ค."
315
- memory.process_input(response) # ๋‹ต๋ณ€๋„ ๊ธฐ์–ต์— ์ถ”๊ฐ€ ๊ฐ€๋Šฅ
316
  return response
317
 
318
  if "๋ˆ„๊ตฌ" in input_text:
@@ -320,21 +319,14 @@ def respond(input_text):
320
  memory.process_input(response)
321
  return response
322
 
323
- # ๊ธฐ์–ต์—์„œ ์œ ์‚ฌ ๋ฌธ์žฅ ๊บผ๋‚ด์„œ ํ”„๋กฌํ”„ํŠธ ๋งŒ๋“ค๊ธฐ
324
  related_memories = memory.retrieve(input_text, top_k=3)
325
- merged_prompt = merge_prompt(input_text, related_memories)
326
-
327
- # ๋ชจ๋ธ๋กœ ์‘๋‹ต ์ƒ์„ฑ
328
- response = generate_text_sample(model, merged_prompt)
329
 
330
- # ์‘๋‹ต ๊ฒ€์ฆ, ์•ˆ ๋งž์œผ๋ฉด ์žฌ์ƒ์„ฑ
331
- if not is_valid_response(response) or mismatch_tone(input_text, response):
332
  response = generate_text_sample(model, merged_prompt)
333
-
334
- # ์ตœ์ข… ์‘๋‹ต๋„ ๊ธฐ์–ต์— ์ถ”๊ฐ€
335
- memory.process_input(response)
336
-
337
- return response
338
 
339
 
340
  @app.get("/generate", response_class=PlainTextResponse)
 
307
 
308
 
309
  def respond(input_text):
 
310
  memory.process_input(input_text)
311
 
312
  if "์ด๋ฆ„" in input_text:
313
  response = "์ œ ์ด๋ฆ„์€ Flexi์ž…๋‹ˆ๋‹ค."
314
+ memory.process_input(response)
315
  return response
316
 
317
  if "๋ˆ„๊ตฌ" in input_text:
 
319
  memory.process_input(response)
320
  return response
321
 
 
322
  related_memories = memory.retrieve(input_text, top_k=3)
323
+ merged_prompt = memory.merge_prompt(input_text, related_memories)
 
 
 
324
 
325
+ for _ in range(3): # ์ตœ๋Œ€ 3๋ฒˆ ์žฌ์‹œ๋„
 
326
  response = generate_text_sample(model, merged_prompt)
327
+ if is_valid_response(response) and not mismatch_tone(input_text, response):
328
+ memory.process_input(response)
329
+ return response
 
 
330
 
331
 
332
  @app.get("/generate", response_class=PlainTextResponse)