Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -150,33 +150,35 @@ from sklearn.metrics.pairwise import cosine_similarity
|
|
150 |
from fastapi import Request
|
151 |
from fastapi.responses import PlainTextResponse
|
152 |
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
172 |
return True
|
173 |
|
174 |
-
def
|
175 |
-
|
176 |
-
|
177 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
178 |
model_input = model_input[:max_len]
|
179 |
generated = list(model_input)
|
|
|
180 |
|
181 |
for _ in range(max_gen):
|
182 |
pad_len = max(0, max_len - len(generated))
|
@@ -190,22 +192,24 @@ def generate_text_with_temp_and_rep_penalty(model, prompt, max_len=100, max_gen=
|
|
190 |
count = generated.count(t)
|
191 |
next_logits[t] /= (repetition_penalty ** count)
|
192 |
|
193 |
-
#
|
194 |
-
|
195 |
-
|
196 |
-
# Softmax로 확률 계산
|
197 |
-
exp_logits = np.exp(next_logits - np.max(next_logits))
|
198 |
probs = exp_logits / exp_logits.sum()
|
199 |
|
200 |
-
#
|
201 |
next_token = np.random.choice(len(probs), p=probs)
|
202 |
generated.append(int(next_token))
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
|
|
|
|
208 |
|
|
|
|
|
209 |
if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
|
210 |
if is_greedy_response_acceptable(decoded):
|
211 |
return decoded
|
|
|
150 |
from fastapi import Request
|
151 |
from fastapi.responses import PlainTextResponse
|
152 |
|
153 |
+
|
154 |
+
def is_greedy_response_acceptable(text):
|
155 |
+
text = text.strip()
|
156 |
+
|
157 |
+
# 너무 짧은 문장 거르기
|
158 |
+
if len(text) < 5:
|
159 |
+
return False
|
160 |
+
|
161 |
+
# 단어 수 너무 적은 것도 거름
|
162 |
+
if len(text.split()) < 3:
|
163 |
+
return False
|
164 |
+
|
165 |
+
# ㅋㅋㅋ 같은 자모 연속만 있으면 거름 (단, 'ㅋㅋ' 포함되면 허용)
|
166 |
+
if re.search(r'[ㄱ-ㅎㅏ-ㅣ]{3,}', text) and 'ㅋㅋ' not in text:
|
167 |
+
return False
|
168 |
+
|
169 |
+
# 문장 끝이 어색한 경우 (다/요/죠 등 일반적 형태로 끝나지 않으면 거름)
|
170 |
+
if not re.search(r'(다|요|죠|다\.|요\.|죠\.|다!|요!|죠!|\!|\?|\.)$', text):
|
171 |
+
return False
|
172 |
+
|
173 |
return True
|
174 |
|
175 |
+
def generate_text_with_mirostat(model, prompt, max_len=100, max_gen=98,
|
176 |
+
repetition_penalty=1.2, tau_init=5.0, mu=5.0,
|
177 |
+
eta=0.1, min_len=20):
|
178 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
179 |
model_input = model_input[:max_len]
|
180 |
generated = list(model_input)
|
181 |
+
tau = tau_init
|
182 |
|
183 |
for _ in range(max_gen):
|
184 |
pad_len = max(0, max_len - len(generated))
|
|
|
192 |
count = generated.count(t)
|
193 |
next_logits[t] /= (repetition_penalty ** count)
|
194 |
|
195 |
+
# Mirostat: softmax with current τ
|
196 |
+
scaled_logits = next_logits / tau
|
197 |
+
exp_logits = np.exp(scaled_logits - np.max(scaled_logits))
|
|
|
|
|
198 |
probs = exp_logits / exp_logits.sum()
|
199 |
|
200 |
+
# 샘플링
|
201 |
next_token = np.random.choice(len(probs), p=probs)
|
202 |
generated.append(int(next_token))
|
203 |
|
204 |
+
# 정보량 계산 (r = -log(p(x)))
|
205 |
+
prob = probs[next_token]
|
206 |
+
r = -np.log(prob + 1e-10)
|
207 |
+
|
208 |
+
# τ 업데이트
|
209 |
+
tau = tau - eta * (r - mu)
|
210 |
|
211 |
+
# 디코딩 및 조건 검증
|
212 |
+
decoded = sp.decode(generated).replace("<start>", "").replace("<sep>", "").replace("<end>", "").strip()
|
213 |
if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('요', '다', '.', '!', '?'))):
|
214 |
if is_greedy_response_acceptable(decoded):
|
215 |
return decoded
|