Spaces:
Running
Running
Update api.py
Browse files
api.py
CHANGED
@@ -166,23 +166,35 @@ def is_greedy_response_acceptable(text):
|
|
166 |
|
167 |
return True
|
168 |
|
169 |
-
def
|
|
|
170 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
171 |
model_input = model_input[:max_len]
|
172 |
generated = list(model_input)
|
173 |
|
174 |
-
for _ in range(
|
175 |
pad_len = max(0, max_len - len(generated))
|
176 |
input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
|
177 |
input_tensor = tf.convert_to_tensor([input_padded])
|
178 |
logits = model(input_tensor, training=False)
|
179 |
next_logits = logits[0, len(generated) - 1].numpy()
|
180 |
|
181 |
-
#
|
182 |
-
|
183 |
-
|
|
|
184 |
|
185 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
decoded = sp.decode(generated)
|
187 |
for t in ["<start>", "<sep>", "<end>"]:
|
188 |
decoded = decoded.replace(t, "")
|
@@ -194,7 +206,6 @@ def generate_text_greedy(model, prompt, max_len=100, min_len=12):
|
|
194 |
else:
|
195 |
continue
|
196 |
|
197 |
-
# 끝까지 만족하는 문장 없으면 그냥 최종 출력
|
198 |
decoded = sp.decode(generated)
|
199 |
for t in ["<start>", "<sep>", "<end>"]:
|
200 |
decoded = decoded.replace(t, "")
|
@@ -324,9 +335,9 @@ def respond(input_text):
|
|
324 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
325 |
|
326 |
# 일상 대화: 샘플링 + fallback
|
327 |
-
response =
|
328 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
329 |
-
response =
|
330 |
return response
|
331 |
|
332 |
@app.get("/generate", response_class=PlainTextResponse)
|
|
|
166 |
|
167 |
return True
|
168 |
|
169 |
+
def generate_text_sample(model, prompt, max_len=100, max_gen=98,
|
170 |
+
temperature=0.7, top_k=40, min_len=12):
|
171 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
172 |
model_input = model_input[:max_len]
|
173 |
generated = list(model_input)
|
174 |
|
175 |
+
for _ in range(max_gen):
|
176 |
pad_len = max(0, max_len - len(generated))
|
177 |
input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
|
178 |
input_tensor = tf.convert_to_tensor([input_padded])
|
179 |
logits = model(input_tensor, training=False)
|
180 |
next_logits = logits[0, len(generated) - 1].numpy()
|
181 |
|
182 |
+
# Temperature 적용
|
183 |
+
next_logits = next_logits / temperature
|
184 |
+
probs = np.exp(next_logits - np.max(next_logits))
|
185 |
+
probs = probs / probs.sum()
|
186 |
|
187 |
+
# Top-K 필터링
|
188 |
+
if top_k is not None and top_k > 0:
|
189 |
+
indices_to_remove = probs < np.sort(probs)[-top_k]
|
190 |
+
probs[indices_to_remove] = 0
|
191 |
+
probs /= probs.sum()
|
192 |
+
|
193 |
+
# 샘플링
|
194 |
+
next_token = np.random.choice(len(probs), p=probs)
|
195 |
+
generated.append(int(next_token))
|
196 |
+
|
197 |
+
# 디코딩 및 후처리
|
198 |
decoded = sp.decode(generated)
|
199 |
for t in ["<start>", "<sep>", "<end>"]:
|
200 |
decoded = decoded.replace(t, "")
|
|
|
206 |
else:
|
207 |
continue
|
208 |
|
|
|
209 |
decoded = sp.decode(generated)
|
210 |
for t in ["<start>", "<sep>", "<end>"]:
|
211 |
decoded = decoded.replace(t, "")
|
|
|
335 |
return f"{summary}\n다른 궁금한 점 있으신가요?"
|
336 |
|
337 |
# 일상 대화: 샘플링 + fallback
|
338 |
+
response = generate_text_sample(model, input_text)
|
339 |
if not is_valid_response(response) or mismatch_tone(input_text, response):
|
340 |
+
response = generate_text_sample(model, input_text)
|
341 |
return response
|
342 |
|
343 |
@app.get("/generate", response_class=PlainTextResponse)
|