Yuchan5386 commited on
Commit
b0c813a
·
verified ·
1 Parent(s): 41baa96

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +13 -8
api.py CHANGED
@@ -140,14 +140,19 @@ _ = model(dummy_input) # 모델이 빌드됨
140
  model.load_weights("InteractGPT.weights.h5")
141
  print("모델 가중치 로드 완료!")
142
 
143
- def generate_text_topp(model, prompt, max_len=100, max_gen=98,
144
- temperature=0.50, min_len=20,
145
- repetition_penalty=1.2, top_p=0.90):
146
- def top_p_filtering(logits, top_p):
147
  probs = np.exp(logits - np.max(logits))
148
  probs /= probs.sum()
149
  sorted_idx = np.argsort(-probs)
150
  sorted_probs = probs[sorted_idx]
 
 
 
 
 
151
  cum_probs = np.cumsum(sorted_probs)
152
  cutoff = np.searchsorted(cum_probs, top_p) + 1
153
  final_idx = sorted_idx[:cutoff]
@@ -174,8 +179,8 @@ def generate_text_topp(model, prompt, max_len=100, max_gen=98,
174
  next_logits[pad_id] -= 10.0
175
  # 온도 적용
176
  next_logits = next_logits / temperature
177
- # Top-P Sampling 적용
178
- final_idx, final_probs = top_p_filtering(next_logits, top_p=top_p)
179
  sampled = np.random.choice(final_idx, p=final_probs)
180
  generated.append(int(sampled))
181
  decoded = sp.decode(generated)
@@ -266,10 +271,10 @@ def respond(input_text):
266
  summary = get_wikipedia_summary(keyword)
267
  return f"{summary}\n다른 궁금한 점 있으신가요?"
268
 
269
- return generate_text_topp(model, input_text)
270
 
271
  async def async_generator_wrapper(prompt: str):
272
- gen = generate_text_topp(model, prompt)
273
  for text_piece in gen:
274
  yield text_piece
275
  await asyncio.sleep(0.1)
 
140
  model.load_weights("InteractGPT.weights.h5")
141
  print("모델 가중치 로드 완료!")
142
 
143
+ def generate_text_topkp(model, prompt, max_len=100, max_gen=98,
144
+ temperature=0.50, min_len=20,
145
+ repetition_penalty=1.2, top_p=0.90, top_k=50):
146
+ def top_kp_filtering(logits, top_k, top_p):
147
  probs = np.exp(logits - np.max(logits))
148
  probs /= probs.sum()
149
  sorted_idx = np.argsort(-probs)
150
  sorted_probs = probs[sorted_idx]
151
+ # Top-K 필터링
152
+ if top_k > 0:
153
+ sorted_idx = sorted_idx[:top_k]
154
+ sorted_probs = sorted_probs[:top_k]
155
+ # Top-P 필터링
156
  cum_probs = np.cumsum(sorted_probs)
157
  cutoff = np.searchsorted(cum_probs, top_p) + 1
158
  final_idx = sorted_idx[:cutoff]
 
179
  next_logits[pad_id] -= 10.0
180
  # 온도 적용
181
  next_logits = next_logits / temperature
182
+ # Top-KP Sampling 적용
183
+ final_idx, final_probs = top_kp_filtering(next_logits, top_k=top_k, top_p=top_p)
184
  sampled = np.random.choice(final_idx, p=final_probs)
185
  generated.append(int(sampled))
186
  decoded = sp.decode(generated)
 
271
  summary = get_wikipedia_summary(keyword)
272
  return f"{summary}\n다른 궁금한 점 있으신가요?"
273
 
274
+ return generate_text_topkp(model, input_text)
275
 
276
  async def async_generator_wrapper(prompt: str):
277
+ gen = generate_text_topkp(model, prompt)
278
  for text_piece in gen:
279
  yield text_piece
280
  await asyncio.sleep(0.1)