Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -203,77 +203,6 @@ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
|
|
203 |
decoded = decoded.replace(t, "")
|
204 |
return decoded.strip()
|
205 |
|
206 |
-
|
207 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
208 |
-
from sklearn.decomposition import TruncatedSVD
|
209 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
210 |
-
|
211 |
-
class SimilarityMemory:
|
212 |
-
def __init__(self, n_components=100):
|
213 |
-
self.memory_texts = []
|
214 |
-
self.vectorizer = TfidfVectorizer()
|
215 |
-
self.svd = TruncatedSVD(n_components=n_components)
|
216 |
-
self.embeddings = None
|
217 |
-
self.fitted = False
|
218 |
-
|
219 |
-
def add(self, text: str):
|
220 |
-
self.memory_texts.append(text)
|
221 |
-
self._update_embeddings()
|
222 |
-
|
223 |
-
def _update_embeddings(self):
|
224 |
-
if len(self.memory_texts) == 0:
|
225 |
-
self.embeddings = None
|
226 |
-
self.fitted = False
|
227 |
-
return
|
228 |
-
|
229 |
-
X = self.vectorizer.fit_transform(self.memory_texts)
|
230 |
-
n_comp = min(self.svd.n_components, X.shape[1], len(self.memory_texts)-1)
|
231 |
-
if n_comp <= 0:
|
232 |
-
self.embeddings = X.toarray()
|
233 |
-
self.fitted = True
|
234 |
-
return
|
235 |
-
|
236 |
-
self.svd = TruncatedSVD(n_components=n_comp)
|
237 |
-
self.embeddings = self.svd.fit_transform(X)
|
238 |
-
self.fitted = True
|
239 |
-
|
240 |
-
def retrieve(self, query: str, top_k=3):
|
241 |
-
if not self.fitted or self.embeddings is None or len(self.memory_texts) == 0:
|
242 |
-
return []
|
243 |
-
|
244 |
-
Xq = self.vectorizer.transform([query])
|
245 |
-
if self.svd.n_components > Xq.shape[1] or self.svd.n_components > len(self.memory_texts) - 1:
|
246 |
-
q_emb = Xq.toarray()
|
247 |
-
else:
|
248 |
-
q_emb = self.svd.transform(Xq)
|
249 |
-
|
250 |
-
sims = cosine_similarity(q_emb, self.embeddings)[0]
|
251 |
-
top_indices = sims.argsort()[::-1][:top_k]
|
252 |
-
|
253 |
-
return [self.memory_texts[i] for i in top_indices]
|
254 |
-
|
255 |
-
def process_input(self, new_text: str, top_k=3):
|
256 |
-
"""자동으로 기억 저장하고, 유사한 기억 찾아서 합친 프롬프트 생성"""
|
257 |
-
related_memories = self.retrieve(new_text, top_k=top_k)
|
258 |
-
self.add(new_text)
|
259 |
-
return self.merge_prompt(new_text, related_memories)
|
260 |
-
|
261 |
-
def merge_prompt(self, prompt: str, memories: list):
|
262 |
-
context = "\n".join(memories)
|
263 |
-
full_prompt = ""
|
264 |
-
if context:
|
265 |
-
full_prompt += f"기억:\n{context}\n\n"
|
266 |
-
full_prompt += f"현재 질문:\n{prompt}\n\n응답:"
|
267 |
-
return full_prompt
|
268 |
-
|
269 |
-
memory = SimilarityMemory()
|
270 |
-
|
271 |
-
with open("base_texts.txt", "r", encoding="utf-8") as f:
|
272 |
-
for line in f:
|
273 |
-
line = line.strip()
|
274 |
-
if line:
|
275 |
-
memory.add(line)
|
276 |
-
|
277 |
def mismatch_tone(input_text, output_text):
|
278 |
if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
|
279 |
return True
|
@@ -295,39 +224,30 @@ def is_valid_response(response):
|
|
295 |
|
296 |
|
297 |
def respond(input_text):
|
298 |
-
|
299 |
-
|
300 |
if "이름" in input_text:
|
301 |
response = "제 이름은 Flexi입니다."
|
302 |
-
memory.process_input(response)
|
303 |
return response
|
304 |
|
305 |
if "누구" in input_text:
|
306 |
response = "저는 Flexi라고 해요."
|
307 |
-
memory.process_input(response)
|
308 |
return response
|
309 |
|
310 |
-
|
311 |
-
|
312 |
|
313 |
for _ in range(3): # 최대 3번 재시도
|
314 |
-
full_response = generate_text_sample(model,
|
315 |
|
316 |
-
# 여기서 '응답:' 뒤의 텍스트만 뽑기
|
317 |
if "응답:" in full_response:
|
318 |
response = full_response.split("응답:")[-1].strip()
|
319 |
else:
|
320 |
response = full_response.strip()
|
321 |
|
322 |
if is_valid_response(response) and not mismatch_tone(input_text, response):
|
323 |
-
memory.process_input(response)
|
324 |
return response
|
325 |
|
326 |
-
|
327 |
-
fallback_response = "죄송해요, 제대로 답변을 못했어요."
|
328 |
-
memory.process_input(fallback_response)
|
329 |
-
return fallback_response
|
330 |
-
|
331 |
|
332 |
@app.get("/generate", response_class=PlainTextResponse)
|
333 |
async def generate(request: Request):
|
|
|
203 |
decoded = decoded.replace(t, "")
|
204 |
return decoded.strip()
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
def mismatch_tone(input_text, output_text):
|
207 |
if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
|
208 |
return True
|
|
|
224 |
|
225 |
|
226 |
def respond(input_text):
|
227 |
+
# 이름 관련 질문에 딱 반응하는 부분 유지
|
|
|
228 |
if "이름" in input_text:
|
229 |
response = "제 이름은 Flexi입니다."
|
|
|
230 |
return response
|
231 |
|
232 |
if "누구" in input_text:
|
233 |
response = "저는 Flexi라고 해요."
|
|
|
234 |
return response
|
235 |
|
236 |
+
# 메모리 관련 부분 싹 제거하고, 단순 프롬프트 생성
|
237 |
+
full_prompt = f"현재 질문:\n{input_text}\n\n응답:"
|
238 |
|
239 |
for _ in range(3): # 최대 3번 재시도
|
240 |
+
full_response = generate_text_sample(model, full_prompt)
|
241 |
|
|
|
242 |
if "응답:" in full_response:
|
243 |
response = full_response.split("응답:")[-1].strip()
|
244 |
else:
|
245 |
response = full_response.strip()
|
246 |
|
247 |
if is_valid_response(response) and not mismatch_tone(input_text, response):
|
|
|
248 |
return response
|
249 |
|
250 |
+
return "죄송해요, 제대로 답변을 못했어요."
|
|
|
|
|
|
|
|
|
251 |
|
252 |
@app.get("/generate", response_class=PlainTextResponse)
|
253 |
async def generate(request: Request):
|