Yuchan5386 commited on
Commit
ed109ab
·
verified ·
1 Parent(s): 11d37d9

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +28 -22
api.py CHANGED
@@ -138,9 +138,30 @@ _ = model(dummy_input) # 모델이 빌드됨
138
  model.load_weights("InteractGPT.weights.h5")
139
  print("모델 가중치 로드 완료!")
140
 
141
- def generate_text_top_kp(model, prompt, max_len=100, max_gen=98,
142
- temperature=0.7, min_len=20,
143
- repetition_penalty=1.1, top_k=50, top_p=0.9):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  model_input = text_to_ids(f"<start> {prompt} <sep>")
145
  model_input = model_input[:max_len]
146
  generated = list(model_input)
@@ -165,26 +186,11 @@ def generate_text_top_kp(model, prompt, max_len=100, max_gen=98,
165
 
166
  # 온도 적용
167
  next_logits = next_logits / temperature
168
- probs = np.exp(next_logits - np.max(next_logits))
169
- probs /= probs.sum()
170
-
171
- # Top-K 적용
172
- top_k = min(top_k, len(probs))
173
- top_k_idx = np.argsort(-probs)[:top_k]
174
- top_k_probs = probs[top_k_idx]
175
- top_k_probs /= top_k_probs.sum()
176
-
177
- # Top-P 필터링
178
- sorted_idx = np.argsort(-top_k_probs)
179
- sorted_probs = top_k_probs[sorted_idx]
180
- cum_probs = np.cumsum(sorted_probs)
181
- cutoff = np.searchsorted(cum_probs, top_p) + 1
182
-
183
- final_idx = top_k_idx[sorted_idx[:cutoff]]
184
- final_probs = sorted_probs[:cutoff]
185
- final_probs /= final_probs.sum()
186
 
 
 
187
  sampled = np.random.choice(final_idx, p=final_probs)
 
188
  generated.append(int(sampled))
189
 
190
  decoded = sp.decode(generated)
@@ -197,7 +203,7 @@ def generate_text_top_kp(model, prompt, max_len=100, max_gen=98,
197
  break
198
 
199
  async def async_generator_wrapper(prompt: str):
200
- gen = generate_text_top_kp(model, prompt)
201
  for text_piece in gen:
202
  yield text_piece
203
  await asyncio.sleep(0.1)
 
138
  model.load_weights("InteractGPT.weights.h5")
139
  print("모델 가중치 로드 완료!")
140
 
141
+ def generate_text_typical(model, prompt, max_len=100, max_gen=98,
142
+ temperature=0.7, min_len=20,
143
+ repetition_penalty=1.1, typical_p=0.95):
144
+
145
+ def typical_filtering(logits, typical_p):
146
+ probs = np.exp(logits - np.max(logits))
147
+ probs /= probs.sum()
148
+
149
+ log_probs = np.log(probs + 1e-9)
150
+ entropy = -np.sum(probs * log_probs)
151
+
152
+ shifted = np.abs(-log_probs - entropy)
153
+ sorted_idx = np.argsort(shifted)
154
+ sorted_probs = probs[sorted_idx]
155
+
156
+ cum_probs = np.cumsum(sorted_probs)
157
+ cutoff = np.searchsorted(cum_probs, typical_p) + 1
158
+
159
+ final_idx = sorted_idx[:cutoff]
160
+ final_probs = probs[final_idx]
161
+ final_probs /= final_probs.sum()
162
+
163
+ return final_idx, final_probs
164
+
165
  model_input = text_to_ids(f"<start> {prompt} <sep>")
166
  model_input = model_input[:max_len]
167
  generated = list(model_input)
 
186
 
187
  # 온도 적용
188
  next_logits = next_logits / temperature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ # Typical Sampling 적용
191
+ final_idx, final_probs = typical_filtering(next_logits, typical_p=typical_p)
192
  sampled = np.random.choice(final_idx, p=final_probs)
193
+
194
  generated.append(int(sampled))
195
 
196
  decoded = sp.decode(generated)
 
203
  break
204
 
205
  async def async_generator_wrapper(prompt: str):
206
+ gen = generate_text_typical(model, prompt)
207
  for text_piece in gen:
208
  yield text_piece
209
  await asyncio.sleep(0.1)