Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -140,68 +140,50 @@ _ = model(dummy_input) # 모델이 빌드됨
|
|
140 |
model.load_weights("InteractGPT.weights.h5")
|
141 |
print("모델 가중치 로드 완료!")
|
142 |
|
143 |
-
def
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
probs
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
if len(generated)
|
186 |
-
|
187 |
-
next_logits[pad_id] -= 10.0
|
188 |
-
|
189 |
-
# 온도 적용
|
190 |
-
next_logits = next_logits / temperature
|
191 |
-
|
192 |
-
# Typical Sampling 적용
|
193 |
-
final_idx, final_probs = typical_filtering(next_logits, typical_p=typical_p)
|
194 |
-
sampled = np.random.choice(final_idx, p=final_probs)
|
195 |
-
|
196 |
-
generated.append(int(sampled))
|
197 |
-
|
198 |
-
decoded = sp.decode(generated)
|
199 |
-
for t in ["<start>", "<sep>", "<end>"]:
|
200 |
-
decoded = decoded.replace(t, "")
|
201 |
-
decoded = decoded.strip()
|
202 |
-
|
203 |
-
if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
|
204 |
-
return decoded # ← yield 대신 return
|
205 |
|
206 |
def is_valid_response(response):
|
207 |
if len(response.strip()) < 2:
|
@@ -284,10 +266,10 @@ def respond(input_text):
|
|
284 |
summary = get_wikipedia_summary(keyword)
|
285 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
286 |
|
287 |
-
return
|
288 |
|
289 |
async def async_generator_wrapper(prompt: str):
|
290 |
-
gen =
|
291 |
for text_piece in gen:
|
292 |
yield text_piece
|
293 |
await asyncio.sleep(0.1)
|
|
|
140 |
model.load_weights("InteractGPT.weights.h5")
|
141 |
print("모델 가중치 로드 완료!")
|
142 |
|
143 |
+
def generate_text_topp(model, prompt, max_len=100, max_gen=98,
|
144 |
+
temperature=0.50, min_len=20,
|
145 |
+
repetition_penalty=1.2, top_p=0.90):
|
146 |
+
def top_p_filtering(logits, top_p):
|
147 |
+
probs = np.exp(logits - np.max(logits))
|
148 |
+
probs /= probs.sum()
|
149 |
+
sorted_idx = np.argsort(-probs)
|
150 |
+
sorted_probs = probs[sorted_idx]
|
151 |
+
cum_probs = np.cumsum(sorted_probs)
|
152 |
+
cutoff = np.searchsorted(cum_probs, top_p) + 1
|
153 |
+
final_idx = sorted_idx[:cutoff]
|
154 |
+
final_probs = probs[final_idx]
|
155 |
+
final_probs /= final_probs.sum()
|
156 |
+
return final_idx, final_probs
|
157 |
+
|
158 |
+
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
159 |
+
model_input = model_input[:max_len]
|
160 |
+
generated = list(model_input)
|
161 |
+
for step in range(max_gen):
|
162 |
+
pad_len = max(0, max_len - len(generated))
|
163 |
+
input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
|
164 |
+
input_tensor = tf.convert_to_tensor([input_padded])
|
165 |
+
logits = model(input_tensor, training=False)
|
166 |
+
next_logits = logits[0, len(generated) - 1].numpy()
|
167 |
+
# 반복 억제
|
168 |
+
for t in set(generated):
|
169 |
+
count = generated.count(t)
|
170 |
+
next_logits[t] /= (repetition_penalty ** count)
|
171 |
+
# 조기 종료 방지
|
172 |
+
if len(generated) < min_len:
|
173 |
+
next_logits[end_id] -= 5.0
|
174 |
+
next_logits[pad_id] -= 10.0
|
175 |
+
# 온도 적용
|
176 |
+
next_logits = next_logits / temperature
|
177 |
+
# Top-P Sampling 적용
|
178 |
+
final_idx, final_probs = top_p_filtering(next_logits, top_p=top_p)
|
179 |
+
sampled = np.random.choice(final_idx, p=final_probs)
|
180 |
+
generated.append(int(sampled))
|
181 |
+
decoded = sp.decode(generated)
|
182 |
+
for t in ["<start>", "<sep>", "<end>"]:
|
183 |
+
decoded = decoded.replace(t, "")
|
184 |
+
decoded = decoded.strip()
|
185 |
+
if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
|
186 |
+
return decoded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
|
188 |
def is_valid_response(response):
|
189 |
if len(response.strip()) < 2:
|
|
|
266 |
summary = get_wikipedia_summary(keyword)
|
267 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
268 |
|
269 |
+
return generate_text_topp(model, input_text)
|
270 |
|
271 |
async def async_generator_wrapper(prompt: str):
|
272 |
+
gen = generate_text_topp(model, prompt)
|
273 |
for text_piece in gen:
|
274 |
yield text_piece
|
275 |
await asyncio.sleep(0.1)
|