Yuchan5386 commited on
Commit
919644d
Β·
verified Β·
1 Parent(s): e89787a

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +75 -5
api.py CHANGED
@@ -228,7 +228,75 @@ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
228
  for t in ["<start>", "<sep>", "<end>"]:
229
  decoded = decoded.replace(t, "")
230
  return decoded.strip()
231
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  def mismatch_tone(input_text, output_text):
233
  if "γ…‹γ…‹" in input_text and not re.search(r'γ…‹γ…‹|γ…Ž|재밌|놀|λ§Œλ‚˜|λ§›μ§‘|μ—¬ν–‰', output_text):
234
  return True
@@ -328,7 +396,6 @@ def parse_math_question(text):
328
  except:
329
  return "계산할 수 μ—†λŠ” μˆ˜μ‹μ΄μ—μš”. λ‹€μ‹œ ν•œλ²ˆ 확인해 μ£Όμ„Έμš”!"
330
 
331
- # μ΅œμ’… 응닡 ν•¨μˆ˜
332
  def respond(input_text):
333
  intent = simple_intent_classifier(input_text)
334
 
@@ -348,10 +415,13 @@ def respond(input_text):
348
  summary = summarize_from_wikipedia(keyword)
349
  return f"{summary}\nλ‹€λ₯Έ κΆκΈˆν•œ 점 μžˆμœΌμ‹ κ°€μš”?"
350
 
351
- # 일상 λŒ€ν™”: μƒ˜ν”Œλ§ + fallback
352
- response = generate_text_sample(model, input_text)
 
 
 
353
  if not is_valid_response(response) or mismatch_tone(input_text, response):
354
- response = generate_text_sample(model, input_text)
355
  return response
356
 
357
  @app.get("/generate", response_class=PlainTextResponse)
 
228
  for t in ["<start>", "<sep>", "<end>"]:
229
  decoded = decoded.replace(t, "")
230
  return decoded.strip()
231
+
232
+ import numpy as np
233
+ from sklearn.feature_extraction.text import TfidfVectorizer
234
+ from sklearn.decomposition import TruncatedSVD
235
+ from sklearn.metrics.pairwise import cosine_similarity
236
+
237
+ class SimilarityMemory:
238
+ def __init__(self, n_components=100):
239
+ self.memory_texts = []
240
+ self.vectorizer = TfidfVectorizer()
241
+ self.svd = TruncatedSVD(n_components=n_components)
242
+ self.embeddings = None
243
+ self.fitted = False
244
+
245
+ def add(self, text: str):
246
+ self.memory_texts.append(text)
247
+ self._update_embeddings()
248
+
249
+ def _update_embeddings(self):
250
+ # ν…μŠ€νŠΈκ°€ 1개 이상일 λ•Œλ§Œ 벑터화 및 차원 μΆ•μ†Œ μ§„ν–‰
251
+ if len(self.memory_texts) == 0:
252
+ self.embeddings = None
253
+ self.fitted = False
254
+ return
255
+
256
+ # 벑터화
257
+ X = self.vectorizer.fit_transform(self.memory_texts)
258
+
259
+ # 차원 μΆ•μ†Œ
260
+ n_comp = min(self.svd.n_components, X.shape[1], len(self.memory_texts)-1)
261
+ if n_comp <= 0:
262
+ # μΆ•μ†Œν•  차원이 μ—†μœΌλ©΄ κ·Έλƒ₯ TF-IDF 벑터 μ‚¬μš©
263
+ self.embeddings = X.toarray()
264
+ self.fitted = True
265
+ return
266
+
267
+ self.svd = TruncatedSVD(n_components=n_comp)
268
+ self.embeddings = self.svd.fit_transform(X)
269
+ self.fitted = True
270
+
271
+ def retrieve(self, query: str, top_k=3):
272
+ if not self.fitted or self.embeddings is None or len(self.memory_texts) == 0:
273
+ return []
274
+
275
+ # 쿼리 벑터화 + 차원 μΆ•μ†Œ (fit_transform이 μ•„λ‹ˆλΌ transform ν•΄μ•Ό 함)
276
+ Xq = self.vectorizer.transform([query])
277
+ if self.svd.n_components > Xq.shape[1] or self.svd.n_components > len(self.memory_texts) - 1:
278
+ q_emb = Xq.toarray()
279
+ else:
280
+ q_emb = self.svd.transform(Xq)
281
+
282
+ # 코사인 μœ μ‚¬λ„ 계산
283
+ sims = cosine_similarity(q_emb, self.embeddings)[0]
284
+ top_indices = sims.argsort()[::-1][:top_k]
285
+
286
+ return [self.memory_texts[i] for i in top_indices]
287
+
288
+ # ν…ŒμŠ€νŠΈ
289
+ memory = SimilarityMemory()
290
+
291
+ memory.add("μ–΄μ œλŠ” 기뢄이 λ³„λ‘œμ˜€μ–΄")
292
+ memory.add("μ˜ν™” 보러 κ°”λ‹€κ°€ μΉœκ΅¬λž‘ μ‹Έμ› μ–΄")
293
+ memory.add("μΉ΄νŽ˜μ—μ„œ κ³΅λΆ€ν–ˆλŠ”λ° 집쀑이 잘 됐어")
294
+
295
+ def merge_prompt_with_memory(prompt: str, memories: list):
296
+ context = "\n".join(f"κ³Όκ±°: {mem}" for mem in memories)
297
+ return f"{context} {prompt}"
298
+
299
+
300
  def mismatch_tone(input_text, output_text):
301
  if "γ…‹γ…‹" in input_text and not re.search(r'γ…‹γ…‹|γ…Ž|재밌|놀|λ§Œλ‚˜|λ§›μ§‘|μ—¬ν–‰', output_text):
302
  return True
 
396
  except:
397
  return "계산할 수 μ—†λŠ” μˆ˜μ‹μ΄μ—μš”. λ‹€μ‹œ ν•œλ²ˆ 확인해 μ£Όμ„Έμš”!"
398
 
 
399
  def respond(input_text):
400
  intent = simple_intent_classifier(input_text)
401
 
 
415
  summary = summarize_from_wikipedia(keyword)
416
  return f"{summary}\nλ‹€λ₯Έ κΆκΈˆν•œ 점 μžˆμœΌμ‹ κ°€μš”?"
417
 
418
+ # βœ… κΈ°μ–΅ 기반 병합 μΆ”κ°€
419
+ related_memories = memory.retrieve(input_text, top_k=3)
420
+ merged_prompt = merge_prompt_with_memory(input_text, related_memories)
421
+
422
+ response = generate_text_sample(model, merged_prompt)
423
  if not is_valid_response(response) or mismatch_tone(input_text, response):
424
+ response = generate_text_sample(model, merged_prompt)
425
  return response
426
 
427
  @app.get("/generate", response_class=PlainTextResponse)