Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -138,9 +138,30 @@ _ = model(dummy_input) # 모델이 빌드됨
|
|
138 |
model.load_weights("InteractGPT.weights.h5")
|
139 |
print("모델 가중치 로드 완료!")
|
140 |
|
141 |
-
def
|
142 |
-
|
143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
145 |
model_input = model_input[:max_len]
|
146 |
generated = list(model_input)
|
@@ -165,26 +186,11 @@ def generate_text_top_kp(model, prompt, max_len=100, max_gen=98,
|
|
165 |
|
166 |
# 온도 적용
|
167 |
next_logits = next_logits / temperature
|
168 |
-
probs = np.exp(next_logits - np.max(next_logits))
|
169 |
-
probs /= probs.sum()
|
170 |
-
|
171 |
-
# Top-K 적용
|
172 |
-
top_k = min(top_k, len(probs))
|
173 |
-
top_k_idx = np.argsort(-probs)[:top_k]
|
174 |
-
top_k_probs = probs[top_k_idx]
|
175 |
-
top_k_probs /= top_k_probs.sum()
|
176 |
-
|
177 |
-
# Top-P 필터링
|
178 |
-
sorted_idx = np.argsort(-top_k_probs)
|
179 |
-
sorted_probs = top_k_probs[sorted_idx]
|
180 |
-
cum_probs = np.cumsum(sorted_probs)
|
181 |
-
cutoff = np.searchsorted(cum_probs, top_p) + 1
|
182 |
-
|
183 |
-
final_idx = top_k_idx[sorted_idx[:cutoff]]
|
184 |
-
final_probs = sorted_probs[:cutoff]
|
185 |
-
final_probs /= final_probs.sum()
|
186 |
|
|
|
|
|
187 |
sampled = np.random.choice(final_idx, p=final_probs)
|
|
|
188 |
generated.append(int(sampled))
|
189 |
|
190 |
decoded = sp.decode(generated)
|
@@ -197,7 +203,7 @@ def generate_text_top_kp(model, prompt, max_len=100, max_gen=98,
|
|
197 |
break
|
198 |
|
199 |
async def async_generator_wrapper(prompt: str):
|
200 |
-
gen =
|
201 |
for text_piece in gen:
|
202 |
yield text_piece
|
203 |
await asyncio.sleep(0.1)
|
|
|
138 |
model.load_weights("InteractGPT.weights.h5")
|
139 |
print("모델 가중치 로드 완료!")
|
140 |
|
141 |
+
def generate_text_typical(model, prompt, max_len=100, max_gen=98,
|
142 |
+
temperature=0.7, min_len=20,
|
143 |
+
repetition_penalty=1.1, typical_p=0.95):
|
144 |
+
|
145 |
+
def typical_filtering(logits, typical_p):
|
146 |
+
probs = np.exp(logits - np.max(logits))
|
147 |
+
probs /= probs.sum()
|
148 |
+
|
149 |
+
log_probs = np.log(probs + 1e-9)
|
150 |
+
entropy = -np.sum(probs * log_probs)
|
151 |
+
|
152 |
+
shifted = np.abs(-log_probs - entropy)
|
153 |
+
sorted_idx = np.argsort(shifted)
|
154 |
+
sorted_probs = probs[sorted_idx]
|
155 |
+
|
156 |
+
cum_probs = np.cumsum(sorted_probs)
|
157 |
+
cutoff = np.searchsorted(cum_probs, typical_p) + 1
|
158 |
+
|
159 |
+
final_idx = sorted_idx[:cutoff]
|
160 |
+
final_probs = probs[final_idx]
|
161 |
+
final_probs /= final_probs.sum()
|
162 |
+
|
163 |
+
return final_idx, final_probs
|
164 |
+
|
165 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
166 |
model_input = model_input[:max_len]
|
167 |
generated = list(model_input)
|
|
|
186 |
|
187 |
# 온도 적용
|
188 |
next_logits = next_logits / temperature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
+
# Typical Sampling 적용
|
191 |
+
final_idx, final_probs = typical_filtering(next_logits, typical_p=typical_p)
|
192 |
sampled = np.random.choice(final_idx, p=final_probs)
|
193 |
+
|
194 |
generated.append(int(sampled))
|
195 |
|
196 |
decoded = sp.decode(generated)
|
|
|
203 |
break
|
204 |
|
205 |
async def async_generator_wrapper(prompt: str):
|
206 |
+
gen = generate_text_typical(model, prompt)
|
207 |
for text_piece in gen:
|
208 |
yield text_piece
|
209 |
await asyncio.sleep(0.1)
|