Yuchan5386 commited on
Commit
1746a1f
Β·
verified Β·
1 Parent(s): 59a8836

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +50 -5
api.py CHANGED
@@ -10,10 +10,52 @@ import re
10
  import math
11
  from sklearn.feature_extraction.text import TfidfVectorizer
12
  from sklearn.metrics.pairwise import cosine_similarity
 
 
13
 
14
- app = FastAPI()
15
 
 
 
 
 
 
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  from fastapi.middleware.cors import CORSMiddleware
18
 
19
  origins = [
@@ -330,7 +372,6 @@ def parse_math_question(text):
330
  except:
331
  return "계산할 수 μ—†λŠ” μˆ˜μ‹μ΄μ—μš”. λ‹€μ‹œ ν•œλ²ˆ 확인해 μ£Όμ„Έμš”!"
332
 
333
- # μ΅œμ’… 응닡 ν•¨μˆ˜
334
  def respond(input_text):
335
  intent = simple_intent_classifier(input_text)
336
 
@@ -350,12 +391,16 @@ def respond(input_text):
350
  summary = summarize_from_wikipedia(keyword)
351
  return f"{summary}\nλ‹€λ₯Έ κΆκΈˆν•œ 점 μžˆμœΌμ‹ κ°€μš”?"
352
 
353
- # 일상 λŒ€ν™”: μƒ˜ν”Œλ§ + fallback
354
- response = generate_text_sample(model, input_text)
 
 
 
355
  if not is_valid_response(response) or mismatch_tone(input_text, response):
356
- response = generate_text_sample(model, input_text)
357
  return response
358
 
 
359
  @app.get("/generate", response_class=PlainTextResponse)
360
  async def generate(request: Request):
361
  prompt = request.query_params.get("prompt", "μ•ˆλ…•ν•˜μ„Έμš”")
 
10
  import math
11
  from sklearn.feature_extraction.text import TfidfVectorizer
12
  from sklearn.metrics.pairwise import cosine_similarity
13
+ from sentence_transformers import SentenceTransformer
14
+ import faiss
15
 
16
+ app = FastAPI()
17
 
18
+ class SimilarityMemory:
19
+ def __init__(self, embed_model='all-MiniLM-L6-v2'):
20
+ self.memory_texts = []
21
+ self.model = SentenceTransformer(embed_model)
22
+ self.index = None
23
+ self.vectors = []
24
 
25
+ def add(self, text: str):
26
+ vec = self.model.encode([text])[0]
27
+ self.memory_texts.append(text)
28
+ self.vectors.append(vec)
29
+ self._update_index()
30
+
31
+ def _update_index(self):
32
+ if not self.vectors:
33
+ return
34
+ vecs = np.vstack(self.vectors).astype('float32')
35
+ self.index = faiss.IndexFlatL2(vecs.shape[1])
36
+ self.index.add(vecs)
37
+
38
+ def retrieve(self, query: str, top_k=3):
39
+ if not self.vectors:
40
+ return []
41
+ q_vec = self.model.encode([query])[0].astype('float32').reshape(1, -1)
42
+ D, I = self.index.search(q_vec, top_k)
43
+ return [self.memory_texts[i] for i in I[0]]
44
+
45
+ memory = SimilarityMemory()
46
+
47
+ # μ˜ˆμ‹œ κΈ°μ–΅ μΆ”κ°€ (이건 운영 쀑에 μ‹€μ‹œκ°„μœΌλ‘œ κ°±μ‹  κ°€λŠ₯)
48
+ memory.add("μ–΄μ œλŠ” 기뢄이 λ³„λ‘œμ˜€μ–΄")
49
+ memory.add("μ˜ν™” 보러 κ°”λ‹€κ°€ μΉœκ΅¬λž‘ μ‹Έμ› μ–΄")
50
+ memory.add("μΉ΄νŽ˜μ—μ„œ κ³΅λΆ€ν–ˆλŠ”λ° 집쀑이 잘 됐어")
51
+
52
+
53
+ def merge_prompt_with_memory(prompt: str, memories: list):
54
+ context = "\n".join(f"κ³Όκ±°: {mem}" for mem in memories)
55
+ return f"{context}\nν˜„μž¬: {prompt}"
56
+
57
+
58
+
59
  from fastapi.middleware.cors import CORSMiddleware
60
 
61
  origins = [
 
372
  except:
373
  return "계산할 수 μ—†λŠ” μˆ˜μ‹μ΄μ—μš”. λ‹€μ‹œ ν•œλ²ˆ 확인해 μ£Όμ„Έμš”!"
374
 
 
375
  def respond(input_text):
376
  intent = simple_intent_classifier(input_text)
377
 
 
391
  summary = summarize_from_wikipedia(keyword)
392
  return f"{summary}\nλ‹€λ₯Έ κΆκΈˆν•œ 점 μžˆμœΌμ‹ κ°€μš”?"
393
 
394
+ # βœ… κΈ°μ–΅ 기반 병합 μΆ”κ°€
395
+ related_memories = memory.retrieve(input_text, top_k=3)
396
+ merged_prompt = merge_prompt_with_memory(input_text, related_memories)
397
+
398
+ response = generate_text_sample(model, merged_prompt)
399
  if not is_valid_response(response) or mismatch_tone(input_text, response):
400
+ response = generate_text_sample(model, merged_prompt)
401
  return response
402
 
403
+
404
  @app.get("/generate", response_class=PlainTextResponse)
405
  async def generate(request: Request):
406
  prompt = request.query_params.get("prompt", "μ•ˆλ…•ν•˜μ„Έμš”")