Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -171,47 +171,49 @@ def is_greedy_response_acceptable(text):
|
|
171 |
|
172 |
return True
|
173 |
|
174 |
-
def
|
175 |
-
|
176 |
-
|
177 |
-
model_input =
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
190 |
next_logits[t] /= (repetition_penalty ** count)
|
191 |
|
192 |
-
#
|
193 |
-
|
194 |
-
for tok in stop_tokens:
|
195 |
-
tok_id = sp.piece_to_id(tok)
|
196 |
-
next_logits[tok_id] -= 5.0
|
197 |
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
|
202 |
-
|
203 |
-
|
204 |
-
|
|
|
|
|
|
|
|
|
205 |
decoded = decoded.strip()
|
206 |
|
207 |
-
if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
|
208 |
if is_greedy_response_acceptable(decoded):
|
209 |
return decoded
|
210 |
else:
|
211 |
continue
|
212 |
|
213 |
return sp.decode(generated)
|
214 |
-
|
215 |
def mismatch_tone(input_text, output_text):
|
216 |
if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
|
217 |
return True
|
@@ -333,9 +335,9 @@ def respond(input_text):
|
|
333 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
334 |
|
335 |
# 일상 대화: 샘플링 + fallback
|
336 |
-
response =
|
337 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
338 |
-
response =
|
339 |
return response
|
340 |
|
341 |
@app.get("/generate", response_class=PlainTextResponse)
|
|
|
171 |
|
172 |
return True
|
173 |
|
174 |
+
def generate_text_with_temp_and_rep_penalty(model, prompt, max_len=100, max_gen=98,
|
175 |
+
repetition_penalty=1.2, temperature=0.7,
|
176 |
+
min_len=20):
|
177 |
+
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
178 |
+
model_input = model_input[:max_len]
|
179 |
+
generated = list(model_input)
|
180 |
+
|
181 |
+
for _ in range(max_gen):
|
182 |
+
pad_len = max(0, max_len - len(generated))
|
183 |
+
input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
|
184 |
+
input_tensor = tf.convert_to_tensor([input_padded])
|
185 |
+
logits = model(input_tensor, training=False)
|
186 |
+
next_logits = logits[0, len(generated) - 1].numpy()
|
187 |
+
|
188 |
+
# Repetition penalty
|
189 |
+
for t in set(generated):
|
190 |
+
count = generated.count(t)
|
191 |
next_logits[t] /= (repetition_penalty ** count)
|
192 |
|
193 |
+
# Temperature scaling
|
194 |
+
next_logits = next_logits / temperature
|
|
|
|
|
|
|
195 |
|
196 |
+
# Softmax로 확률 계산
|
197 |
+
exp_logits = np.exp(next_logits - np.max(next_logits))
|
198 |
+
probs = exp_logits / exp_logits.sum()
|
199 |
|
200 |
+
# 다음 토큰 샘플링
|
201 |
+
next_token = np.random.choice(len(probs), p=probs)
|
202 |
+
generated.append(int(next_token))
|
203 |
+
|
204 |
+
decoded = sp.decode(generated)
|
205 |
+
for t in ["<start>", "<sep>", "<end>"]:
|
206 |
+
decoded = decoded.replace(t, "")
|
207 |
decoded = decoded.strip()
|
208 |
|
209 |
+
if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
|
210 |
if is_greedy_response_acceptable(decoded):
|
211 |
return decoded
|
212 |
else:
|
213 |
continue
|
214 |
|
215 |
return sp.decode(generated)
|
216 |
+
|
217 |
def mismatch_tone(input_text, output_text):
|
218 |
if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
|
219 |
return True
|
|
|
335 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
336 |
|
337 |
# 일상 대화: 샘플링 + fallback
|
338 |
+
response = generate_text_with_temp_and_rep_penalty(model, input_text)
|
339 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
340 |
+
response = generate_text_with_temp_and_rep_penalty(model, input_text)
|
341 |
return response
|
342 |
|
343 |
@app.get("/generate", response_class=PlainTextResponse)
|