Yuchan5386 commited on
Commit
b3998ae
·
verified ·
1 Parent(s): 5bef4d8

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +33 -31
api.py CHANGED
@@ -171,47 +171,49 @@ def is_greedy_response_acceptable(text):
171
 
172
  return True
173
 
174
- def generate_text_greedy_strong(model, prompt, max_len=100, max_gen=98,
175
- repetition_penalty=1.2, min_len=20):
176
- model_input = text_to_ids(f"<start> {prompt} <sep>")
177
- model_input = model_input[:max_len]
178
- generated = list(model_input)
179
-
180
- for _ in range(max_gen):
181
- pad_len = max(0, max_len - len(generated))
182
- input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
183
- input_tensor = tf.convert_to_tensor([input_padded])
184
- logits = model(input_tensor, training=False)
185
- next_logits = logits[0, len(generated) - 1].numpy()
186
-
187
- # Repetition Penalty
188
- for t in set(generated):
189
- count = generated.count(t)
 
190
  next_logits[t] /= (repetition_penalty ** count)
191
 
192
- # Stop token filtering
193
- stop_tokens = ["음", "어", "그", "뭐지", "..."]
194
- for tok in stop_tokens:
195
- tok_id = sp.piece_to_id(tok)
196
- next_logits[tok_id] -= 5.0
197
 
198
- next_logits[pad_id] -= 10.0
199
- next_token = np.argmax(next_logits)
200
- generated.append(int(next_token))
201
 
202
- decoded = sp.decode(generated)
203
- for t in ["<start>", "<sep>", "<end>"]:
204
- decoded = decoded.replace(t, "")
 
 
 
 
205
  decoded = decoded.strip()
206
 
207
- if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
208
  if is_greedy_response_acceptable(decoded):
209
  return decoded
210
  else:
211
  continue
212
 
213
  return sp.decode(generated)
214
-
215
  def mismatch_tone(input_text, output_text):
216
  if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
217
  return True
@@ -333,9 +335,9 @@ def respond(input_text):
333
  return f"{summary}\n다른 궁금한 점 있으신가요?"
334
 
335
  # 일상 대화: 샘플링 + fallback
336
- response = generate_text_greedy_strong(model, input_text)
337
  if not is_valid_response(response) or mismatch_tone(input_text, response):
338
- response = generate_text_greedy_strong(model, input_text)
339
  return response
340
 
341
  @app.get("/generate", response_class=PlainTextResponse)
 
171
 
172
  return True
173
 
174
+ def generate_text_with_temp_and_rep_penalty(model, prompt, max_len=100, max_gen=98,
175
+ repetition_penalty=1.2, temperature=0.7,
176
+ min_len=20):
177
+ model_input = text_to_ids(f"<start> {prompt} <sep>")
178
+ model_input = model_input[:max_len]
179
+ generated = list(model_input)
180
+
181
+ for _ in range(max_gen):
182
+ pad_len = max(0, max_len - len(generated))
183
+ input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
184
+ input_tensor = tf.convert_to_tensor([input_padded])
185
+ logits = model(input_tensor, training=False)
186
+ next_logits = logits[0, len(generated) - 1].numpy()
187
+
188
+ # Repetition penalty
189
+ for t in set(generated):
190
+ count = generated.count(t)
191
  next_logits[t] /= (repetition_penalty ** count)
192
 
193
+ # Temperature scaling
194
+ next_logits = next_logits / temperature
 
 
 
195
 
196
+ # Softmax로 확률 계산
197
+ exp_logits = np.exp(next_logits - np.max(next_logits))
198
+ probs = exp_logits / exp_logits.sum()
199
 
200
+ # 다음 토큰 샘플링
201
+ next_token = np.random.choice(len(probs), p=probs)
202
+ generated.append(int(next_token))
203
+
204
+ decoded = sp.decode(generated)
205
+ for t in ["<start>", "<sep>", "<end>"]:
206
+ decoded = decoded.replace(t, "")
207
  decoded = decoded.strip()
208
 
209
+ if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
210
  if is_greedy_response_acceptable(decoded):
211
  return decoded
212
  else:
213
  continue
214
 
215
  return sp.decode(generated)
216
+
217
  def mismatch_tone(input_text, output_text):
218
  if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
219
  return True
 
335
  return f"{summary}\n다른 궁금한 점 있으신가요?"
336
 
337
  # 일상 대화: 샘플링 + fallback
338
+ response = generate_text_with_temp_and_rep_penalty(model, input_text)
339
  if not is_valid_response(response) or mismatch_tone(input_text, response):
340
+ response = generate_text_with_temp_and_rep_penalty(model, input_text)
341
  return response
342
 
343
  @app.get("/generate", response_class=PlainTextResponse)