Yuchan5386 commited on
Commit
70b0917
·
verified ·
1 Parent(s): 2812b06

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +36 -32
api.py CHANGED
@@ -150,33 +150,35 @@ from sklearn.metrics.pairwise import cosine_similarity
150
  from fastapi import Request
151
  from fastapi.responses import PlainTextResponse
152
 
153
- def is_greedy_response_acceptable(text):
154
- text = text.strip()
155
-
156
- # 너무 짧은 문장 거르기
157
- if len(text) < 5:
158
- return False
159
-
160
- # 단어 수 너무 적은 것도 거름
161
- if len(text.split()) < 3:
162
- return False
163
-
164
- # ㅋㅋㅋ 같은 자모 연속만 있으면 거름 (단, 'ㅋㅋ' 포함되면 허용)
165
- if re.search(r'[ㄱ-ㅎㅏ-ㅣ]{3,}', text) and 'ㅋㅋ' not in text:
166
- return False
167
-
168
- # 문장 끝이 어색한 경우 (다/요/죠 등 일반적 형태로 끝나지 않으면 거름)
169
- if not re.search(r'(다|요|죠|다\.|요\.|죠\.|다!|요!|죠!|\!|\?|\.)$', text):
170
- return False
171
-
 
172
  return True
173
 
174
- def generate_text_with_temp_and_rep_penalty(model, prompt, max_len=100, max_gen=98,
175
- repetition_penalty=1.2, temperature=0.7,
176
- min_len=20):
177
  model_input = text_to_ids(f"<start> {prompt} <sep>")
178
  model_input = model_input[:max_len]
179
  generated = list(model_input)
 
180
 
181
  for _ in range(max_gen):
182
  pad_len = max(0, max_len - len(generated))
@@ -190,22 +192,24 @@ def generate_text_with_temp_and_rep_penalty(model, prompt, max_len=100, max_gen=
190
  count = generated.count(t)
191
  next_logits[t] /= (repetition_penalty ** count)
192
 
193
- # Temperature scaling
194
- next_logits = next_logits / temperature
195
-
196
- # Softmax로 확률 계산
197
- exp_logits = np.exp(next_logits - np.max(next_logits))
198
  probs = exp_logits / exp_logits.sum()
199
 
200
- # 다음 토큰 샘플링
201
  next_token = np.random.choice(len(probs), p=probs)
202
  generated.append(int(next_token))
203
 
204
- decoded = sp.decode(generated)
205
- for t in ["<start>", "<sep>", "<end>"]:
206
- decoded = decoded.replace(t, "")
207
- decoded = decoded.strip()
 
 
208
 
 
 
209
  if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
210
  if is_greedy_response_acceptable(decoded):
211
  return decoded
 
150
  from fastapi import Request
151
  from fastapi.responses import PlainTextResponse
152
 
153
+
154
+ def is_greedy_response_acceptable(text):
155
+ text = text.strip()
156
+
157
+ # 너무 짧은 문장 거르기
158
+ if len(text) < 5:
159
+ return False
160
+
161
+ # 단어 너무 적은 것도 거름
162
+ if len(text.split()) < 3:
163
+ return False
164
+
165
+ # ㅋㅋㅋ 같은 자모 연속만 있으면 거름 (단, 'ㅋㅋ' 포함되면 허용)
166
+ if re.search(r'[ㄱ-ㅎㅏ-ㅣ]{3,}', text) and 'ㅋㅋ' not in text:
167
+ return False
168
+
169
+ # 문장 끝이 어색한 경우 (다/요/죠 등 일반적 형태로 끝나지 않으면 거름)
170
+ if not re.search(r'(다|요|죠|다\.|요\.|죠\.|다!|요!|죠!|\!|\?|\.)$', text):
171
+ return False
172
+
173
  return True
174
 
175
+ def generate_text_with_mirostat(model, prompt, max_len=100, max_gen=98,
176
+ repetition_penalty=1.2, tau_init=5.0, mu=5.0,
177
+ eta=0.1, min_len=20):
178
  model_input = text_to_ids(f"<start> {prompt} <sep>")
179
  model_input = model_input[:max_len]
180
  generated = list(model_input)
181
+ tau = tau_init
182
 
183
  for _ in range(max_gen):
184
  pad_len = max(0, max_len - len(generated))
 
192
  count = generated.count(t)
193
  next_logits[t] /= (repetition_penalty ** count)
194
 
195
+ # Mirostat: softmax with current τ
196
+ scaled_logits = next_logits / tau
197
+ exp_logits = np.exp(scaled_logits - np.max(scaled_logits))
 
 
198
  probs = exp_logits / exp_logits.sum()
199
 
200
+ # 샘플링
201
  next_token = np.random.choice(len(probs), p=probs)
202
  generated.append(int(next_token))
203
 
204
+ # 정보량 계산 (r = -log(p(x)))
205
+ prob = probs[next_token]
206
+ r = -np.log(prob + 1e-10)
207
+
208
+ # τ 업데이트
209
+ tau = tau - eta * (r - mu)
210
 
211
+ # 디코딩 및 조건 검증
212
+ decoded = sp.decode(generated).replace("<start>", "").replace("<sep>", "").replace("<end>", "").strip()
213
  if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
214
  if is_greedy_response_acceptable(decoded):
215
  return decoded