Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -229,7 +229,6 @@ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
|
|
229 |
decoded = decoded.replace(t, "")
|
230 |
return decoded.strip()
|
231 |
|
232 |
-
import numpy as np
|
233 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
234 |
from sklearn.decomposition import TruncatedSVD
|
235 |
from sklearn.metrics.pairwise import cosine_similarity
|
@@ -238,9 +237,9 @@ class SimilarityMemory:
|
|
238 |
def __init__(self, n_components=100):
|
239 |
self.memory_texts = []
|
240 |
self.vectorizer = TfidfVectorizer()
|
241 |
-
self.svd =
|
242 |
self.embeddings = None
|
243 |
-
self.
|
244 |
|
245 |
def add(self, text: str):
|
246 |
self.memory_texts.append(text)
|
@@ -249,42 +248,44 @@ class SimilarityMemory:
|
|
249 |
def _update_embeddings(self):
|
250 |
if len(self.memory_texts) == 0:
|
251 |
self.embeddings = None
|
252 |
-
self.
|
253 |
return
|
254 |
|
255 |
X = self.vectorizer.fit_transform(self.memory_texts)
|
256 |
-
|
257 |
-
n_comp = min(self.n_components, X.shape[1], len(self.memory_texts) - 1)
|
258 |
if n_comp <= 0:
|
259 |
self.embeddings = X.toarray()
|
260 |
-
self.
|
261 |
return
|
262 |
|
263 |
-
svd = TruncatedSVD(n_components=n_comp)
|
264 |
-
self.embeddings = svd.fit_transform(X)
|
265 |
-
self.
|
266 |
|
267 |
def retrieve(self, query: str, top_k=3):
|
268 |
-
if self.embeddings is None:
|
269 |
return []
|
270 |
|
271 |
Xq = self.vectorizer.transform([query])
|
272 |
-
if self.svd
|
273 |
-
q_emb = self.svd.transform(Xq)
|
274 |
-
else:
|
275 |
q_emb = Xq.toarray()
|
|
|
|
|
276 |
|
277 |
sims = cosine_similarity(q_emb, self.embeddings)[0]
|
278 |
top_indices = sims.argsort()[::-1][:top_k]
|
279 |
|
280 |
return [self.memory_texts[i] for i in top_indices]
|
281 |
|
282 |
-
|
283 |
-
|
|
|
|
|
|
|
284 |
|
285 |
-
def
|
286 |
-
|
287 |
-
|
288 |
|
289 |
|
290 |
def mismatch_tone(input_text, output_text):
|
|
|
229 |
decoded = decoded.replace(t, "")
|
230 |
return decoded.strip()
|
231 |
|
|
|
232 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
233 |
from sklearn.decomposition import TruncatedSVD
|
234 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
237 |
def __init__(self, n_components=100):
|
238 |
self.memory_texts = []
|
239 |
self.vectorizer = TfidfVectorizer()
|
240 |
+
self.svd = TruncatedSVD(n_components=n_components)
|
241 |
self.embeddings = None
|
242 |
+
self.fitted = False
|
243 |
|
244 |
def add(self, text: str):
|
245 |
self.memory_texts.append(text)
|
|
|
248 |
def _update_embeddings(self):
|
249 |
if len(self.memory_texts) == 0:
|
250 |
self.embeddings = None
|
251 |
+
self.fitted = False
|
252 |
return
|
253 |
|
254 |
X = self.vectorizer.fit_transform(self.memory_texts)
|
255 |
+
n_comp = min(self.svd.n_components, X.shape[1], len(self.memory_texts)-1)
|
|
|
256 |
if n_comp <= 0:
|
257 |
self.embeddings = X.toarray()
|
258 |
+
self.fitted = True
|
259 |
return
|
260 |
|
261 |
+
self.svd = TruncatedSVD(n_components=n_comp)
|
262 |
+
self.embeddings = self.svd.fit_transform(X)
|
263 |
+
self.fitted = True
|
264 |
|
265 |
def retrieve(self, query: str, top_k=3):
|
266 |
+
if not self.fitted or self.embeddings is None or len(self.memory_texts) == 0:
|
267 |
return []
|
268 |
|
269 |
Xq = self.vectorizer.transform([query])
|
270 |
+
if self.svd.n_components > Xq.shape[1] or self.svd.n_components > len(self.memory_texts) - 1:
|
|
|
|
|
271 |
q_emb = Xq.toarray()
|
272 |
+
else:
|
273 |
+
q_emb = self.svd.transform(Xq)
|
274 |
|
275 |
sims = cosine_similarity(q_emb, self.embeddings)[0]
|
276 |
top_indices = sims.argsort()[::-1][:top_k]
|
277 |
|
278 |
return [self.memory_texts[i] for i in top_indices]
|
279 |
|
280 |
+
def process_input(self, new_text: str, top_k=3):
|
281 |
+
"""์๋์ผ๋ก ๊ธฐ์ต ์ ์ฅํ๊ณ , ์ ์ฌํ ๊ธฐ์ต ์ฐพ์์ ํฉ์น ํ๋กฌํํธ ์์ฑ"""
|
282 |
+
related_memories = self.retrieve(new_text, top_k=top_k)
|
283 |
+
self.add(new_text)
|
284 |
+
return self.merge_prompt(new_text, related_memories)
|
285 |
|
286 |
+
def merge_prompt(self, prompt: str, memories: list):
|
287 |
+
context = "\n".join(memories)
|
288 |
+
return f"{context}\n\n{prompt}" if context else prompt
|
289 |
|
290 |
|
291 |
def mismatch_tone(input_text, output_text):
|