Yuchan5386 commited on
Commit
feab31b
·
verified ·
1 Parent(s): d348825

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +51 -72
api.py CHANGED
@@ -138,81 +138,60 @@ _ = model(dummy_input) # 모델이 빌드됨
138
  model.load_weights("InteractGPT.weights.h5")
139
  print("모델 가중치 로드 완료!")
140
 
 
 
 
 
 
 
141
 
142
- def generate_text_mirostat_top_p(model, prompt, max_len=100, max_gen=98,
143
- temperature=1.0, min_len=20,
144
- repetition_penalty=1.2, eta=0.1, m=100, p=0.9):
145
- model_input = text_to_ids(f"<start> {prompt} <sep>")
146
- model_input = model_input[:max_len]
147
- generated = list(model_input)
148
-
149
- tau = 5.0 # 초기 목표 surprise
150
-
151
- for step in range(max_gen):
152
- pad_length = max(0, max_len - len(generated))
153
- input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
154
- input_tensor = tf.convert_to_tensor([input_padded])
155
- logits = model(input_tensor, training=False)
156
- next_token_logits = logits[0, len(generated) - 1].numpy()
157
-
158
- # 반복 페널티 적용
159
- token_counts = {}
160
- for t in generated:
161
- token_counts[t] = token_counts.get(t, 0) + 1
162
- for token_id, count in token_counts.items():
163
- next_token_logits[token_id] /= (repetition_penalty ** count)
164
-
165
- # 최소 길이 넘으면 종료 토큰 확률 낮추기
166
- if len(generated) >= min_len:
167
- next_token_logits[end_id] -= 5.0
168
- next_token_logits[pad_id] -= 10.0
169
-
170
- # 온도 조절
171
- next_token_logits = next_token_logits / temperature
172
-
173
- # --- 미로스타트 + Top-p 샘플링 ---
174
- logits_stable = next_token_logits - np.max(next_token_logits)
175
- probs = np.exp(logits_stable)
176
- probs /= probs.sum()
177
-
178
- # 1. mirostat top-m 후보 추리기
179
- sorted_indices = np.argsort(-probs)
180
- top_indices = sorted_indices[:m]
181
- top_probs = probs[top_indices]
182
- top_probs /= top_probs.sum()
183
-
184
- # 2. mirostat 샘플링
185
- sampled_index = np.random.choice(top_indices, p=top_probs)
186
- sampled_prob = probs[sampled_index]
187
- observed_surprise = -np.log(sampled_prob + 1e-9)
188
- tau += eta * (observed_surprise - tau)
189
-
190
- # 3. top-p 필터링
191
- sorted_top_indices = top_indices[np.argsort(-top_probs)]
192
- sorted_top_probs = np.sort(top_probs)[::-1]
193
- cumulative_probs = np.cumsum(sorted_top_probs)
194
- cutoff = np.searchsorted(cumulative_probs, p, side='left') + 1
195
- filtered_indices = sorted_top_indices[:cutoff]
196
- filtered_probs = sorted_top_probs[:cutoff]
197
- filtered_probs /= filtered_probs.sum()
198
-
199
- # 4. 최종 토큰 샘플링
200
- final_token = np.random.choice(filtered_indices, p=filtered_probs)
201
- generated.append(int(final_token))
202
-
203
- decoded_text = sp.decode(generated)
204
- # 특수 토큰 제거
205
- for token in ["<start>", "<sep>", "<end>"]:
206
- decoded_text = decoded_text.replace(token, "")
207
-
208
- decoded_text = decoded_text.strip()
209
-
210
- if len(generated) >= min_len and (final_token == end_id or decoded_text.endswith(('.', '!', '?', '<end>'))):
211
- yield decoded_text
212
- break
213
 
214
  async def async_generator_wrapper(prompt: str):
215
- gen = generate_text_mirostat_top_p(model, prompt)
216
  for text_piece in gen:
217
  yield text_piece
218
  await asyncio.sleep(0.1)
 
138
  model.load_weights("InteractGPT.weights.h5")
139
  print("모델 가중치 로드 완료!")
140
 
141
+ def generate_text_top_p(model, prompt, max_len=100, max_gen=98,
142
+ temperature=1.0, min_len=20,
143
+ repetition_penalty=1.1, top_p=0.9):
144
+ model_input = text_to_ids(f"<start> {prompt} <sep>")
145
+ model_input = model_input[:max_len]
146
+ generated = list(model_input)
147
 
148
+ for step in range(max_gen):
149
+ pad_len = max(0, max_len - len(generated))
150
+ input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
151
+ input_tensor = tf.convert_to_tensor([input_padded])
152
+
153
+ logits = model(input_tensor, training=False)
154
+ next_logits = logits[0, len(generated) - 1].numpy()
155
+
156
+ # 반복 억제 penalty
157
+ for t in set(generated):
158
+ count = generated.count(t)
159
+ next_logits[t] /= (repetition_penalty ** count)
160
+
161
+ # 종료 조건 방지
162
+ if len(generated) < min_len:
163
+ next_logits[end_id] -= 5.0
164
+ next_logits[pad_id] -= 10.0
165
+
166
+ # 온도 적용
167
+ next_logits = next_logits / temperature
168
+ probs = np.exp(next_logits - np.max(next_logits))
169
+ probs /= probs.sum()
170
+
171
+ # Top-p 필터링
172
+ sorted_idx = np.argsort(-probs)
173
+ sorted_probs = probs[sorted_idx]
174
+ cum_probs = np.cumsum(sorted_probs)
175
+ cutoff = np.searchsorted(cum_probs, top_p) + 1
176
+
177
+ filtered_idx = sorted_idx[:cutoff]
178
+ filtered_probs = sorted_probs[:cutoff]
179
+ filtered_probs /= filtered_probs.sum()
180
+
181
+ sampled = np.random.choice(filtered_idx, p=filtered_probs)
182
+ generated.append(int(sampled))
183
+
184
+ decoded = sp.decode(generated)
185
+ for t in ["<start>", "<sep>", "<end>"]:
186
+ decoded = decoded.replace(t, "")
187
+ decoded = decoded.strip()
188
+
189
+ if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
190
+ yield decoded
191
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  async def async_generator_wrapper(prompt: str):
194
+ gen = generate_text_top_p(model, prompt)
195
  for text_piece in gen:
196
  yield text_piece
197
  await asyncio.sleep(0.1)