Yuchan5386 commited on
Commit
9ab8176
·
verified ·
1 Parent(s): feab31b

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +20 -14
api.py CHANGED
@@ -138,9 +138,9 @@ _ = model(dummy_input) # 모델이 빌드됨
138
  model.load_weights("InteractGPT.weights.h5")
139
  print("모델 가중치 로드 완료!")
140
 
141
- def generate_text_top_p(model, prompt, max_len=100, max_gen=98,
142
- temperature=1.0, min_len=20,
143
- repetition_penalty=1.1, top_p=0.9):
144
  model_input = text_to_ids(f"<start> {prompt} <sep>")
145
  model_input = model_input[:max_len]
146
  generated = list(model_input)
@@ -153,12 +153,12 @@ def generate_text_top_p(model, prompt, max_len=100, max_gen=98,
153
  logits = model(input_tensor, training=False)
154
  next_logits = logits[0, len(generated) - 1].numpy()
155
 
156
- # 반복 억제 penalty
157
  for t in set(generated):
158
  count = generated.count(t)
159
  next_logits[t] /= (repetition_penalty ** count)
160
 
161
- # 종료 조건 방지
162
  if len(generated) < min_len:
163
  next_logits[end_id] -= 5.0
164
  next_logits[pad_id] -= 10.0
@@ -168,17 +168,23 @@ def generate_text_top_p(model, prompt, max_len=100, max_gen=98,
168
  probs = np.exp(next_logits - np.max(next_logits))
169
  probs /= probs.sum()
170
 
171
- # Top-p 필터링
172
- sorted_idx = np.argsort(-probs)
173
- sorted_probs = probs[sorted_idx]
 
 
 
 
 
 
174
  cum_probs = np.cumsum(sorted_probs)
175
  cutoff = np.searchsorted(cum_probs, top_p) + 1
176
 
177
- filtered_idx = sorted_idx[:cutoff]
178
- filtered_probs = sorted_probs[:cutoff]
179
- filtered_probs /= filtered_probs.sum()
180
 
181
- sampled = np.random.choice(filtered_idx, p=filtered_probs)
182
  generated.append(int(sampled))
183
 
184
  decoded = sp.decode(generated)
@@ -189,9 +195,9 @@ def generate_text_top_p(model, prompt, max_len=100, max_gen=98,
189
  if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
190
  yield decoded
191
  break
192
-
193
  async def async_generator_wrapper(prompt: str):
194
- gen = generate_text_top_p(model, prompt)
195
  for text_piece in gen:
196
  yield text_piece
197
  await asyncio.sleep(0.1)
 
138
  model.load_weights("InteractGPT.weights.h5")
139
  print("모델 가중치 로드 완료!")
140
 
141
+ def generate_text_top_kp(model, prompt, max_len=100, max_gen=98,
142
+ temperature=1.0, min_len=20,
143
+ repetition_penalty=1.1, top_k=40, top_p=0.9):
144
  model_input = text_to_ids(f"<start> {prompt} <sep>")
145
  model_input = model_input[:max_len]
146
  generated = list(model_input)
 
153
  logits = model(input_tensor, training=False)
154
  next_logits = logits[0, len(generated) - 1].numpy()
155
 
156
+ # 반복 억제
157
  for t in set(generated):
158
  count = generated.count(t)
159
  next_logits[t] /= (repetition_penalty ** count)
160
 
161
+ # 조기 종료 방지
162
  if len(generated) < min_len:
163
  next_logits[end_id] -= 5.0
164
  next_logits[pad_id] -= 10.0
 
168
  probs = np.exp(next_logits - np.max(next_logits))
169
  probs /= probs.sum()
170
 
171
+ # Top-K 적용
172
+ top_k = min(top_k, len(probs))
173
+ top_k_idx = np.argsort(-probs)[:top_k]
174
+ top_k_probs = probs[top_k_idx]
175
+ top_k_probs /= top_k_probs.sum()
176
+
177
+ # Top-P 필터링
178
+ sorted_idx = np.argsort(-top_k_probs)
179
+ sorted_probs = top_k_probs[sorted_idx]
180
  cum_probs = np.cumsum(sorted_probs)
181
  cutoff = np.searchsorted(cum_probs, top_p) + 1
182
 
183
+ final_idx = top_k_idx[sorted_idx[:cutoff]]
184
+ final_probs = sorted_probs[:cutoff]
185
+ final_probs /= final_probs.sum()
186
 
187
+ sampled = np.random.choice(final_idx, p=final_probs)
188
  generated.append(int(sampled))
189
 
190
  decoded = sp.decode(generated)
 
195
  if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
196
  yield decoded
197
  break
198
+
199
  async def async_generator_wrapper(prompt: str):
200
+ gen = generate_text_top_kp(model, prompt)
201
  for text_piece in gen:
202
  yield text_piece
203
  await asyncio.sleep(0.1)