Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -10,10 +10,52 @@ import re
|
|
10 |
import math
|
11 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
12 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
|
13 |
|
14 |
-
app = FastAPI()
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
from fastapi.middleware.cors import CORSMiddleware
|
18 |
|
19 |
origins = [
|
@@ -330,7 +372,6 @@ def parse_math_question(text):
|
|
330 |
except:
|
331 |
return "κ³μ°ν μ μλ μμμ΄μμ. λ€μ νλ² νμΈν΄ μ£ΌμΈμ!"
|
332 |
|
333 |
-
# μ΅μ’
μλ΅ ν¨μ
|
334 |
def respond(input_text):
|
335 |
intent = simple_intent_classifier(input_text)
|
336 |
|
@@ -350,12 +391,16 @@ def respond(input_text):
|
|
350 |
summary = summarize_from_wikipedia(keyword)
|
351 |
return f"{summary}\nλ€λ₯Έ κΆκΈν μ μμΌμ κ°μ?"
|
352 |
|
353 |
-
#
|
354 |
-
|
|
|
|
|
|
|
355 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
356 |
-
response = generate_text_sample(model,
|
357 |
return response
|
358 |
|
|
|
359 |
@app.get("/generate", response_class=PlainTextResponse)
|
360 |
async def generate(request: Request):
|
361 |
prompt = request.query_params.get("prompt", "μλ
νμΈμ")
|
|
|
10 |
import math
|
11 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
12 |
from sklearn.metrics.pairwise import cosine_similarity
|
13 |
+
from sentence_transformers import SentenceTransformer
|
14 |
+
import faiss
|
15 |
|
16 |
+
app = FastAPI()
|
17 |
|
18 |
+
class SimilarityMemory:
|
19 |
+
def __init__(self, embed_model='all-MiniLM-L6-v2'):
|
20 |
+
self.memory_texts = []
|
21 |
+
self.model = SentenceTransformer(embed_model)
|
22 |
+
self.index = None
|
23 |
+
self.vectors = []
|
24 |
|
25 |
+
def add(self, text: str):
|
26 |
+
vec = self.model.encode([text])[0]
|
27 |
+
self.memory_texts.append(text)
|
28 |
+
self.vectors.append(vec)
|
29 |
+
self._update_index()
|
30 |
+
|
31 |
+
def _update_index(self):
|
32 |
+
if not self.vectors:
|
33 |
+
return
|
34 |
+
vecs = np.vstack(self.vectors).astype('float32')
|
35 |
+
self.index = faiss.IndexFlatL2(vecs.shape[1])
|
36 |
+
self.index.add(vecs)
|
37 |
+
|
38 |
+
def retrieve(self, query: str, top_k=3):
|
39 |
+
if not self.vectors:
|
40 |
+
return []
|
41 |
+
q_vec = self.model.encode([query])[0].astype('float32').reshape(1, -1)
|
42 |
+
D, I = self.index.search(q_vec, top_k)
|
43 |
+
return [self.memory_texts[i] for i in I[0]]
|
44 |
+
|
45 |
+
memory = SimilarityMemory()
|
46 |
+
|
47 |
+
# μμ κΈ°μ΅ μΆκ° (μ΄κ±΄ μ΄μ μ€μ μ€μκ°μΌλ‘ κ°±μ κ°λ₯)
|
48 |
+
memory.add("μ΄μ λ κΈ°λΆμ΄ λ³λ‘μμ΄")
|
49 |
+
memory.add("μν λ³΄λ¬ κ°λ€κ° μΉκ΅¬λ μΈμ μ΄")
|
50 |
+
memory.add("μΉ΄νμμ 곡λΆνλλ° μ§μ€μ΄ μ λμ΄")
|
51 |
+
|
52 |
+
|
53 |
+
def merge_prompt_with_memory(prompt: str, memories: list):
|
54 |
+
context = "\n".join(f"κ³Όκ±°: {mem}" for mem in memories)
|
55 |
+
return f"{context}\nνμ¬: {prompt}"
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
from fastapi.middleware.cors import CORSMiddleware
|
60 |
|
61 |
origins = [
|
|
|
372 |
except:
|
373 |
return "κ³μ°ν μ μλ μμμ΄μμ. λ€μ νλ² νμΈν΄ μ£ΌμΈμ!"
|
374 |
|
|
|
375 |
def respond(input_text):
|
376 |
intent = simple_intent_classifier(input_text)
|
377 |
|
|
|
391 |
summary = summarize_from_wikipedia(keyword)
|
392 |
return f"{summary}\nλ€λ₯Έ κΆκΈν μ μμΌμ κ°μ?"
|
393 |
|
394 |
+
# β
κΈ°μ΅ κΈ°λ° λ³ν© μΆκ°
|
395 |
+
related_memories = memory.retrieve(input_text, top_k=3)
|
396 |
+
merged_prompt = merge_prompt_with_memory(input_text, related_memories)
|
397 |
+
|
398 |
+
response = generate_text_sample(model, merged_prompt)
|
399 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
400 |
+
response = generate_text_sample(model, merged_prompt)
|
401 |
return response
|
402 |
|
403 |
+
|
404 |
@app.get("/generate", response_class=PlainTextResponse)
|
405 |
async def generate(request: Request):
|
406 |
prompt = request.query_params.get("prompt", "μλ
νμΈμ")
|