Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -1,12 +1,14 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
import
|
4 |
-
import
|
5 |
-
import
|
6 |
-
|
7 |
-
from
|
8 |
-
|
9 |
-
import
|
|
|
|
|
10 |
|
11 |
app = FastAPI()
|
12 |
|
@@ -124,93 +126,127 @@ dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # 배치1, 시퀀스길이
|
|
124 |
_ = model(dummy_input) # 모델이 빌드됨
|
125 |
model.load_weights("InteractGPT.weights.h5")
|
126 |
print("모델 가중치 로드 완료!")
|
127 |
-
|
128 |
-
def
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
#
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
return StreamingResponse(async_generator_wrapper(prompt), media_type="text/plain")
|
|
|
1 |
+
import requests
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
import asyncio
|
5 |
+
from fastapi import FastAPI, Request
|
6 |
+
from fastapi.responses import StreamingResponse
|
7 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
+
import nltk
|
10 |
+
nltk.download('punkt')
|
11 |
+
from nltk.tokenize import
|
12 |
|
13 |
app = FastAPI()
|
14 |
|
|
|
126 |
_ = model(dummy_input) # 모델이 빌드됨
|
127 |
model.load_weights("InteractGPT.weights.h5")
|
128 |
print("모델 가중치 로드 완료!")
|
129 |
+
|
130 |
+
def extract_main_query(query):
|
131 |
+
words = query.split()
|
132 |
+
return " ".join(words[:3])
|
133 |
+
|
134 |
+
def get_wikipedia_summary(query):
|
135 |
+
cleaned_query = extract_main_query(query)
|
136 |
+
url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}"
|
137 |
+
res = requests.get(url)
|
138 |
+
if res.status_code == 200:
|
139 |
+
return res.json().get("extract", "요약 정보를 찾을 수 없습니다.")
|
140 |
+
else:
|
141 |
+
return "위키백과에서 정보를 가져올 수 없습니다."
|
142 |
+
|
143 |
+
def summarize_text(text, top_n=3):
|
144 |
+
sentences = sent_tokenize(text)
|
145 |
+
if len(sentences) <= top_n:
|
146 |
+
return text
|
147 |
+
vectorizer = TfidfVectorizer(ngram_range=(1, 2), stop_words=['은', '는', '이', '가', '을', '를', '에', '에서'])
|
148 |
+
tfidf_matrix = vectorizer.fit_transform(sentences)
|
149 |
+
sim_matrix = cosine_similarity(tfidf_matrix, tfidf_matrix)
|
150 |
+
np.fill_diagonal(sim_matrix, 0)
|
151 |
+
scores = sim_matrix.sum(axis=1)
|
152 |
+
ranked_idx = np.argsort(scores)[::-1]
|
153 |
+
selected_idx = sorted(ranked_idx[:top_n])
|
154 |
+
summary = " ".join([sentences[i] for i in selected_idx])
|
155 |
+
return summary
|
156 |
+
|
157 |
+
def simple_intent_classifier(text):
|
158 |
+
text = text.lower()
|
159 |
+
greet_keywords = ["안녕", "반가워", "이름", "누구", "소개", "어디서 왔", "정체", "몇 살", "너 뭐야"]
|
160 |
+
info_keywords = ["설명", "정보", "무엇", "뭐야", "어디", "누구", "왜", "어떻게", "종류", "개념"]
|
161 |
+
if any(kw in text for kw in greet_keywords):
|
162 |
+
return "인사"
|
163 |
+
elif any(kw in text for kw in info_keywords):
|
164 |
+
return "정보질문"
|
165 |
+
else:
|
166 |
+
return "일상대화"
|
167 |
+
|
168 |
+
|
169 |
+
def generate_text_mirostat_top_p(model, prompt, max_len=100, max_gen=98,
|
170 |
+
temperature=1.0, min_len=20,
|
171 |
+
repetition_penalty=1.2, eta=0.1, m=100, p=0.9):
|
172 |
+
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
173 |
+
model_input = model_input[:max_len]
|
174 |
+
generated = list(model_input)
|
175 |
+
|
176 |
+
tau = 5.0 # 초기 목표 surprise
|
177 |
+
|
178 |
+
for step in range(max_gen):
|
179 |
+
pad_length = max(0, max_len - len(generated))
|
180 |
+
input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
|
181 |
+
input_tensor = tf.convert_to_tensor([input_padded])
|
182 |
+
logits = model(input_tensor, training=False)
|
183 |
+
next_token_logits = logits[0, len(generated) - 1].numpy()
|
184 |
+
|
185 |
+
# 반복 페널티 적용
|
186 |
+
token_counts = {}
|
187 |
+
for t in generated:
|
188 |
+
token_counts[t] = token_counts.get(t, 0) + 1
|
189 |
+
for token_id, count in token_counts.items():
|
190 |
+
next_token_logits[token_id] /= (repetition_penalty ** count)
|
191 |
+
|
192 |
+
# 최소 길이 넘으면 종료 토큰 확률 낮추기
|
193 |
+
if len(generated) >= min_len:
|
194 |
+
next_token_logits[end_id] -= 5.0
|
195 |
+
next_token_logits[pad_id] -= 10.0
|
196 |
+
|
197 |
+
# 온도 조절
|
198 |
+
next_token_logits = next_token_logits / temperature
|
199 |
+
|
200 |
+
# --- 미로스타트 + Top-p 샘플링 ---
|
201 |
+
logits_stable = next_token_logits - np.max(next_token_logits)
|
202 |
+
probs = np.exp(logits_stable)
|
203 |
+
probs /= probs.sum()
|
204 |
+
|
205 |
+
sorted_indices = np.argsort(-probs)
|
206 |
+
top_indices = sorted_indices[:m]
|
207 |
+
top_probs = probs[top_indices]
|
208 |
+
top_probs /= top_probs.sum()
|
209 |
+
|
210 |
+
sampled_index = np.random.choice(top_indices, p=top_probs)
|
211 |
+
sampled_prob = probs[sampled_index]
|
212 |
+
observed_surprise = -np.log(sampled_prob + 1e-9)
|
213 |
+
tau += eta * (observed_surprise - tau)
|
214 |
+
|
215 |
+
sorted_top_indices = top_indices[np.argsort(-top_probs)]
|
216 |
+
sorted_top_probs = np.sort(top_probs)[::-1]
|
217 |
+
cumulative_probs = np.cumsum(sorted_top_probs)
|
218 |
+
cutoff = np.searchsorted(cumulative_probs, p, side='left') + 1
|
219 |
+
filtered_indices = sorted_top_indices[:cutoff]
|
220 |
+
filtered_probs = sorted_top_probs[:cutoff]
|
221 |
+
filtered_probs /= filtered_probs.sum()
|
222 |
+
|
223 |
+
final_token = np.random.choice(filtered_indices, p=filtered_probs)
|
224 |
+
generated.append(int(final_token))
|
225 |
+
|
226 |
+
decoded_text = decode_ids(generated)
|
227 |
+
for token in ["<start>", "<sep>", "<end>"]:
|
228 |
+
decoded_text = decoded_text.replace(token, "")
|
229 |
+
decoded_text = decoded_text.strip()
|
230 |
+
|
231 |
+
if len(generated) >= min_len and (final_token == end_id or decoded_text.endswith(('.', '!', '?'))):
|
232 |
+
yield decoded_text
|
233 |
+
break
|
234 |
+
|
235 |
+
async def async_generator_wrapper(prompt: str):
|
236 |
+
intent = simple_intent_classifier(prompt)
|
237 |
+
|
238 |
+
if intent == "정보질문":
|
239 |
+
wiki_summary = get_wikipedia_summary(prompt)
|
240 |
+
summarized = summarize_text(wiki_summary, top_n=3)
|
241 |
+
yield f"『 \"{prompt}\" 에 대한 위키백과 요약입니다. 』\n\n{summarized}\n\n"
|
242 |
+
|
243 |
+
# 이후 일반 생성으로 이어감 (스트리밍)
|
244 |
+
gen = generate_text_mirostat_top_p(model, prompt)
|
245 |
+
for text_piece in gen:
|
246 |
+
yield text_piece
|
247 |
+
await asyncio.sleep(0.1)
|
248 |
+
|
249 |
+
@app.get("/generate")
|
250 |
+
async def generate(request: Request):
|
251 |
+
prompt = request.query_params.get("prompt", "안녕하세요")
|
252 |
return StreamingResponse(async_generator_wrapper(prompt), media_type="text/plain")
|