Yuchan5386 commited on
Commit
35d657e
·
verified ·
1 Parent(s): 9cad8b1

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +134 -98
api.py CHANGED
@@ -1,12 +1,14 @@
1
- from fastapi import FastAPI, Request
2
- from fastapi.responses import StreamingResponse
3
- import asyncio
4
- import json
5
- import numpy as np
6
- import tensorflow as tf
7
- from tensorflow.keras import layers
8
- import sentencepiece as spm
9
- import requests
 
 
10
 
11
  app = FastAPI()
12
 
@@ -124,93 +126,127 @@ dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # 배치1, 시퀀스길이
124
  _ = model(dummy_input) # 모델이 빌드됨
125
  model.load_weights("InteractGPT.weights.h5")
126
  print("모델 가중치 로드 완료!")
127
-
128
- def generate_text_mirostat_top_p(model, prompt, max_len=100, max_gen=98,
129
- temperature=1.0, min_len=20,
130
- repetition_penalty=1.2, eta=0.1, m=100, p=0.9):
131
- model_input = text_to_ids(f"<start> {prompt} <sep>")
132
- model_input = model_input[:max_len]
133
- generated = list(model_input)
134
-
135
- tau = 5.0 # 초기 목표 surprise
136
-
137
- for step in range(max_gen):
138
- pad_length = max(0, max_len - len(generated))
139
- input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
140
- input_tensor = tf.convert_to_tensor([input_padded])
141
- logits = model(input_tensor, training=False)
142
- next_token_logits = logits[0, len(generated) - 1].numpy()
143
-
144
- # 반복 페널티 적용
145
- token_counts = {}
146
- for t in generated:
147
- token_counts[t] = token_counts.get(t, 0) + 1
148
- for token_id, count in token_counts.items():
149
- next_token_logits[token_id] /= (repetition_penalty ** count)
150
-
151
- # 최소 길이 넘으면 종료 토큰 확률 낮추기
152
- if len(generated) >= min_len:
153
- next_token_logits[end_id] -= 5.0
154
- next_token_logits[pad_id] -= 10.0
155
-
156
- # 온도 조절
157
- next_token_logits = next_token_logits / temperature
158
-
159
- # --- 미로스타트 + Top-p 샘플링 ---
160
- logits_stable = next_token_logits - np.max(next_token_logits)
161
- probs = np.exp(logits_stable)
162
- probs /= probs.sum()
163
-
164
- # 1. mirostat top-m 후보 추리기
165
- sorted_indices = np.argsort(-probs)
166
- top_indices = sorted_indices[:m]
167
- top_probs = probs[top_indices]
168
- top_probs /= top_probs.sum()
169
-
170
- # 2. mirostat 샘플링
171
- sampled_index = np.random.choice(top_indices, p=top_probs)
172
- sampled_prob = probs[sampled_index]
173
- observed_surprise = -np.log(sampled_prob + 1e-9)
174
- tau += eta * (observed_surprise - tau)
175
-
176
- # 3. top-p 필터링
177
- sorted_top_indices = top_indices[np.argsort(-top_probs)]
178
- sorted_top_probs = np.sort(top_probs)[::-1]
179
- cumulative_probs = np.cumsum(sorted_top_probs)
180
- cutoff = np.searchsorted(cumulative_probs, p, side='left') + 1
181
- filtered_indices = sorted_top_indices[:cutoff]
182
- filtered_probs = sorted_top_probs[:cutoff]
183
- filtered_probs /= filtered_probs.sum()
184
-
185
- # 4. 최종 토큰 샘플링
186
- final_token = np.random.choice(filtered_indices, p=filtered_probs)
187
- generated.append(int(final_token))
188
-
189
- decoded_text = sp.decode(generated)
190
- # 특수 토큰 제거
191
- for token in ["<start>", "<sep>", "<end>"]:
192
- decoded_text = decoded_text.replace(token, "")
193
-
194
- decoded_text = decoded_text.strip()
195
-
196
- if len(generated) >= min_len and (final_token == end_id or decoded_text.endswith(('.', '!', '?'))):
197
- yield decoded_text
198
- break
199
-
200
- async def async_generator_wrapper(prompt: str):
201
- # 동기 제너레이터를 비동기로 감싸기
202
- loop = asyncio.get_event_loop()
203
- gen = generate_text_mirostat_top_p(model, prompt)
204
-
205
- for text_piece in gen:
206
- yield text_piece
207
- # 토큰 생성 속도 조절 (0.1초 딜레이)
208
- await asyncio.sleep(0.1)
209
-
210
- @app.get("/generate")
211
- async def generate(request: Request):
212
- # 쿼리 파라미터로 prompt 받음, 없으면 기본값
213
- prompt = request.query_params.get("prompt", "안녕하세요")
214
-
215
- # 스트리밍 응답으로 보냄
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  return StreamingResponse(async_generator_wrapper(prompt), media_type="text/plain")
 
1
+ import requests
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import asyncio
5
+ from fastapi import FastAPI, Request
6
+ from fastapi.responses import StreamingResponse
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ import nltk
10
+ nltk.download('punkt')
11
+ from nltk.tokenize import
12
 
13
  app = FastAPI()
14
 
 
126
  _ = model(dummy_input) # 모델이 빌드됨
127
  model.load_weights("InteractGPT.weights.h5")
128
  print("모델 가중치 로드 완료!")
129
+
130
+ def extract_main_query(query):
131
+ words = query.split()
132
+ return " ".join(words[:3])
133
+
134
+ def get_wikipedia_summary(query):
135
+ cleaned_query = extract_main_query(query)
136
+ url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}"
137
+ res = requests.get(url)
138
+ if res.status_code == 200:
139
+ return res.json().get("extract", "요약 정보를 찾을 수 없습니다.")
140
+ else:
141
+ return "위키백과에서 정보를 가져올 없습니다."
142
+
143
+ def summarize_text(text, top_n=3):
144
+ sentences = sent_tokenize(text)
145
+ if len(sentences) <= top_n:
146
+ return text
147
+ vectorizer = TfidfVectorizer(ngram_range=(1, 2), stop_words=['은', '는', '이', '가', '을', '를', '에', '에서'])
148
+ tfidf_matrix = vectorizer.fit_transform(sentences)
149
+ sim_matrix = cosine_similarity(tfidf_matrix, tfidf_matrix)
150
+ np.fill_diagonal(sim_matrix, 0)
151
+ scores = sim_matrix.sum(axis=1)
152
+ ranked_idx = np.argsort(scores)[::-1]
153
+ selected_idx = sorted(ranked_idx[:top_n])
154
+ summary = " ".join([sentences[i] for i in selected_idx])
155
+ return summary
156
+
157
+ def simple_intent_classifier(text):
158
+ text = text.lower()
159
+ greet_keywords = ["안녕", "반가워", "이름", "누구", "소개", "어디서 왔", "정체", "몇 살", "너 뭐야"]
160
+ info_keywords = ["설명", "정보", "무엇", "뭐야", "어디", "누구", "왜", "어떻게", "종류", "개념"]
161
+ if any(kw in text for kw in greet_keywords):
162
+ return "인사"
163
+ elif any(kw in text for kw in info_keywords):
164
+ return "정보질문"
165
+ else:
166
+ return "일상대화"
167
+
168
+
169
+ def generate_text_mirostat_top_p(model, prompt, max_len=100, max_gen=98,
170
+ temperature=1.0, min_len=20,
171
+ repetition_penalty=1.2, eta=0.1, m=100, p=0.9):
172
+ model_input = text_to_ids(f"<start> {prompt} <sep>")
173
+ model_input = model_input[:max_len]
174
+ generated = list(model_input)
175
+
176
+ tau = 5.0 # 초기 목표 surprise
177
+
178
+ for step in range(max_gen):
179
+ pad_length = max(0, max_len - len(generated))
180
+ input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
181
+ input_tensor = tf.convert_to_tensor([input_padded])
182
+ logits = model(input_tensor, training=False)
183
+ next_token_logits = logits[0, len(generated) - 1].numpy()
184
+
185
+ # 반복 페널티 적용
186
+ token_counts = {}
187
+ for t in generated:
188
+ token_counts[t] = token_counts.get(t, 0) + 1
189
+ for token_id, count in token_counts.items():
190
+ next_token_logits[token_id] /= (repetition_penalty ** count)
191
+
192
+ # 최소 길이 넘으면 종료 토큰 확률 낮추기
193
+ if len(generated) >= min_len:
194
+ next_token_logits[end_id] -= 5.0
195
+ next_token_logits[pad_id] -= 10.0
196
+
197
+ # 온도 조절
198
+ next_token_logits = next_token_logits / temperature
199
+
200
+ # --- 미로스타트 + Top-p 샘플링 ---
201
+ logits_stable = next_token_logits - np.max(next_token_logits)
202
+ probs = np.exp(logits_stable)
203
+ probs /= probs.sum()
204
+
205
+ sorted_indices = np.argsort(-probs)
206
+ top_indices = sorted_indices[:m]
207
+ top_probs = probs[top_indices]
208
+ top_probs /= top_probs.sum()
209
+
210
+ sampled_index = np.random.choice(top_indices, p=top_probs)
211
+ sampled_prob = probs[sampled_index]
212
+ observed_surprise = -np.log(sampled_prob + 1e-9)
213
+ tau += eta * (observed_surprise - tau)
214
+
215
+ sorted_top_indices = top_indices[np.argsort(-top_probs)]
216
+ sorted_top_probs = np.sort(top_probs)[::-1]
217
+ cumulative_probs = np.cumsum(sorted_top_probs)
218
+ cutoff = np.searchsorted(cumulative_probs, p, side='left') + 1
219
+ filtered_indices = sorted_top_indices[:cutoff]
220
+ filtered_probs = sorted_top_probs[:cutoff]
221
+ filtered_probs /= filtered_probs.sum()
222
+
223
+ final_token = np.random.choice(filtered_indices, p=filtered_probs)
224
+ generated.append(int(final_token))
225
+
226
+ decoded_text = decode_ids(generated)
227
+ for token in ["<start>", "<sep>", "<end>"]:
228
+ decoded_text = decoded_text.replace(token, "")
229
+ decoded_text = decoded_text.strip()
230
+
231
+ if len(generated) >= min_len and (final_token == end_id or decoded_text.endswith(('.', '!', '?'))):
232
+ yield decoded_text
233
+ break
234
+
235
+ async def async_generator_wrapper(prompt: str):
236
+ intent = simple_intent_classifier(prompt)
237
+
238
+ if intent == "정보질문":
239
+ wiki_summary = get_wikipedia_summary(prompt)
240
+ summarized = summarize_text(wiki_summary, top_n=3)
241
+ yield f"『 \"{prompt}\" 에 대한 위키백과 요약입니다. 』\n\n{summarized}\n\n"
242
+
243
+ # 이후 일반 생성으로 이어감 (스트리밍)
244
+ gen = generate_text_mirostat_top_p(model, prompt)
245
+ for text_piece in gen:
246
+ yield text_piece
247
+ await asyncio.sleep(0.1)
248
+
249
+ @app.get("/generate")
250
+ async def generate(request: Request):
251
+ prompt = request.query_params.get("prompt", "안녕하세요")
252
  return StreamingResponse(async_generator_wrapper(prompt), media_type="text/plain")