Yuchan5386 commited on
Commit
41baa96
·
verified ·
1 Parent(s): b894ce9

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +46 -64
api.py CHANGED
@@ -140,68 +140,50 @@ _ = model(dummy_input) # 모델이 빌드됨
140
  model.load_weights("InteractGPT.weights.h5")
141
  print("모델 가중치 로드 완료!")
142
 
143
- def generate_text_typical(model, prompt, max_len=100, max_gen=98,
144
- temperature=0.50, min_len=20,
145
- repetition_penalty=1.2, typical_p=0.80):
146
-
147
- def typical_filtering(logits, typical_p):
148
- probs = np.exp(logits - np.max(logits))
149
- probs /= probs.sum()
150
-
151
- log_probs = np.log(probs + 1e-9)
152
- entropy = -np.sum(probs * log_probs)
153
-
154
- shifted = np.abs(-log_probs - entropy)
155
- sorted_idx = np.argsort(shifted)
156
- sorted_probs = probs[sorted_idx]
157
-
158
- cum_probs = np.cumsum(sorted_probs)
159
- cutoff = np.searchsorted(cum_probs, typical_p) + 1
160
-
161
- final_idx = sorted_idx[:cutoff]
162
- final_probs = probs[final_idx]
163
- final_probs /= final_probs.sum()
164
-
165
- return final_idx, final_probs
166
-
167
- model_input = text_to_ids(f"<start> {prompt} <sep>")
168
- model_input = model_input[:max_len]
169
- generated = list(model_input)
170
-
171
- for step in range(max_gen):
172
- pad_len = max(0, max_len - len(generated))
173
- input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
174
- input_tensor = tf.convert_to_tensor([input_padded])
175
-
176
- logits = model(input_tensor, training=False)
177
- next_logits = logits[0, len(generated) - 1].numpy()
178
-
179
- # 반복 억제
180
- for t in set(generated):
181
- count = generated.count(t)
182
- next_logits[t] /= (repetition_penalty ** count)
183
-
184
- # 조기 종료 방지
185
- if len(generated) < min_len:
186
- next_logits[end_id] -= 5.0
187
- next_logits[pad_id] -= 10.0
188
-
189
- # 온도 적용
190
- next_logits = next_logits / temperature
191
-
192
- # Typical Sampling 적용
193
- final_idx, final_probs = typical_filtering(next_logits, typical_p=typical_p)
194
- sampled = np.random.choice(final_idx, p=final_probs)
195
-
196
- generated.append(int(sampled))
197
-
198
- decoded = sp.decode(generated)
199
- for t in ["<start>", "<sep>", "<end>"]:
200
- decoded = decoded.replace(t, "")
201
- decoded = decoded.strip()
202
-
203
- if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
204
- return decoded # ← yield 대신 return
205
 
206
  def is_valid_response(response):
207
  if len(response.strip()) < 2:
@@ -284,10 +266,10 @@ def respond(input_text):
284
  summary = get_wikipedia_summary(keyword)
285
  return f"{summary}\n다른 궁금한 점 있으신가요?"
286
 
287
- return generate_text_typical(model, input_text)
288
 
289
  async def async_generator_wrapper(prompt: str):
290
- gen = generate_text_typical(model, prompt)
291
  for text_piece in gen:
292
  yield text_piece
293
  await asyncio.sleep(0.1)
 
140
  model.load_weights("InteractGPT.weights.h5")
141
  print("모델 가중치 로드 완료!")
142
 
143
+ def generate_text_topp(model, prompt, max_len=100, max_gen=98,
144
+ temperature=0.50, min_len=20,
145
+ repetition_penalty=1.2, top_p=0.90):
146
+ def top_p_filtering(logits, top_p):
147
+ probs = np.exp(logits - np.max(logits))
148
+ probs /= probs.sum()
149
+ sorted_idx = np.argsort(-probs)
150
+ sorted_probs = probs[sorted_idx]
151
+ cum_probs = np.cumsum(sorted_probs)
152
+ cutoff = np.searchsorted(cum_probs, top_p) + 1
153
+ final_idx = sorted_idx[:cutoff]
154
+ final_probs = probs[final_idx]
155
+ final_probs /= final_probs.sum()
156
+ return final_idx, final_probs
157
+
158
+ model_input = text_to_ids(f"<start> {prompt} <sep>")
159
+ model_input = model_input[:max_len]
160
+ generated = list(model_input)
161
+ for step in range(max_gen):
162
+ pad_len = max(0, max_len - len(generated))
163
+ input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
164
+ input_tensor = tf.convert_to_tensor([input_padded])
165
+ logits = model(input_tensor, training=False)
166
+ next_logits = logits[0, len(generated) - 1].numpy()
167
+ # 반복 억제
168
+ for t in set(generated):
169
+ count = generated.count(t)
170
+ next_logits[t] /= (repetition_penalty ** count)
171
+ # 조기 종료 방지
172
+ if len(generated) < min_len:
173
+ next_logits[end_id] -= 5.0
174
+ next_logits[pad_id] -= 10.0
175
+ # 온도 적용
176
+ next_logits = next_logits / temperature
177
+ # Top-P Sampling 적용
178
+ final_idx, final_probs = top_p_filtering(next_logits, top_p=top_p)
179
+ sampled = np.random.choice(final_idx, p=final_probs)
180
+ generated.append(int(sampled))
181
+ decoded = sp.decode(generated)
182
+ for t in ["<start>", "<sep>", "<end>"]:
183
+ decoded = decoded.replace(t, "")
184
+ decoded = decoded.strip()
185
+ if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
186
+ return decoded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  def is_valid_response(response):
189
  if len(response.strip()) < 2:
 
266
  summary = get_wikipedia_summary(keyword)
267
  return f"{summary}\n다른 궁금한 점 있으신가요?"
268
 
269
+ return generate_text_topp(model, input_text)
270
 
271
  async def async_generator_wrapper(prompt: str):
272
+ gen = generate_text_topp(model, prompt)
273
  for text_piece in gen:
274
  yield text_piece
275
  await asyncio.sleep(0.1)