Yuchan5386 commited on
Commit
004f427
ยท
verified ยท
1 Parent(s): df8ae21

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +15 -20
api.py CHANGED
@@ -238,48 +238,42 @@ class SimilarityMemory:
238
  def __init__(self, n_components=100):
239
  self.memory_texts = []
240
  self.vectorizer = TfidfVectorizer()
241
- self.svd = TruncatedSVD(n_components=n_components)
242
  self.embeddings = None
243
- self.fitted = False
244
 
245
  def add(self, text: str):
246
  self.memory_texts.append(text)
247
  self._update_embeddings()
248
 
249
  def _update_embeddings(self):
250
- # ํ…์ŠคํŠธ๊ฐ€ 1๊ฐœ ์ด์ƒ์ผ ๋•Œ๋งŒ ๋ฒกํ„ฐํ™” ๋ฐ ์ฐจ์› ์ถ•์†Œ ์ง„ํ–‰
251
  if len(self.memory_texts) == 0:
252
  self.embeddings = None
253
- self.fitted = False
254
  return
255
 
256
- # ๋ฒกํ„ฐํ™”
257
  X = self.vectorizer.fit_transform(self.memory_texts)
258
 
259
- # ์ฐจ์› ์ถ•์†Œ
260
- n_comp = min(self.svd.n_components, X.shape[1], len(self.memory_texts)-1)
261
  if n_comp <= 0:
262
- # ์ถ•์†Œํ•  ์ฐจ์›์ด ์—†์œผ๋ฉด ๊ทธ๋ƒฅ TF-IDF ๋ฒกํ„ฐ ์‚ฌ์šฉ
263
  self.embeddings = X.toarray()
264
- self.fitted = True
265
  return
266
 
267
- self.svd = TruncatedSVD(n_components=n_comp)
268
- self.embeddings = self.svd.fit_transform(X)
269
- self.fitted = True
270
 
271
  def retrieve(self, query: str, top_k=3):
272
- if not self.fitted or self.embeddings is None or len(self.memory_texts) == 0:
273
  return []
274
 
275
- # ์ฟผ๋ฆฌ ๋ฒกํ„ฐํ™” + ์ฐจ์› ์ถ•์†Œ (fit_transform์ด ์•„๋‹ˆ๋ผ transform ํ•ด์•ผ ํ•จ)
276
  Xq = self.vectorizer.transform([query])
277
- if self.svd.n_components > Xq.shape[1] or self.svd.n_components > len(self.memory_texts) - 1:
278
- q_emb = Xq.toarray()
279
- else:
280
  q_emb = self.svd.transform(Xq)
 
 
281
 
282
- # ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
283
  sims = cosine_similarity(q_emb, self.embeddings)[0]
284
  top_indices = sims.argsort()[::-1][:top_k]
285
 
@@ -289,8 +283,9 @@ class SimilarityMemory:
289
  memory = SimilarityMemory()
290
 
291
  def merge_prompt_with_memory(prompt: str, memories: list):
292
- context = "\n".join(f"{mem}" for mem in memories)
293
- return f"{context} {prompt}"
 
294
 
295
  def mismatch_tone(input_text, output_text):
296
  if "ใ…‹ใ…‹" in input_text and not re.search(r'ใ…‹ใ…‹|ใ…Ž|์žฌ๋ฐŒ|๋†€|๋งŒ๋‚˜|๋ง›์ง‘|์—ฌํ–‰', output_text):
 
238
  def __init__(self, n_components=100):
239
  self.memory_texts = []
240
  self.vectorizer = TfidfVectorizer()
241
+ self.svd = None
242
  self.embeddings = None
243
+ self.n_components = n_components
244
 
245
  def add(self, text: str):
246
  self.memory_texts.append(text)
247
  self._update_embeddings()
248
 
249
  def _update_embeddings(self):
 
250
  if len(self.memory_texts) == 0:
251
  self.embeddings = None
252
+ self.svd = None
253
  return
254
 
 
255
  X = self.vectorizer.fit_transform(self.memory_texts)
256
 
257
+ n_comp = min(self.n_components, X.shape[1], len(self.memory_texts) - 1)
 
258
  if n_comp <= 0:
 
259
  self.embeddings = X.toarray()
260
+ self.svd = None
261
  return
262
 
263
+ svd = TruncatedSVD(n_components=n_comp)
264
+ self.embeddings = svd.fit_transform(X)
265
+ self.svd = svd
266
 
267
  def retrieve(self, query: str, top_k=3):
268
+ if self.embeddings is None:
269
  return []
270
 
 
271
  Xq = self.vectorizer.transform([query])
272
+ if self.svd is not None:
 
 
273
  q_emb = self.svd.transform(Xq)
274
+ else:
275
+ q_emb = Xq.toarray()
276
 
 
277
  sims = cosine_similarity(q_emb, self.embeddings)[0]
278
  top_indices = sims.argsort()[::-1][:top_k]
279
 
 
283
  memory = SimilarityMemory()
284
 
285
  def merge_prompt_with_memory(prompt: str, memories: list):
286
+ context = "\n".join(memories)
287
+ return f"{context}\n{prompt}"
288
+
289
 
290
  def mismatch_tone(input_text, output_text):
291
  if "ใ…‹ใ…‹" in input_text and not re.search(r'ใ…‹ใ…‹|ใ…Ž|์žฌ๋ฐŒ|๋†€|๋งŒ๋‚˜|๋ง›์ง‘|์—ฌํ–‰', output_text):