Yuchan5386 commited on
Commit
d3506eb
Β·
verified Β·
1 Parent(s): 67196c6

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +19 -2
api.py CHANGED
@@ -260,7 +260,12 @@ class SimilarityMemory:
260
 
261
  def merge_prompt(self, prompt: str, memories: list):
262
  context = "\n".join(memories)
263
- return f"{context}\n\n{prompt}" if context else prompt
 
 
 
 
 
264
 
265
  memory = SimilarityMemory()
266
 
@@ -302,11 +307,23 @@ def respond(input_text):
302
  merged_prompt = memory.merge_prompt(input_text, related_memories)
303
 
304
  for _ in range(3): # μ΅œλŒ€ 3번 μž¬μ‹œλ„
305
- response = generate_text_sample(model, merged_prompt)
 
 
 
 
 
 
 
306
  if is_valid_response(response) and not mismatch_tone(input_text, response):
307
  memory.process_input(response)
308
  return response
309
 
 
 
 
 
 
310
 
311
  @app.get("/generate", response_class=PlainTextResponse)
312
  async def generate(request: Request):
 
260
 
261
  def merge_prompt(self, prompt: str, memories: list):
262
  context = "\n".join(memories)
263
+ full_prompt = ""
264
+ if context:
265
+ full_prompt += f"κΈ°μ–΅:\n{context}\n\n"
266
+ full_prompt += f"ν˜„μž¬ 질문:\n{prompt}\n\n응닡:"
267
+ return full_prompt
268
+
269
 
270
  memory = SimilarityMemory()
271
 
 
307
  merged_prompt = memory.merge_prompt(input_text, related_memories)
308
 
309
  for _ in range(3): # μ΅œλŒ€ 3번 μž¬μ‹œλ„
310
+ full_response = generate_text_sample(model, merged_prompt)
311
+
312
+ # μ—¬κΈ°μ„œ '응닡:' λ’€μ˜ ν…μŠ€νŠΈλ§Œ 뽑기
313
+ if "응닡:" in full_response:
314
+ response = full_response.split("응닡:")[-1].strip()
315
+ else:
316
+ response = full_response.strip()
317
+
318
  if is_valid_response(response) and not mismatch_tone(input_text, response):
319
  memory.process_input(response)
320
  return response
321
 
322
+ # 3번 λͺ¨λ‘ μ‹€νŒ¨ μ‹œ fallback
323
+ fallback_response = "μ£„μ†‘ν•΄μš”, μ œλŒ€λ‘œ 닡변을 λͺ»ν–ˆμ–΄μš”."
324
+ memory.process_input(fallback_response)
325
+ return fallback_response
326
+
327
 
328
  @app.get("/generate", response_class=PlainTextResponse)
329
  async def generate(request: Request):