Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -140,14 +140,19 @@ _ = model(dummy_input) # 모델이 빌드됨
|
|
140 |
model.load_weights("InteractGPT.weights.h5")
|
141 |
print("모델 가중치 로드 완료!")
|
142 |
|
143 |
-
def
|
144 |
-
|
145 |
-
|
146 |
-
def
|
147 |
probs = np.exp(logits - np.max(logits))
|
148 |
probs /= probs.sum()
|
149 |
sorted_idx = np.argsort(-probs)
|
150 |
sorted_probs = probs[sorted_idx]
|
|
|
|
|
|
|
|
|
|
|
151 |
cum_probs = np.cumsum(sorted_probs)
|
152 |
cutoff = np.searchsorted(cum_probs, top_p) + 1
|
153 |
final_idx = sorted_idx[:cutoff]
|
@@ -174,8 +179,8 @@ def generate_text_topp(model, prompt, max_len=100, max_gen=98,
|
|
174 |
next_logits[pad_id] -= 10.0
|
175 |
# 온도 적용
|
176 |
next_logits = next_logits / temperature
|
177 |
-
# Top-
|
178 |
-
final_idx, final_probs =
|
179 |
sampled = np.random.choice(final_idx, p=final_probs)
|
180 |
generated.append(int(sampled))
|
181 |
decoded = sp.decode(generated)
|
@@ -266,10 +271,10 @@ def respond(input_text):
|
|
266 |
summary = get_wikipedia_summary(keyword)
|
267 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
268 |
|
269 |
-
return
|
270 |
|
271 |
async def async_generator_wrapper(prompt: str):
|
272 |
-
gen =
|
273 |
for text_piece in gen:
|
274 |
yield text_piece
|
275 |
await asyncio.sleep(0.1)
|
|
|
140 |
model.load_weights("InteractGPT.weights.h5")
|
141 |
print("모델 가중치 로드 완료!")
|
142 |
|
143 |
+
def generate_text_topkp(model, prompt, max_len=100, max_gen=98,
|
144 |
+
temperature=0.50, min_len=20,
|
145 |
+
repetition_penalty=1.2, top_p=0.90, top_k=50):
|
146 |
+
def top_kp_filtering(logits, top_k, top_p):
|
147 |
probs = np.exp(logits - np.max(logits))
|
148 |
probs /= probs.sum()
|
149 |
sorted_idx = np.argsort(-probs)
|
150 |
sorted_probs = probs[sorted_idx]
|
151 |
+
# Top-K 필터링
|
152 |
+
if top_k > 0:
|
153 |
+
sorted_idx = sorted_idx[:top_k]
|
154 |
+
sorted_probs = sorted_probs[:top_k]
|
155 |
+
# Top-P 필터링
|
156 |
cum_probs = np.cumsum(sorted_probs)
|
157 |
cutoff = np.searchsorted(cum_probs, top_p) + 1
|
158 |
final_idx = sorted_idx[:cutoff]
|
|
|
179 |
next_logits[pad_id] -= 10.0
|
180 |
# 온도 적용
|
181 |
next_logits = next_logits / temperature
|
182 |
+
# Top-KP Sampling 적용
|
183 |
+
final_idx, final_probs = top_kp_filtering(next_logits, top_k=top_k, top_p=top_p)
|
184 |
sampled = np.random.choice(final_idx, p=final_probs)
|
185 |
generated.append(int(sampled))
|
186 |
decoded = sp.decode(generated)
|
|
|
271 |
summary = get_wikipedia_summary(keyword)
|
272 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
273 |
|
274 |
+
return generate_text_topkp(model, input_text)
|
275 |
|
276 |
async def async_generator_wrapper(prompt: str):
|
277 |
+
gen = generate_text_topkp(model, prompt)
|
278 |
for text_piece in gen:
|
279 |
yield text_piece
|
280 |
await asyncio.sleep(0.1)
|