Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -140,19 +140,28 @@ _ = model(dummy_input) # 모델이 빌드됨
|
|
140 |
model.load_weights("InteractGPT.weights.h5")
|
141 |
print("모델 가중치 로드 완료!")
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
def generate_text_topkp(model, prompt, max_len=100, max_gen=98,
|
144 |
-
temperature=0.
|
145 |
repetition_penalty=1.2, top_p=0.90, top_k=50):
|
146 |
def top_kp_filtering(logits, top_k, top_p):
|
147 |
probs = np.exp(logits - np.max(logits))
|
148 |
probs /= probs.sum()
|
149 |
sorted_idx = np.argsort(-probs)
|
150 |
sorted_probs = probs[sorted_idx]
|
151 |
-
# Top-K 필터링
|
152 |
if top_k > 0:
|
153 |
sorted_idx = sorted_idx[:top_k]
|
154 |
sorted_probs = sorted_probs[:top_k]
|
155 |
-
# Top-P 필터링
|
156 |
cum_probs = np.cumsum(sorted_probs)
|
157 |
cutoff = np.searchsorted(cum_probs, top_p) + 1
|
158 |
final_idx = sorted_idx[:cutoff]
|
@@ -169,17 +178,13 @@ def generate_text_topkp(model, prompt, max_len=100, max_gen=98,
|
|
169 |
input_tensor = tf.convert_to_tensor([input_padded])
|
170 |
logits = model(input_tensor, training=False)
|
171 |
next_logits = logits[0, len(generated) - 1].numpy()
|
172 |
-
# 반복 억제
|
173 |
for t in set(generated):
|
174 |
count = generated.count(t)
|
175 |
next_logits[t] /= (repetition_penalty ** count)
|
176 |
-
# 조기 종료 방지
|
177 |
if len(generated) < min_len:
|
178 |
next_logits[end_id] -= 5.0
|
179 |
next_logits[pad_id] -= 10.0
|
180 |
-
# 온도 적용
|
181 |
next_logits = next_logits / temperature
|
182 |
-
# Top-KP Sampling 적용
|
183 |
final_idx, final_probs = top_kp_filtering(next_logits, top_k=top_k, top_p=top_p)
|
184 |
sampled = np.random.choice(final_idx, p=final_probs)
|
185 |
generated.append(int(sampled))
|
@@ -189,7 +194,37 @@ def generate_text_topkp(model, prompt, max_len=100, max_gen=98,
|
|
189 |
decoded = decoded.strip()
|
190 |
if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
|
191 |
return decoded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
|
|
|
193 |
def is_valid_response(response):
|
194 |
if len(response.strip()) < 2:
|
195 |
return False
|
@@ -203,6 +238,7 @@ def is_valid_response(response):
|
|
203 |
return False
|
204 |
return True
|
205 |
|
|
|
206 |
def extract_main_query(text):
|
207 |
sentences = re.split(r'[.?!]\s*', text)
|
208 |
sentences = [s.strip() for s in sentences if s.strip()]
|
@@ -215,13 +251,6 @@ def extract_main_query(text):
|
|
215 |
last = re.sub(rf'\b(\w+){p}\b', r'\1', last)
|
216 |
return last.strip()
|
217 |
|
218 |
-
import re
|
219 |
-
import requests
|
220 |
-
import numpy as np
|
221 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
222 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
223 |
-
|
224 |
-
# 2. 위키백과 요약 가져오기
|
225 |
def get_wikipedia_summary(query):
|
226 |
cleaned_query = extract_main_query(query)
|
227 |
url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}"
|
@@ -231,24 +260,20 @@ def get_wikipedia_summary(query):
|
|
231 |
else:
|
232 |
return "위키백과에서 정보를 가져올 수 없습니다."
|
233 |
|
234 |
-
# 3. TextRank 요약기
|
235 |
def textrank_summarize(text, top_n=3):
|
236 |
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
|
237 |
sentences = [s.strip() for s in sentences if len(s.strip()) > 10]
|
238 |
-
|
239 |
if len(sentences) <= top_n:
|
240 |
-
return text
|
241 |
-
|
242 |
vectorizer = TfidfVectorizer()
|
243 |
tfidf_matrix = vectorizer.fit_transform(sentences)
|
244 |
sim_matrix = cosine_similarity(tfidf_matrix)
|
245 |
np.fill_diagonal(sim_matrix, 0)
|
246 |
-
|
247 |
def pagerank(matrix, damping=0.85, max_iter=100, tol=1e-4):
|
248 |
N = matrix.shape[0]
|
249 |
ranks = np.ones(N) / N
|
250 |
row_sums = np.sum(matrix, axis=1)
|
251 |
-
row_sums[row_sums == 0] = 1
|
252 |
for _ in range(max_iter):
|
253 |
prev_ranks = ranks.copy()
|
254 |
for i in range(N):
|
@@ -257,25 +282,22 @@ def textrank_summarize(text, top_n=3):
|
|
257 |
if np.linalg.norm(ranks - prev_ranks) < tol:
|
258 |
break
|
259 |
return ranks
|
260 |
-
|
261 |
scores = pagerank(sim_matrix)
|
262 |
ranked_idx = np.argsort(scores)[::-1]
|
263 |
selected_idx = sorted(ranked_idx[:top_n])
|
264 |
summary = ' '.join([sentences[i] for i in selected_idx])
|
265 |
-
|
266 |
return summary
|
267 |
|
268 |
-
# 4. 전체 파이프라인
|
269 |
def summarize_from_wikipedia(query, top_n=3):
|
270 |
raw_summary = get_wikipedia_summary(query)
|
271 |
return textrank_summarize(raw_summary, top_n=top_n)
|
272 |
|
|
|
273 |
def simple_intent_classifier(text):
|
274 |
text = text.lower()
|
275 |
greet_keywords = ["안녕", "반가워", "이름", "누구", "소개", "어디서 왔", "정체", "몇 살", "너 뭐야"]
|
276 |
info_keywords = ["설명", "정보", "무엇", "뭐야", "어디", "누구", "왜", "어떻게", "종류", "개념"]
|
277 |
math_keywords = ["더하기", "빼기", "곱하기", "나누기", "루트", "제곱", "+", "-", "*", "/", "=", "^", "√", "계산", "몇이야", "얼마야"]
|
278 |
-
|
279 |
if any(kw in text for kw in greet_keywords):
|
280 |
return "인사"
|
281 |
elif any(kw in text for kw in info_keywords):
|
@@ -294,7 +316,7 @@ def parse_math_question(text):
|
|
294 |
except:
|
295 |
return "계산할 수 없는 수식이에요. 다시 한번 확인해 주세요!"
|
296 |
|
297 |
-
#
|
298 |
def respond(input_text):
|
299 |
intent = simple_intent_classifier(input_text)
|
300 |
|
@@ -309,7 +331,7 @@ def respond(input_text):
|
|
309 |
|
310 |
if intent == "인사":
|
311 |
return "반가워요! 무엇을 도와드릴까요?"
|
312 |
-
|
313 |
if intent == "정보질문":
|
314 |
keyword = re.sub(r"(에 대해|에 대한|에 대해서)?\s*(설명해줘|알려줘|뭐야|개념|정의|정보)?", "", input_text).strip()
|
315 |
if not keyword:
|
@@ -317,16 +339,11 @@ def respond(input_text):
|
|
317 |
summary = summarize_from_wikipedia(keyword)
|
318 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
319 |
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
yield text_piece
|
326 |
-
await asyncio.sleep(0.1)
|
327 |
-
|
328 |
-
|
329 |
-
from fastapi.responses import PlainTextResponse
|
330 |
|
331 |
@app.get("/generate", response_class=PlainTextResponse)
|
332 |
async def generate(request: Request):
|
|
|
140 |
model.load_weights("InteractGPT.weights.h5")
|
141 |
print("모델 가중치 로드 완료!")
|
142 |
|
143 |
+
import re
|
144 |
+
import math
|
145 |
+
import numpy as np
|
146 |
+
import requests
|
147 |
+
import tensorflow as tf
|
148 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
149 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
150 |
+
from fastapi import Request
|
151 |
+
from fastapi.responses import PlainTextResponse
|
152 |
+
|
153 |
+
# 1. Top-KP 기반 생성기
|
154 |
def generate_text_topkp(model, prompt, max_len=100, max_gen=98,
|
155 |
+
temperature=0.90, min_len=20,
|
156 |
repetition_penalty=1.2, top_p=0.90, top_k=50):
|
157 |
def top_kp_filtering(logits, top_k, top_p):
|
158 |
probs = np.exp(logits - np.max(logits))
|
159 |
probs /= probs.sum()
|
160 |
sorted_idx = np.argsort(-probs)
|
161 |
sorted_probs = probs[sorted_idx]
|
|
|
162 |
if top_k > 0:
|
163 |
sorted_idx = sorted_idx[:top_k]
|
164 |
sorted_probs = sorted_probs[:top_k]
|
|
|
165 |
cum_probs = np.cumsum(sorted_probs)
|
166 |
cutoff = np.searchsorted(cum_probs, top_p) + 1
|
167 |
final_idx = sorted_idx[:cutoff]
|
|
|
178 |
input_tensor = tf.convert_to_tensor([input_padded])
|
179 |
logits = model(input_tensor, training=False)
|
180 |
next_logits = logits[0, len(generated) - 1].numpy()
|
|
|
181 |
for t in set(generated):
|
182 |
count = generated.count(t)
|
183 |
next_logits[t] /= (repetition_penalty ** count)
|
|
|
184 |
if len(generated) < min_len:
|
185 |
next_logits[end_id] -= 5.0
|
186 |
next_logits[pad_id] -= 10.0
|
|
|
187 |
next_logits = next_logits / temperature
|
|
|
188 |
final_idx, final_probs = top_kp_filtering(next_logits, top_k=top_k, top_p=top_p)
|
189 |
sampled = np.random.choice(final_idx, p=final_probs)
|
190 |
generated.append(int(sampled))
|
|
|
194 |
decoded = decoded.strip()
|
195 |
if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
|
196 |
return decoded
|
197 |
+
return sp.decode(generated)
|
198 |
+
|
199 |
+
# Greedy 버전 생성기
|
200 |
+
def generate_text_greedy(model, prompt, max_len=100, max_gen=98):
|
201 |
+
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
202 |
+
model_input = model_input[:max_len]
|
203 |
+
generated = list(model_input)
|
204 |
+
for _ in range(max_gen):
|
205 |
+
pad_len = max(0, max_len - len(generated))
|
206 |
+
input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
|
207 |
+
input_tensor = tf.convert_to_tensor([input_padded])
|
208 |
+
logits = model(input_tensor, training=False)
|
209 |
+
next_logits = logits[0, len(generated) - 1].numpy()
|
210 |
+
next_logits[pad_id] -= 10.0
|
211 |
+
next_token = np.argmax(next_logits)
|
212 |
+
generated.append(int(next_token))
|
213 |
+
decoded = sp.decode(generated)
|
214 |
+
for t in ["<start>", "<sep>", "<end>"]:
|
215 |
+
decoded = decoded.replace(t, "")
|
216 |
+
decoded = decoded.strip()
|
217 |
+
if next_token == end_id or decoded.endswith(('.', '!', '?')):
|
218 |
+
return decoded
|
219 |
+
return sp.decode(generated)
|
220 |
+
|
221 |
+
# 톤 불일치 체크
|
222 |
+
def mismatch_tone(input_text, output_text):
|
223 |
+
if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
|
224 |
+
return True
|
225 |
+
return False
|
226 |
|
227 |
+
# 유효한 응답인지 검사
|
228 |
def is_valid_response(response):
|
229 |
if len(response.strip()) < 2:
|
230 |
return False
|
|
|
238 |
return False
|
239 |
return True
|
240 |
|
241 |
+
# 위키 요약 관련
|
242 |
def extract_main_query(text):
|
243 |
sentences = re.split(r'[.?!]\s*', text)
|
244 |
sentences = [s.strip() for s in sentences if s.strip()]
|
|
|
251 |
last = re.sub(rf'\b(\w+){p}\b', r'\1', last)
|
252 |
return last.strip()
|
253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
def get_wikipedia_summary(query):
|
255 |
cleaned_query = extract_main_query(query)
|
256 |
url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}"
|
|
|
260 |
else:
|
261 |
return "위키백과에서 정보를 가져올 수 없습니다."
|
262 |
|
|
|
263 |
def textrank_summarize(text, top_n=3):
|
264 |
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
|
265 |
sentences = [s.strip() for s in sentences if len(s.strip()) > 10]
|
|
|
266 |
if len(sentences) <= top_n:
|
267 |
+
return text
|
|
|
268 |
vectorizer = TfidfVectorizer()
|
269 |
tfidf_matrix = vectorizer.fit_transform(sentences)
|
270 |
sim_matrix = cosine_similarity(tfidf_matrix)
|
271 |
np.fill_diagonal(sim_matrix, 0)
|
|
|
272 |
def pagerank(matrix, damping=0.85, max_iter=100, tol=1e-4):
|
273 |
N = matrix.shape[0]
|
274 |
ranks = np.ones(N) / N
|
275 |
row_sums = np.sum(matrix, axis=1)
|
276 |
+
row_sums[row_sums == 0] = 1
|
277 |
for _ in range(max_iter):
|
278 |
prev_ranks = ranks.copy()
|
279 |
for i in range(N):
|
|
|
282 |
if np.linalg.norm(ranks - prev_ranks) < tol:
|
283 |
break
|
284 |
return ranks
|
|
|
285 |
scores = pagerank(sim_matrix)
|
286 |
ranked_idx = np.argsort(scores)[::-1]
|
287 |
selected_idx = sorted(ranked_idx[:top_n])
|
288 |
summary = ' '.join([sentences[i] for i in selected_idx])
|
|
|
289 |
return summary
|
290 |
|
|
|
291 |
def summarize_from_wikipedia(query, top_n=3):
|
292 |
raw_summary = get_wikipedia_summary(query)
|
293 |
return textrank_summarize(raw_summary, top_n=top_n)
|
294 |
|
295 |
+
# 의도 분류기
|
296 |
def simple_intent_classifier(text):
|
297 |
text = text.lower()
|
298 |
greet_keywords = ["안녕", "반가워", "이름", "누구", "소개", "어디서 왔", "정체", "몇 살", "너 뭐야"]
|
299 |
info_keywords = ["설명", "정보", "무엇", "뭐야", "어디", "누구", "왜", "어떻게", "종류", "개념"]
|
300 |
math_keywords = ["더하기", "빼기", "곱하기", "나누기", "루트", "제곱", "+", "-", "*", "/", "=", "^", "√", "계산", "몇이야", "얼마야"]
|
|
|
301 |
if any(kw in text for kw in greet_keywords):
|
302 |
return "인사"
|
303 |
elif any(kw in text for kw in info_keywords):
|
|
|
316 |
except:
|
317 |
return "계산할 수 없는 수식이에요. 다시 한번 확인해 주세요!"
|
318 |
|
319 |
+
# 최종 응답 함수
|
320 |
def respond(input_text):
|
321 |
intent = simple_intent_classifier(input_text)
|
322 |
|
|
|
331 |
|
332 |
if intent == "인사":
|
333 |
return "반가워요! 무엇을 도와드릴까요?"
|
334 |
+
|
335 |
if intent == "정보질문":
|
336 |
keyword = re.sub(r"(에 대해|에 대한|에 대해서)?\s*(설명해줘|알려줘|뭐야|개념|정의|정보)?", "", input_text).strip()
|
337 |
if not keyword:
|
|
|
339 |
summary = summarize_from_wikipedia(keyword)
|
340 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
341 |
|
342 |
+
# 일상 대화: 샘플링 + fallback
|
343 |
+
response = generate_text_topkp(model, input_text)
|
344 |
+
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
345 |
+
response = generate_text_greedy(model, input_text)
|
346 |
+
return response
|
|
|
|
|
|
|
|
|
|
|
347 |
|
348 |
@app.get("/generate", response_class=PlainTextResponse)
|
349 |
async def generate(request: Request):
|