Yuchan5386 commited on
Commit
bec02a2
·
verified ·
1 Parent(s): 3aa5873

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +5 -85
api.py CHANGED
@@ -203,77 +203,6 @@ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
203
  decoded = decoded.replace(t, "")
204
  return decoded.strip()
205
 
206
-
207
- from sklearn.feature_extraction.text import TfidfVectorizer
208
- from sklearn.decomposition import TruncatedSVD
209
- from sklearn.metrics.pairwise import cosine_similarity
210
-
211
- class SimilarityMemory:
212
- def __init__(self, n_components=100):
213
- self.memory_texts = []
214
- self.vectorizer = TfidfVectorizer()
215
- self.svd = TruncatedSVD(n_components=n_components)
216
- self.embeddings = None
217
- self.fitted = False
218
-
219
- def add(self, text: str):
220
- self.memory_texts.append(text)
221
- self._update_embeddings()
222
-
223
- def _update_embeddings(self):
224
- if len(self.memory_texts) == 0:
225
- self.embeddings = None
226
- self.fitted = False
227
- return
228
-
229
- X = self.vectorizer.fit_transform(self.memory_texts)
230
- n_comp = min(self.svd.n_components, X.shape[1], len(self.memory_texts)-1)
231
- if n_comp <= 0:
232
- self.embeddings = X.toarray()
233
- self.fitted = True
234
- return
235
-
236
- self.svd = TruncatedSVD(n_components=n_comp)
237
- self.embeddings = self.svd.fit_transform(X)
238
- self.fitted = True
239
-
240
- def retrieve(self, query: str, top_k=3):
241
- if not self.fitted or self.embeddings is None or len(self.memory_texts) == 0:
242
- return []
243
-
244
- Xq = self.vectorizer.transform([query])
245
- if self.svd.n_components > Xq.shape[1] or self.svd.n_components > len(self.memory_texts) - 1:
246
- q_emb = Xq.toarray()
247
- else:
248
- q_emb = self.svd.transform(Xq)
249
-
250
- sims = cosine_similarity(q_emb, self.embeddings)[0]
251
- top_indices = sims.argsort()[::-1][:top_k]
252
-
253
- return [self.memory_texts[i] for i in top_indices]
254
-
255
- def process_input(self, new_text: str, top_k=3):
256
- """자동으로 기억 저장하고, 유사한 기억 찾아서 합친 프롬프트 생성"""
257
- related_memories = self.retrieve(new_text, top_k=top_k)
258
- self.add(new_text)
259
- return self.merge_prompt(new_text, related_memories)
260
-
261
- def merge_prompt(self, prompt: str, memories: list):
262
- context = "\n".join(memories)
263
- full_prompt = ""
264
- if context:
265
- full_prompt += f"기억:\n{context}\n\n"
266
- full_prompt += f"현재 질문:\n{prompt}\n\n응답:"
267
- return full_prompt
268
-
269
- memory = SimilarityMemory()
270
-
271
- with open("base_texts.txt", "r", encoding="utf-8") as f:
272
- for line in f:
273
- line = line.strip()
274
- if line:
275
- memory.add(line)
276
-
277
  def mismatch_tone(input_text, output_text):
278
  if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
279
  return True
@@ -295,39 +224,30 @@ def is_valid_response(response):
295
 
296
 
297
  def respond(input_text):
298
- memory.process_input(input_text)
299
-
300
  if "이름" in input_text:
301
  response = "제 이름은 Flexi입니다."
302
- memory.process_input(response)
303
  return response
304
 
305
  if "누구" in input_text:
306
  response = "저는 Flexi라고 해요."
307
- memory.process_input(response)
308
  return response
309
 
310
- related_memories = memory.retrieve(input_text, top_k=3)
311
- merged_prompt = memory.merge_prompt(input_text, related_memories)
312
 
313
  for _ in range(3): # 최대 3번 재시도
314
- full_response = generate_text_sample(model, merged_prompt)
315
 
316
- # 여기서 '응답:' 뒤의 텍스트만 뽑기
317
  if "응답:" in full_response:
318
  response = full_response.split("응답:")[-1].strip()
319
  else:
320
  response = full_response.strip()
321
 
322
  if is_valid_response(response) and not mismatch_tone(input_text, response):
323
- memory.process_input(response)
324
  return response
325
 
326
- # 3번 모두 실패 시 fallback
327
- fallback_response = "죄송해요, 제대로 답변을 못했어요."
328
- memory.process_input(fallback_response)
329
- return fallback_response
330
-
331
 
332
  @app.get("/generate", response_class=PlainTextResponse)
333
  async def generate(request: Request):
 
203
  decoded = decoded.replace(t, "")
204
  return decoded.strip()
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def mismatch_tone(input_text, output_text):
207
  if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
208
  return True
 
224
 
225
 
226
  def respond(input_text):
227
+ # 이름 관련 질문에 딱 반응하는 부분 유지
 
228
  if "이름" in input_text:
229
  response = "제 이름은 Flexi입니다."
 
230
  return response
231
 
232
  if "누구" in input_text:
233
  response = "저는 Flexi라고 해요."
 
234
  return response
235
 
236
+ # 메모리 관련 부분 싹 제거하고, 단순 프롬프트 생성
237
+ full_prompt = f"현재 질문:\n{input_text}\n\n응답:"
238
 
239
  for _ in range(3): # 최대 3번 재시도
240
+ full_response = generate_text_sample(model, full_prompt)
241
 
 
242
  if "응답:" in full_response:
243
  response = full_response.split("응답:")[-1].strip()
244
  else:
245
  response = full_response.strip()
246
 
247
  if is_valid_response(response) and not mismatch_tone(input_text, response):
 
248
  return response
249
 
250
+ return "죄송해요, 제대로 답변을 못했어요."
 
 
 
 
251
 
252
  @app.get("/generate", response_class=PlainTextResponse)
253
  async def generate(request: Request):