Spaces:
Sleeping
Sleeping
Update api.py
Browse files
api.py
CHANGED
@@ -124,14 +124,16 @@ _ = model(dummy_input)
|
|
124 |
model.load_weights("InteractGPT.weights.h5")
|
125 |
print("모델 가중치 로드 완료!")
|
126 |
|
127 |
-
|
|
|
128 |
temperature=1.0, min_len=20,
|
129 |
-
repetition_penalty=1.2, eta=0.1, m=100, p=0.9):
|
130 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
131 |
model_input = model_input[:max_len]
|
132 |
generated = list(model_input)
|
133 |
|
134 |
tau = 5.0 # 초기 목표 surprise
|
|
|
135 |
|
136 |
for step in range(max_gen):
|
137 |
pad_length = max(0, max_len - len(generated))
|
@@ -177,18 +179,30 @@ async def generate_text_mirostat_top_p(model, prompt, max_len=100, max_gen=98,
|
|
177 |
|
178 |
final_token = np.random.choice(filtered_indices, p=filtered_probs)
|
179 |
|
180 |
-
# 특수 토큰 처리
|
181 |
if final_token == end_id:
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
break
|
|
|
183 |
if final_token in [start_id, pad_id] or sp.id_to_piece(final_token) == "<sep>":
|
184 |
continue
|
185 |
|
186 |
-
# 정상 토큰만 추가 및 yield
|
187 |
generated.append(int(final_token))
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
@app.get("/generate")
|
192 |
async def generate(request: Request):
|
193 |
prompt = request.query_params.get("prompt", "안녕하세요")
|
194 |
-
return StreamingResponse(
|
|
|
124 |
model.load_weights("InteractGPT.weights.h5")
|
125 |
print("모델 가중치 로드 완료!")
|
126 |
|
127 |
+
|
128 |
+
async def generate_text_mirostat_top_p_with_buffer(model, prompt, max_len=100, max_gen=98,
|
129 |
temperature=1.0, min_len=20,
|
130 |
+
repetition_penalty=1.2, eta=0.1, m=100, p=0.9, buffer_size=3):
|
131 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
132 |
model_input = model_input[:max_len]
|
133 |
generated = list(model_input)
|
134 |
|
135 |
tau = 5.0 # 초기 목표 surprise
|
136 |
+
buffer_tokens = []
|
137 |
|
138 |
for step in range(max_gen):
|
139 |
pad_length = max(0, max_len - len(generated))
|
|
|
179 |
|
180 |
final_token = np.random.choice(filtered_indices, p=filtered_probs)
|
181 |
|
|
|
182 |
if final_token == end_id:
|
183 |
+
# 버퍼에 남은 거 다 출력
|
184 |
+
if buffer_tokens:
|
185 |
+
decoded = sp.decode(buffer_tokens)
|
186 |
+
for token in ["<start>", "<sep>", "<end>"]:
|
187 |
+
decoded = decoded.replace(token, "")
|
188 |
+
yield decoded.strip()
|
189 |
break
|
190 |
+
|
191 |
if final_token in [start_id, pad_id] or sp.id_to_piece(final_token) == "<sep>":
|
192 |
continue
|
193 |
|
|
|
194 |
generated.append(int(final_token))
|
195 |
+
buffer_tokens.append(final_token)
|
196 |
+
|
197 |
+
if len(buffer_tokens) >= buffer_size or sp.id_to_piece(final_token).endswith("▁"):
|
198 |
+
# 띄어쓰기 있는 토큰 나오거나 버퍼 꽉 찼으면 출력
|
199 |
+
decoded = sp.decode(buffer_tokens)
|
200 |
+
for token in ["<start>", "<sep>", "<end>"]:
|
201 |
+
decoded = decoded.replace(token, "")
|
202 |
+
yield decoded.strip()
|
203 |
+
buffer_tokens = []
|
204 |
|
205 |
@app.get("/generate")
|
206 |
async def generate(request: Request):
|
207 |
prompt = request.query_params.get("prompt", "안녕하세요")
|
208 |
+
return StreamingResponse(generate_text_mirostat_top_p_with_buffer(prompt), media_type="text/plain")
|