Yuchan5386 commited on
Commit
5480bd4
·
verified ·
1 Parent(s): b2b2dad

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +82 -49
api.py CHANGED
@@ -126,55 +126,88 @@ print("모델 가중치 로드 완료!")
126
 
127
  repetition_penalty = 1.2
128
 
129
- async def generate_text_stream(prompt: str):
130
- model_input = text_to_ids(f"<start> {prompt} <sep>")
131
- model_input = model_input[:max_len]
132
- generated = list(model_input)
133
-
134
- tau = 5.0
135
-
136
- while True:
137
- pad_length = max(0, max_len - len(generated))
138
- input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
139
- input_tensor = tf.convert_to_tensor([input_padded])
140
- logits = model(input_tensor, training=False)
141
- next_token_logits = logits[0, len(generated) - 1].numpy()
142
-
143
- # 반복 페널티 적용
144
- token_counts = {}
145
- for t in generated:
146
- token_counts[t] = token_counts.get(t, 0) + 1
147
- for token_id, count in token_counts.items():
148
- next_token_logits[token_id] /= (repetition_penalty ** count)
149
-
150
- if len(generated) >= 20:
151
- next_token_logits[end_id] -= 5.0
152
- next_token_logits[pad_id] -= 10.0
153
- next_token_logits = next_token_logits / 1.0 # temperature 고정
154
-
155
- logits_stable = next_token_logits - np.max(next_token_logits)
156
- probs = np.exp(logits_stable)
157
- probs /= probs.sum()
158
-
159
- sorted_indices = np.argsort(-probs)
160
- top_indices = sorted_indices[:100]
161
- top_probs = probs[top_indices]
162
- top_probs /= top_probs.sum()
163
-
164
- sampled_index = np.random.choice(top_indices, p=top_probs)
165
- generated.append(int(sampled_index))
166
-
167
- new_token_text = sp.decode([int(sampled_index)])
168
-
169
- # 특수 토큰 무시 및 종료 처리
170
- if any(tok in new_token_text for tok in ["<start>", "<sep>", "<end>", "<pad>"]):
171
- if sampled_index == end_id:
172
- break
173
- continue
174
-
175
- yield new_token_text
176
- await asyncio.sleep(0.1)
177
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  @app.get("/generate")
179
  async def generate(request: Request):
180
  prompt = request.query_params.get("prompt", "안녕하세요")
 
126
 
127
  repetition_penalty = 1.2
128
 
129
+ async def generate_text_mirostat_top_p(model, prompt, max_len=100, max_gen=98,
130
+ temperature=1.0, min_len=20,
131
+ repetition_penalty=1.2, eta=0.1, m=100, p=0.9):
132
+ model_input = text_to_ids(f"<start> {prompt} <sep>")
133
+ model_input = model_input[:max_len]
134
+ generated = list(model_input)
135
+
136
+ tau = 5.0 # 초기 목표 surprise
137
+
138
+ for step in range(max_gen):
139
+ pad_length = max(0, max_len - len(generated))
140
+ input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
141
+ input_tensor = tf.convert_to_tensor([input_padded])
142
+ logits = model(input_tensor, training=False)
143
+ next_token_logits = logits[0, len(generated) - 1].numpy()
144
+
145
+ # 반복 페널티 적용
146
+ token_counts = {}
147
+ for t in generated:
148
+ token_counts[t] = token_counts.get(t, 0) + 1
149
+ for token_id, count in token_counts.items():
150
+ next_token_logits[token_id] /= (repetition_penalty ** count)
151
+
152
+ # 최소 길이 넘으면 종료 토큰 확률 낮추기
153
+ if len(generated) >= min_len:
154
+ next_token_logits[end_id] -= 5.0
155
+ next_token_logits[pad_id] -= 10.0
156
+
157
+ # 온도 조절
158
+ next_token_logits = next_token_logits / temperature
159
+
160
+ # --- 미로스타트 + Top-p 샘플링 ---
161
+ logits_stable = next_token_logits - np.max(next_token_logits)
162
+ probs = np.exp(logits_stable)
163
+ probs /= probs.sum()
164
+
165
+ # 1. mirostat top-m 후보 추리기
166
+ sorted_indices = np.argsort(-probs)
167
+ top_indices = sorted_indices[:m]
168
+ top_probs = probs[top_indices]
169
+ top_probs /= top_probs.sum()
170
+
171
+ # 2. mirostat 샘플링
172
+ sampled_index = np.random.choice(top_indices, p=top_probs)
173
+ sampled_prob = probs[sampled_index]
174
+ observed_surprise = -np.log(sampled_prob + 1e-9)
175
+ tau += eta * (observed_surprise - tau)
176
+
177
+ # 3. top-p 필터링
178
+ sorted_top_indices = top_indices[np.argsort(-top_probs)]
179
+ sorted_top_probs = np.sort(top_probs)[::-1]
180
+ cumulative_probs = np.cumsum(sorted_top_probs)
181
+ cutoff = np.searchsorted(cumulative_probs, p, side='left') + 1
182
+ filtered_indices = sorted_top_indices[:cutoff]
183
+ filtered_probs = sorted_top_probs[:cutoff]
184
+ filtered_probs /= filtered_probs.sum()
185
+
186
+ # 4. 최종 토큰 샘플링
187
+ final_token = np.random.choice(filtered_indices, p=filtered_probs)
188
+ generated.append(int(final_token))
189
+
190
+ # 특수 토큰 UI에 표시하지 않기 위해 필터링
191
+ if final_token == end_id:
192
+ # 종료 토큰 만나면 멈춤
193
+ decoded_text = sp.decode(generated)
194
+ for token in ["<start>", "<sep>", "<end>"]:
195
+ decoded_text = decoded_text.replace(token, "")
196
+ decoded_text = decoded_text.strip()
197
+ yield decoded_text
198
+ break
199
+
200
+ if final_token in [start_id, pad_id] or sp.id_to_piece(final_token) == "<sep>":
201
+ # 특수 토큰은 무시하고 출력 안 함
202
+ continue
203
+
204
+ # 일반 토큰인 경우에만 출력
205
+ decoded_text = sp.decode(generated)
206
+ for token in ["<start>", "<sep>", "<end>"]:
207
+ decoded_text = decoded_text.replace(token, "")
208
+ decoded_text = decoded_text.strip()
209
+ yield decoded_text
210
+
211
  @app.get("/generate")
212
  async def generate(request: Request):
213
  prompt = request.query_params.get("prompt", "안녕하세요")