Yuchan5386 commited on
Commit
5226c06
ยท
verified ยท
1 Parent(s): 004f427

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +20 -19
api.py CHANGED
@@ -229,7 +229,6 @@ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
229
  decoded = decoded.replace(t, "")
230
  return decoded.strip()
231
 
232
- import numpy as np
233
  from sklearn.feature_extraction.text import TfidfVectorizer
234
  from sklearn.decomposition import TruncatedSVD
235
  from sklearn.metrics.pairwise import cosine_similarity
@@ -238,9 +237,9 @@ class SimilarityMemory:
238
  def __init__(self, n_components=100):
239
  self.memory_texts = []
240
  self.vectorizer = TfidfVectorizer()
241
- self.svd = None
242
  self.embeddings = None
243
- self.n_components = n_components
244
 
245
  def add(self, text: str):
246
  self.memory_texts.append(text)
@@ -249,42 +248,44 @@ class SimilarityMemory:
249
  def _update_embeddings(self):
250
  if len(self.memory_texts) == 0:
251
  self.embeddings = None
252
- self.svd = None
253
  return
254
 
255
  X = self.vectorizer.fit_transform(self.memory_texts)
256
-
257
- n_comp = min(self.n_components, X.shape[1], len(self.memory_texts) - 1)
258
  if n_comp <= 0:
259
  self.embeddings = X.toarray()
260
- self.svd = None
261
  return
262
 
263
- svd = TruncatedSVD(n_components=n_comp)
264
- self.embeddings = svd.fit_transform(X)
265
- self.svd = svd
266
 
267
  def retrieve(self, query: str, top_k=3):
268
- if self.embeddings is None:
269
  return []
270
 
271
  Xq = self.vectorizer.transform([query])
272
- if self.svd is not None:
273
- q_emb = self.svd.transform(Xq)
274
- else:
275
  q_emb = Xq.toarray()
 
 
276
 
277
  sims = cosine_similarity(q_emb, self.embeddings)[0]
278
  top_indices = sims.argsort()[::-1][:top_k]
279
 
280
  return [self.memory_texts[i] for i in top_indices]
281
 
282
- # ํ…Œ์ŠคํŠธ
283
- memory = SimilarityMemory()
 
 
 
284
 
285
- def merge_prompt_with_memory(prompt: str, memories: list):
286
- context = "\n".join(memories)
287
- return f"{context}\n{prompt}"
288
 
289
 
290
  def mismatch_tone(input_text, output_text):
 
229
  decoded = decoded.replace(t, "")
230
  return decoded.strip()
231
 
 
232
  from sklearn.feature_extraction.text import TfidfVectorizer
233
  from sklearn.decomposition import TruncatedSVD
234
  from sklearn.metrics.pairwise import cosine_similarity
 
237
  def __init__(self, n_components=100):
238
  self.memory_texts = []
239
  self.vectorizer = TfidfVectorizer()
240
+ self.svd = TruncatedSVD(n_components=n_components)
241
  self.embeddings = None
242
+ self.fitted = False
243
 
244
  def add(self, text: str):
245
  self.memory_texts.append(text)
 
248
  def _update_embeddings(self):
249
  if len(self.memory_texts) == 0:
250
  self.embeddings = None
251
+ self.fitted = False
252
  return
253
 
254
  X = self.vectorizer.fit_transform(self.memory_texts)
255
+ n_comp = min(self.svd.n_components, X.shape[1], len(self.memory_texts)-1)
 
256
  if n_comp <= 0:
257
  self.embeddings = X.toarray()
258
+ self.fitted = True
259
  return
260
 
261
+ self.svd = TruncatedSVD(n_components=n_comp)
262
+ self.embeddings = self.svd.fit_transform(X)
263
+ self.fitted = True
264
 
265
  def retrieve(self, query: str, top_k=3):
266
+ if not self.fitted or self.embeddings is None or len(self.memory_texts) == 0:
267
  return []
268
 
269
  Xq = self.vectorizer.transform([query])
270
+ if self.svd.n_components > Xq.shape[1] or self.svd.n_components > len(self.memory_texts) - 1:
 
 
271
  q_emb = Xq.toarray()
272
+ else:
273
+ q_emb = self.svd.transform(Xq)
274
 
275
  sims = cosine_similarity(q_emb, self.embeddings)[0]
276
  top_indices = sims.argsort()[::-1][:top_k]
277
 
278
  return [self.memory_texts[i] for i in top_indices]
279
 
280
+ def process_input(self, new_text: str, top_k=3):
281
+ """์ž๋™์œผ๋กœ ๊ธฐ์–ต ์ €์žฅํ•˜๊ณ , ์œ ์‚ฌํ•œ ๊ธฐ์–ต ์ฐพ์•„์„œ ํ•ฉ์นœ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ"""
282
+ related_memories = self.retrieve(new_text, top_k=top_k)
283
+ self.add(new_text)
284
+ return self.merge_prompt(new_text, related_memories)
285
 
286
+ def merge_prompt(self, prompt: str, memories: list):
287
+ context = "\n".join(memories)
288
+ return f"{context}\n\n{prompt}" if context else prompt
289
 
290
 
291
  def mismatch_tone(input_text, output_text):