Yuchan5386 commited on
Commit
dcd1d9c
·
verified ·
1 Parent(s): 025726a

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +53 -36
api.py CHANGED
@@ -140,19 +140,28 @@ _ = model(dummy_input) # 모델이 빌드됨
140
  model.load_weights("InteractGPT.weights.h5")
141
  print("모델 가중치 로드 완료!")
142
 
 
 
 
 
 
 
 
 
 
 
 
143
  def generate_text_topkp(model, prompt, max_len=100, max_gen=98,
144
- temperature=0.50, min_len=20,
145
  repetition_penalty=1.2, top_p=0.90, top_k=50):
146
  def top_kp_filtering(logits, top_k, top_p):
147
  probs = np.exp(logits - np.max(logits))
148
  probs /= probs.sum()
149
  sorted_idx = np.argsort(-probs)
150
  sorted_probs = probs[sorted_idx]
151
- # Top-K 필터링
152
  if top_k > 0:
153
  sorted_idx = sorted_idx[:top_k]
154
  sorted_probs = sorted_probs[:top_k]
155
- # Top-P 필터링
156
  cum_probs = np.cumsum(sorted_probs)
157
  cutoff = np.searchsorted(cum_probs, top_p) + 1
158
  final_idx = sorted_idx[:cutoff]
@@ -169,17 +178,13 @@ def generate_text_topkp(model, prompt, max_len=100, max_gen=98,
169
  input_tensor = tf.convert_to_tensor([input_padded])
170
  logits = model(input_tensor, training=False)
171
  next_logits = logits[0, len(generated) - 1].numpy()
172
- # 반복 억제
173
  for t in set(generated):
174
  count = generated.count(t)
175
  next_logits[t] /= (repetition_penalty ** count)
176
- # 조기 종료 방지
177
  if len(generated) < min_len:
178
  next_logits[end_id] -= 5.0
179
  next_logits[pad_id] -= 10.0
180
- # 온도 적용
181
  next_logits = next_logits / temperature
182
- # Top-KP Sampling 적용
183
  final_idx, final_probs = top_kp_filtering(next_logits, top_k=top_k, top_p=top_p)
184
  sampled = np.random.choice(final_idx, p=final_probs)
185
  generated.append(int(sampled))
@@ -189,7 +194,37 @@ def generate_text_topkp(model, prompt, max_len=100, max_gen=98,
189
  decoded = decoded.strip()
190
  if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
191
  return decoded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
 
193
  def is_valid_response(response):
194
  if len(response.strip()) < 2:
195
  return False
@@ -203,6 +238,7 @@ def is_valid_response(response):
203
  return False
204
  return True
205
 
 
206
  def extract_main_query(text):
207
  sentences = re.split(r'[.?!]\s*', text)
208
  sentences = [s.strip() for s in sentences if s.strip()]
@@ -215,13 +251,6 @@ def extract_main_query(text):
215
  last = re.sub(rf'\b(\w+){p}\b', r'\1', last)
216
  return last.strip()
217
 
218
- import re
219
- import requests
220
- import numpy as np
221
- from sklearn.feature_extraction.text import TfidfVectorizer
222
- from sklearn.metrics.pairwise import cosine_similarity
223
-
224
- # 2. 위키백과 요약 가져오기
225
  def get_wikipedia_summary(query):
226
  cleaned_query = extract_main_query(query)
227
  url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}"
@@ -231,24 +260,20 @@ def get_wikipedia_summary(query):
231
  else:
232
  return "위키백과에서 정보를 가져올 수 없습니다."
233
 
234
- # 3. TextRank 요약기
235
  def textrank_summarize(text, top_n=3):
236
  sentences = re.split(r'(?<=[.!?])\s+', text.strip())
237
  sentences = [s.strip() for s in sentences if len(s.strip()) > 10]
238
-
239
  if len(sentences) <= top_n:
240
- return text # 문장이 너무 적으면 원문 반환
241
-
242
  vectorizer = TfidfVectorizer()
243
  tfidf_matrix = vectorizer.fit_transform(sentences)
244
  sim_matrix = cosine_similarity(tfidf_matrix)
245
  np.fill_diagonal(sim_matrix, 0)
246
-
247
  def pagerank(matrix, damping=0.85, max_iter=100, tol=1e-4):
248
  N = matrix.shape[0]
249
  ranks = np.ones(N) / N
250
  row_sums = np.sum(matrix, axis=1)
251
- row_sums[row_sums == 0] = 1 # NaN 방지용 처리
252
  for _ in range(max_iter):
253
  prev_ranks = ranks.copy()
254
  for i in range(N):
@@ -257,25 +282,22 @@ def textrank_summarize(text, top_n=3):
257
  if np.linalg.norm(ranks - prev_ranks) < tol:
258
  break
259
  return ranks
260
-
261
  scores = pagerank(sim_matrix)
262
  ranked_idx = np.argsort(scores)[::-1]
263
  selected_idx = sorted(ranked_idx[:top_n])
264
  summary = ' '.join([sentences[i] for i in selected_idx])
265
-
266
  return summary
267
 
268
- # 4. 전체 파이프라인
269
  def summarize_from_wikipedia(query, top_n=3):
270
  raw_summary = get_wikipedia_summary(query)
271
  return textrank_summarize(raw_summary, top_n=top_n)
272
 
 
273
  def simple_intent_classifier(text):
274
  text = text.lower()
275
  greet_keywords = ["안녕", "반가워", "이름", "누구", "소개", "어디서 왔", "정체", "몇 살", "너 뭐야"]
276
  info_keywords = ["설명", "정보", "무엇", "뭐야", "어디", "누구", "왜", "어떻게", "종류", "개념"]
277
  math_keywords = ["더하기", "빼기", "곱하기", "나누기", "루트", "제곱", "+", "-", "*", "/", "=", "^", "√", "계산", "몇이야", "얼마야"]
278
-
279
  if any(kw in text for kw in greet_keywords):
280
  return "인사"
281
  elif any(kw in text for kw in info_keywords):
@@ -294,7 +316,7 @@ def parse_math_question(text):
294
  except:
295
  return "계산할 수 없는 수식이에요. 다시 한번 확인해 주세요!"
296
 
297
- # 전체 응답 함수
298
  def respond(input_text):
299
  intent = simple_intent_classifier(input_text)
300
 
@@ -309,7 +331,7 @@ def respond(input_text):
309
 
310
  if intent == "인사":
311
  return "반가워요! 무엇을 도와드릴까요?"
312
-
313
  if intent == "정보질문":
314
  keyword = re.sub(r"(에 대해|에 대한|에 대해서)?\s*(설명해줘|알려줘|뭐야|개념|정의|정보)?", "", input_text).strip()
315
  if not keyword:
@@ -317,16 +339,11 @@ def respond(input_text):
317
  summary = summarize_from_wikipedia(keyword)
318
  return f"{summary}\n다른 궁금한 점 있으신가요?"
319
 
320
- return generate_text_topkp(model, input_text)
321
-
322
- async def async_generator_wrapper(prompt: str):
323
- gen = generate_text_topkp(model, prompt)
324
- for text_piece in gen:
325
- yield text_piece
326
- await asyncio.sleep(0.1)
327
-
328
-
329
- from fastapi.responses import PlainTextResponse
330
 
331
  @app.get("/generate", response_class=PlainTextResponse)
332
  async def generate(request: Request):
 
140
  model.load_weights("InteractGPT.weights.h5")
141
  print("모델 가중치 로드 완료!")
142
 
143
+ import re
144
+ import math
145
+ import numpy as np
146
+ import requests
147
+ import tensorflow as tf
148
+ from sklearn.feature_extraction.text import TfidfVectorizer
149
+ from sklearn.metrics.pairwise import cosine_similarity
150
+ from fastapi import Request
151
+ from fastapi.responses import PlainTextResponse
152
+
153
+ # 1. Top-KP 기반 생성기
154
  def generate_text_topkp(model, prompt, max_len=100, max_gen=98,
155
+ temperature=0.90, min_len=20,
156
  repetition_penalty=1.2, top_p=0.90, top_k=50):
157
  def top_kp_filtering(logits, top_k, top_p):
158
  probs = np.exp(logits - np.max(logits))
159
  probs /= probs.sum()
160
  sorted_idx = np.argsort(-probs)
161
  sorted_probs = probs[sorted_idx]
 
162
  if top_k > 0:
163
  sorted_idx = sorted_idx[:top_k]
164
  sorted_probs = sorted_probs[:top_k]
 
165
  cum_probs = np.cumsum(sorted_probs)
166
  cutoff = np.searchsorted(cum_probs, top_p) + 1
167
  final_idx = sorted_idx[:cutoff]
 
178
  input_tensor = tf.convert_to_tensor([input_padded])
179
  logits = model(input_tensor, training=False)
180
  next_logits = logits[0, len(generated) - 1].numpy()
 
181
  for t in set(generated):
182
  count = generated.count(t)
183
  next_logits[t] /= (repetition_penalty ** count)
 
184
  if len(generated) < min_len:
185
  next_logits[end_id] -= 5.0
186
  next_logits[pad_id] -= 10.0
 
187
  next_logits = next_logits / temperature
 
188
  final_idx, final_probs = top_kp_filtering(next_logits, top_k=top_k, top_p=top_p)
189
  sampled = np.random.choice(final_idx, p=final_probs)
190
  generated.append(int(sampled))
 
194
  decoded = decoded.strip()
195
  if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
196
  return decoded
197
+ return sp.decode(generated)
198
+
199
+ # Greedy 버전 생성기
200
+ def generate_text_greedy(model, prompt, max_len=100, max_gen=98):
201
+ model_input = text_to_ids(f"<start> {prompt} <sep>")
202
+ model_input = model_input[:max_len]
203
+ generated = list(model_input)
204
+ for _ in range(max_gen):
205
+ pad_len = max(0, max_len - len(generated))
206
+ input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
207
+ input_tensor = tf.convert_to_tensor([input_padded])
208
+ logits = model(input_tensor, training=False)
209
+ next_logits = logits[0, len(generated) - 1].numpy()
210
+ next_logits[pad_id] -= 10.0
211
+ next_token = np.argmax(next_logits)
212
+ generated.append(int(next_token))
213
+ decoded = sp.decode(generated)
214
+ for t in ["<start>", "<sep>", "<end>"]:
215
+ decoded = decoded.replace(t, "")
216
+ decoded = decoded.strip()
217
+ if next_token == end_id or decoded.endswith(('.', '!', '?')):
218
+ return decoded
219
+ return sp.decode(generated)
220
+
221
+ # 톤 불일치 체크
222
+ def mismatch_tone(input_text, output_text):
223
+ if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
224
+ return True
225
+ return False
226
 
227
+ # 유효한 응답인지 검사
228
  def is_valid_response(response):
229
  if len(response.strip()) < 2:
230
  return False
 
238
  return False
239
  return True
240
 
241
+ # 위키 요약 관련
242
  def extract_main_query(text):
243
  sentences = re.split(r'[.?!]\s*', text)
244
  sentences = [s.strip() for s in sentences if s.strip()]
 
251
  last = re.sub(rf'\b(\w+){p}\b', r'\1', last)
252
  return last.strip()
253
 
 
 
 
 
 
 
 
254
  def get_wikipedia_summary(query):
255
  cleaned_query = extract_main_query(query)
256
  url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}"
 
260
  else:
261
  return "위키백과에서 정보를 가져올 수 없습니다."
262
 
 
263
  def textrank_summarize(text, top_n=3):
264
  sentences = re.split(r'(?<=[.!?])\s+', text.strip())
265
  sentences = [s.strip() for s in sentences if len(s.strip()) > 10]
 
266
  if len(sentences) <= top_n:
267
+ return text
 
268
  vectorizer = TfidfVectorizer()
269
  tfidf_matrix = vectorizer.fit_transform(sentences)
270
  sim_matrix = cosine_similarity(tfidf_matrix)
271
  np.fill_diagonal(sim_matrix, 0)
 
272
  def pagerank(matrix, damping=0.85, max_iter=100, tol=1e-4):
273
  N = matrix.shape[0]
274
  ranks = np.ones(N) / N
275
  row_sums = np.sum(matrix, axis=1)
276
+ row_sums[row_sums == 0] = 1
277
  for _ in range(max_iter):
278
  prev_ranks = ranks.copy()
279
  for i in range(N):
 
282
  if np.linalg.norm(ranks - prev_ranks) < tol:
283
  break
284
  return ranks
 
285
  scores = pagerank(sim_matrix)
286
  ranked_idx = np.argsort(scores)[::-1]
287
  selected_idx = sorted(ranked_idx[:top_n])
288
  summary = ' '.join([sentences[i] for i in selected_idx])
 
289
  return summary
290
 
 
291
  def summarize_from_wikipedia(query, top_n=3):
292
  raw_summary = get_wikipedia_summary(query)
293
  return textrank_summarize(raw_summary, top_n=top_n)
294
 
295
+ # 의도 분류기
296
  def simple_intent_classifier(text):
297
  text = text.lower()
298
  greet_keywords = ["안녕", "반가워", "이름", "누구", "소개", "어디서 왔", "정체", "몇 살", "너 뭐야"]
299
  info_keywords = ["설명", "정보", "무엇", "뭐야", "어디", "누구", "왜", "어떻게", "종류", "개념"]
300
  math_keywords = ["더하기", "빼기", "곱하기", "나누기", "루트", "제곱", "+", "-", "*", "/", "=", "^", "√", "계산", "몇이야", "얼마야"]
 
301
  if any(kw in text for kw in greet_keywords):
302
  return "인사"
303
  elif any(kw in text for kw in info_keywords):
 
316
  except:
317
  return "계산할 수 없는 수식이에요. 다시 한번 확인해 주세요!"
318
 
319
+ # 최종 응답 함수
320
  def respond(input_text):
321
  intent = simple_intent_classifier(input_text)
322
 
 
331
 
332
  if intent == "인사":
333
  return "반가워요! 무엇을 도와드릴까요?"
334
+
335
  if intent == "정보질문":
336
  keyword = re.sub(r"(에 대해|에 대한|에 대해서)?\s*(설명해줘|알려줘|뭐야|개념|정의|정보)?", "", input_text).strip()
337
  if not keyword:
 
339
  summary = summarize_from_wikipedia(keyword)
340
  return f"{summary}\n다른 궁금한 점 있으신가요?"
341
 
342
+ # 일상 대화: 샘플링 + fallback
343
+ response = generate_text_topkp(model, input_text)
344
+ if not is_valid_response(response) or mismatch_tone(input_text, response):
345
+ response = generate_text_greedy(model, input_text)
346
+ return response
 
 
 
 
 
347
 
348
  @app.get("/generate", response_class=PlainTextResponse)
349
  async def generate(request: Request):