Yuchan5386 commited on
Commit
35df76e
·
verified ·
1 Parent(s): 49d8bca

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +15 -2
api.py CHANGED
@@ -167,7 +167,7 @@ def is_greedy_response_acceptable(text):
167
  return True
168
 
169
  def generate_text_sample(model, prompt, max_len=100, max_gen=98,
170
- temperature=0.7, top_k=40, min_len=12):
171
  model_input = text_to_ids(f"<start> {prompt} <sep>")
172
  model_input = model_input[:max_len]
173
  generated = list(model_input)
@@ -190,6 +190,20 @@ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
190
  probs[indices_to_remove] = 0
191
  probs /= probs.sum()
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  # 샘플링
194
  next_token = np.random.choice(len(probs), p=probs)
195
  generated.append(int(next_token))
@@ -210,7 +224,6 @@ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
210
  for t in ["<start>", "<sep>", "<end>"]:
211
  decoded = decoded.replace(t, "")
212
  return decoded.strip()
213
-
214
 
215
  def mismatch_tone(input_text, output_text):
216
  if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):
 
167
  return True
168
 
169
  def generate_text_sample(model, prompt, max_len=100, max_gen=98,
170
+ temperature=0.7, top_k=40, top_p=0.9, min_len=12):
171
  model_input = text_to_ids(f"<start> {prompt} <sep>")
172
  model_input = model_input[:max_len]
173
  generated = list(model_input)
 
190
  probs[indices_to_remove] = 0
191
  probs /= probs.sum()
192
 
193
+ # Top-P (누적 확률) 필터링
194
+ if top_p is not None and 0 < top_p < 1:
195
+ sorted_indices = np.argsort(probs)[::-1]
196
+ sorted_probs = probs[sorted_indices]
197
+ cumulative_probs = np.cumsum(sorted_probs)
198
+ # 누적 확률이 top_p 초과하는 토큰들은 제거
199
+ cutoff_index = np.searchsorted(cumulative_probs, top_p, side='right')
200
+ probs_to_keep = sorted_indices[:cutoff_index+1]
201
+
202
+ mask = np.ones_like(probs, dtype=bool)
203
+ mask[probs_to_keep] = False
204
+ probs[mask] = 0
205
+ probs /= probs.sum()
206
+
207
  # 샘플링
208
  next_token = np.random.choice(len(probs), p=probs)
209
  generated.append(int(next_token))
 
224
  for t in ["<start>", "<sep>", "<end>"]:
225
  decoded = decoded.replace(t, "")
226
  return decoded.strip()
 
227
 
228
  def mismatch_tone(input_text, output_text):
229
  if "ㅋㅋ" in input_text and not re.search(r'ㅋㅋ|ㅎ|재밌|놀|만나|맛집|여행', output_text):