Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -167,7 +167,7 @@ def is_greedy_response_acceptable(text):
|
|
167 |
return True
|
168 |
|
169 |
def generate_text_sample(model, prompt, max_len=100, max_gen=98,
|
170 |
-
temperature=0.7, top_k=40, min_len=12):
|
171 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
172 |
model_input = model_input[:max_len]
|
173 |
generated = list(model_input)
|
@@ -190,6 +190,20 @@ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
|
|
190 |
probs[indices_to_remove] = 0
|
191 |
probs /= probs.sum()
|
192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
# 샘플링
|
194 |
next_token = np.random.choice(len(probs), p=probs)
|
195 |
generated.append(int(next_token))
|
@@ -210,7 +224,6 @@ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
|
|
210 |
for t in ["<start>", "<sep>", "<end>"]:
|
211 |
decoded = decoded.replace(t, "")
|
212 |
return decoded.strip()
|
213 |
-
|
214 |
|
215 |
def mismatch_tone(input_text, output_text):
|
216 |
if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
|
|
|
167 |
return True
|
168 |
|
169 |
def generate_text_sample(model, prompt, max_len=100, max_gen=98,
|
170 |
+
temperature=0.7, top_k=40, top_p=0.9, min_len=12):
|
171 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
172 |
model_input = model_input[:max_len]
|
173 |
generated = list(model_input)
|
|
|
190 |
probs[indices_to_remove] = 0
|
191 |
probs /= probs.sum()
|
192 |
|
193 |
+
# Top-P (누적 확률) 필터링
|
194 |
+
if top_p is not None and 0 < top_p < 1:
|
195 |
+
sorted_indices = np.argsort(probs)[::-1]
|
196 |
+
sorted_probs = probs[sorted_indices]
|
197 |
+
cumulative_probs = np.cumsum(sorted_probs)
|
198 |
+
# 누적 확률이 top_p 초과하는 토큰들은 제거
|
199 |
+
cutoff_index = np.searchsorted(cumulative_probs, top_p, side='right')
|
200 |
+
probs_to_keep = sorted_indices[:cutoff_index+1]
|
201 |
+
|
202 |
+
mask = np.ones_like(probs, dtype=bool)
|
203 |
+
mask[probs_to_keep] = False
|
204 |
+
probs[mask] = 0
|
205 |
+
probs /= probs.sum()
|
206 |
+
|
207 |
# 샘플링
|
208 |
next_token = np.random.choice(len(probs), p=probs)
|
209 |
generated.append(int(next_token))
|
|
|
224 |
for t in ["<start>", "<sep>", "<end>"]:
|
225 |
decoded = decoded.replace(t, "")
|
226 |
return decoded.strip()
|
|
|
227 |
|
228 |
def mismatch_tone(input_text, output_text):
|
229 |
if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
|