Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -228,7 +228,75 @@ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
|
|
228 |
for t in ["<start>", "<sep>", "<end>"]:
|
229 |
decoded = decoded.replace(t, "")
|
230 |
return decoded.strip()
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
def mismatch_tone(input_text, output_text):
|
233 |
if "γ
γ
" in input_text and not re.search(r'γ
γ
|γ
|μ¬λ°|λ|λ§λ|λ§μ§|μ¬ν', output_text):
|
234 |
return True
|
@@ -328,7 +396,6 @@ def parse_math_question(text):
|
|
328 |
except:
|
329 |
return "κ³μ°ν μ μλ μμμ΄μμ. λ€μ νλ² νμΈν΄ μ£ΌμΈμ!"
|
330 |
|
331 |
-
# μ΅μ’
μλ΅ ν¨μ
|
332 |
def respond(input_text):
|
333 |
intent = simple_intent_classifier(input_text)
|
334 |
|
@@ -348,10 +415,13 @@ def respond(input_text):
|
|
348 |
summary = summarize_from_wikipedia(keyword)
|
349 |
return f"{summary}\nλ€λ₯Έ κΆκΈν μ μμΌμ κ°μ?"
|
350 |
|
351 |
-
#
|
352 |
-
|
|
|
|
|
|
|
353 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
354 |
-
response = generate_text_sample(model,
|
355 |
return response
|
356 |
|
357 |
@app.get("/generate", response_class=PlainTextResponse)
|
|
|
228 |
for t in ["<start>", "<sep>", "<end>"]:
|
229 |
decoded = decoded.replace(t, "")
|
230 |
return decoded.strip()
|
231 |
+
|
232 |
+
import numpy as np
|
233 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
234 |
+
from sklearn.decomposition import TruncatedSVD
|
235 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
236 |
+
|
237 |
+
class SimilarityMemory:
|
238 |
+
def __init__(self, n_components=100):
|
239 |
+
self.memory_texts = []
|
240 |
+
self.vectorizer = TfidfVectorizer()
|
241 |
+
self.svd = TruncatedSVD(n_components=n_components)
|
242 |
+
self.embeddings = None
|
243 |
+
self.fitted = False
|
244 |
+
|
245 |
+
def add(self, text: str):
|
246 |
+
self.memory_texts.append(text)
|
247 |
+
self._update_embeddings()
|
248 |
+
|
249 |
+
def _update_embeddings(self):
|
250 |
+
# ν
μ€νΈκ° 1κ° μ΄μμΌ λλ§ λ²‘ν°ν λ° μ°¨μ μΆμ μ§ν
|
251 |
+
if len(self.memory_texts) == 0:
|
252 |
+
self.embeddings = None
|
253 |
+
self.fitted = False
|
254 |
+
return
|
255 |
+
|
256 |
+
# 벑ν°ν
|
257 |
+
X = self.vectorizer.fit_transform(self.memory_texts)
|
258 |
+
|
259 |
+
# μ°¨μ μΆμ
|
260 |
+
n_comp = min(self.svd.n_components, X.shape[1], len(self.memory_texts)-1)
|
261 |
+
if n_comp <= 0:
|
262 |
+
# μΆμν μ°¨μμ΄ μμΌλ©΄ κ·Έλ₯ TF-IDF λ²‘ν° μ¬μ©
|
263 |
+
self.embeddings = X.toarray()
|
264 |
+
self.fitted = True
|
265 |
+
return
|
266 |
+
|
267 |
+
self.svd = TruncatedSVD(n_components=n_comp)
|
268 |
+
self.embeddings = self.svd.fit_transform(X)
|
269 |
+
self.fitted = True
|
270 |
+
|
271 |
+
def retrieve(self, query: str, top_k=3):
|
272 |
+
if not self.fitted or self.embeddings is None or len(self.memory_texts) == 0:
|
273 |
+
return []
|
274 |
+
|
275 |
+
# 쿼리 벑ν°ν + μ°¨μ μΆμ (fit_transformμ΄ μλλΌ transform ν΄μΌ ν¨)
|
276 |
+
Xq = self.vectorizer.transform([query])
|
277 |
+
if self.svd.n_components > Xq.shape[1] or self.svd.n_components > len(self.memory_texts) - 1:
|
278 |
+
q_emb = Xq.toarray()
|
279 |
+
else:
|
280 |
+
q_emb = self.svd.transform(Xq)
|
281 |
+
|
282 |
+
# μ½μ¬μΈ μ μ¬λ κ³μ°
|
283 |
+
sims = cosine_similarity(q_emb, self.embeddings)[0]
|
284 |
+
top_indices = sims.argsort()[::-1][:top_k]
|
285 |
+
|
286 |
+
return [self.memory_texts[i] for i in top_indices]
|
287 |
+
|
288 |
+
# ν
μ€νΈ
|
289 |
+
memory = SimilarityMemory()
|
290 |
+
|
291 |
+
memory.add("μ΄μ λ κΈ°λΆμ΄ λ³λ‘μμ΄")
|
292 |
+
memory.add("μν λ³΄λ¬ κ°λ€κ° μΉκ΅¬λ μΈμ μ΄")
|
293 |
+
memory.add("μΉ΄νμμ 곡λΆνλλ° μ§μ€μ΄ μ λμ΄")
|
294 |
+
|
295 |
+
def merge_prompt_with_memory(prompt: str, memories: list):
|
296 |
+
context = "\n".join(f"κ³Όκ±°: {mem}" for mem in memories)
|
297 |
+
return f"{context} {prompt}"
|
298 |
+
|
299 |
+
|
300 |
def mismatch_tone(input_text, output_text):
|
301 |
if "γ
γ
" in input_text and not re.search(r'γ
γ
|γ
|μ¬λ°|λ|λ§λ|λ§μ§|μ¬ν', output_text):
|
302 |
return True
|
|
|
396 |
except:
|
397 |
return "κ³μ°ν μ μλ μμμ΄μμ. λ€μ νλ² νμΈν΄ μ£ΌμΈμ!"
|
398 |
|
|
|
399 |
def respond(input_text):
|
400 |
intent = simple_intent_classifier(input_text)
|
401 |
|
|
|
415 |
summary = summarize_from_wikipedia(keyword)
|
416 |
return f"{summary}\nλ€λ₯Έ κΆκΈν μ μμΌμ κ°μ?"
|
417 |
|
418 |
+
# β
κΈ°μ΅ κΈ°λ° λ³ν© μΆκ°
|
419 |
+
related_memories = memory.retrieve(input_text, top_k=3)
|
420 |
+
merged_prompt = merge_prompt_with_memory(input_text, related_memories)
|
421 |
+
|
422 |
+
response = generate_text_sample(model, merged_prompt)
|
423 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
424 |
+
response = generate_text_sample(model, merged_prompt)
|
425 |
return response
|
426 |
|
427 |
@app.get("/generate", response_class=PlainTextResponse)
|