Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -150,75 +150,47 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
150 |
from fastapi import Request
|
151 |
from fastapi.responses import PlainTextResponse
|
152 |
|
153 |
-
|
154 |
-
|
155 |
-
temperature=0.90, min_len=20,
|
156 |
-
repetition_penalty=1.2, top_p=0.90, top_k=50):
|
157 |
-
def top_kp_filtering(logits, top_k, top_p):
|
158 |
-
probs = np.exp(logits - np.max(logits))
|
159 |
-
probs /= probs.sum()
|
160 |
-
sorted_idx = np.argsort(-probs)
|
161 |
-
sorted_probs = probs[sorted_idx]
|
162 |
-
if top_k > 0:
|
163 |
-
sorted_idx = sorted_idx[:top_k]
|
164 |
-
sorted_probs = sorted_probs[:top_k]
|
165 |
-
cum_probs = np.cumsum(sorted_probs)
|
166 |
-
cutoff = np.searchsorted(cum_probs, top_p) + 1
|
167 |
-
final_idx = sorted_idx[:cutoff]
|
168 |
-
final_probs = probs[final_idx]
|
169 |
-
final_probs /= final_probs.sum()
|
170 |
-
return final_idx, final_probs
|
171 |
-
|
172 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
173 |
model_input = model_input[:max_len]
|
174 |
generated = list(model_input)
|
175 |
-
|
|
|
176 |
pad_len = max(0, max_len - len(generated))
|
177 |
input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
|
178 |
input_tensor = tf.convert_to_tensor([input_padded])
|
179 |
logits = model(input_tensor, training=False)
|
180 |
next_logits = logits[0, len(generated) - 1].numpy()
|
|
|
|
|
181 |
for t in set(generated):
|
182 |
count = generated.count(t)
|
183 |
-
next_logits[t] /= (repetition_penalty ** count)
|
184 |
-
if len(generated) < min_len:
|
185 |
-
next_logits[end_id] -= 5.0
|
186 |
-
next_logits[pad_id] -= 10.0
|
187 |
-
next_logits = next_logits / temperature
|
188 |
-
final_idx, final_probs = top_kp_filtering(next_logits, top_k=top_k, top_p=top_p)
|
189 |
-
sampled = np.random.choice(final_idx, p=final_probs)
|
190 |
-
generated.append(int(sampled))
|
191 |
-
decoded = sp.decode(generated)
|
192 |
-
for t in ["<start>", "<sep>", "<end>"]:
|
193 |
-
decoded = decoded.replace(t, "")
|
194 |
-
decoded = decoded.strip()
|
195 |
-
if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
|
196 |
-
return decoded
|
197 |
-
return sp.decode(generated)
|
198 |
|
199 |
-
#
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
input_tensor = tf.convert_to_tensor([input_padded])
|
208 |
-
logits = model(input_tensor, training=False)
|
209 |
-
next_logits = logits[0, len(generated) - 1].numpy()
|
210 |
-
next_logits[pad_id] -= 10.0
|
211 |
-
next_token = np.argmax(next_logits)
|
212 |
generated.append(int(next_token))
|
|
|
213 |
decoded = sp.decode(generated)
|
214 |
for t in ["<start>", "<sep>", "<end>"]:
|
215 |
decoded = decoded.replace(t, "")
|
216 |
-
decoded = decoded.strip()
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
219 |
return sp.decode(generated)
|
220 |
|
221 |
-
# 톤 불일치 체크
|
222 |
def mismatch_tone(input_text, output_text):
|
223 |
if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
|
224 |
return True
|
@@ -340,9 +312,9 @@ def respond(input_text):
|
|
340 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
341 |
|
342 |
# 일상 대화: 샘플링 + fallback
|
343 |
-
response =
|
344 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
345 |
-
response =
|
346 |
return response
|
347 |
|
348 |
@app.get("/generate", response_class=PlainTextResponse)
|
|
|
150 |
from fastapi import Request
|
151 |
from fastapi.responses import PlainTextResponse
|
152 |
|
153 |
+
def generate_text_greedy_strong(model, prompt, max_len=100, max_gen=98,
|
154 |
+
repetition_penalty=1.2, min_len=20):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
156 |
model_input = model_input[:max_len]
|
157 |
generated = list(model_input)
|
158 |
+
|
159 |
+
for _ in range(max_gen):
|
160 |
pad_len = max(0, max_len - len(generated))
|
161 |
input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
|
162 |
input_tensor = tf.convert_to_tensor([input_padded])
|
163 |
logits = model(input_tensor, training=False)
|
164 |
next_logits = logits[0, len(generated) - 1].numpy()
|
165 |
+
|
166 |
+
# Repetition Penalty
|
167 |
for t in set(generated):
|
168 |
count = generated.count(t)
|
169 |
+
next_logits[t] /= (repetition_penalty ** count)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
+
# Stop token filtering
|
172 |
+
stop_tokens = ["음", "어", "그", "뭐지", "..."]
|
173 |
+
for tok in stop_tokens:
|
174 |
+
tok_id = sp.piece_to_id(tok)
|
175 |
+
next_logits[tok_id] -= 5.0
|
176 |
+
|
177 |
+
next_logits[pad_id] -= 10.0
|
178 |
+
next_token = np.argmax(next_logits)
|
|
|
|
|
|
|
|
|
|
|
179 |
generated.append(int(next_token))
|
180 |
+
|
181 |
decoded = sp.decode(generated)
|
182 |
for t in ["<start>", "<sep>", "<end>"]:
|
183 |
decoded = decoded.replace(t, "")
|
184 |
+
decoded = decoded.strip()
|
185 |
+
|
186 |
+
if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
|
187 |
+
if is_greedy_response_acceptable(decoded):
|
188 |
+
return decoded
|
189 |
+
else:
|
190 |
+
continue
|
191 |
+
|
192 |
return sp.decode(generated)
|
193 |
|
|
|
194 |
def mismatch_tone(input_text, output_text):
|
195 |
if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
|
196 |
return True
|
|
|
312 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
313 |
|
314 |
# 일상 대화: 샘플링 + fallback
|
315 |
+
response = generate_text_greedy_strong(model, input_text)
|
316 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
317 |
+
response = generate_text_greedy_strong(model, input_text)
|
318 |
return response
|
319 |
|
320 |
@app.get("/generate", response_class=PlainTextResponse)
|