Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -163,7 +163,7 @@ def decode_sp_tokens(tokens):
|
|
163 |
text = ''.join(tokens).replace('โ', ' ').strip()
|
164 |
return text
|
165 |
|
166 |
-
def
|
167 |
model_input = text_to_ids(f"<start> {prompt}")
|
168 |
model_input = model_input[:max_len]
|
169 |
generated = list(model_input)
|
@@ -176,43 +176,36 @@ def generate_text_topp_stream(model, prompt, max_len=100, max_gen=98, p=0.9, tem
|
|
176 |
logits = model(input_tensor, training=False)
|
177 |
next_token_logits = logits[0, len(generated) - 1].numpy()
|
178 |
|
179 |
-
# ํน์ ํ ํฐ๋ค ํ๋ฅ ๋ฎ์ถค
|
180 |
if len(generated) >= min_len:
|
181 |
next_token_logits[end_id] -= 5.0
|
182 |
next_token_logits[pad_id] -= 10.0
|
183 |
|
184 |
-
# ์จ๋ ์ ์ฉ
|
185 |
logits_temp = next_token_logits / temperature
|
186 |
probs = tf.nn.softmax(logits_temp).numpy()
|
187 |
|
188 |
-
# ํ๋ฅ ๋ด๋ฆผ์ฐจ์ ์ ๋ ฌ
|
189 |
sorted_idx = np.argsort(probs)[::-1]
|
190 |
sorted_probs = probs[sorted_idx]
|
191 |
cumulative_probs = np.cumsum(sorted_probs)
|
192 |
|
193 |
-
# ๋์ ํฉ์ด p ๋๋ ์์น๊น์ง๋ง ์ ํ
|
194 |
cutoff = np.searchsorted(cumulative_probs, p, side='right') + 1
|
195 |
filtered_indices = sorted_idx[:cutoff]
|
196 |
filtered_probs = sorted_probs[:cutoff]
|
197 |
filtered_probs /= filtered_probs.sum()
|
198 |
|
199 |
-
# ์ํ๋ง
|
200 |
next_token_id = np.random.choice(filtered_indices, p=filtered_probs)
|
201 |
|
202 |
-
# ๊ฒฐ๊ณผ ๋์
|
203 |
generated.append(int(next_token_id))
|
204 |
next_word = sp.id_to_piece(int(next_token_id))
|
205 |
text_so_far.append(next_word)
|
206 |
|
207 |
decoded_text = decode_sp_tokens(text_so_far)
|
208 |
|
209 |
-
# ์ ์ง ์กฐ๊ฑด
|
210 |
if len(generated) >= min_len and next_token_id == end_id:
|
211 |
break
|
212 |
if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
|
213 |
break
|
214 |
|
215 |
-
|
216 |
|
217 |
def respond(input_text):
|
218 |
if "์ด๋ฆ" in input_text:
|
@@ -220,10 +213,7 @@ def respond(input_text):
|
|
220 |
if "๋๊ตฌ" in input_text:
|
221 |
return "์ ๋ KeraLux๋ผ๊ณ ํด์."
|
222 |
|
223 |
-
|
224 |
-
response = ''.join(generate_text_topp_stream(model, input_text))
|
225 |
-
return response
|
226 |
-
|
227 |
|
228 |
@app.get("/generate", response_class=PlainTextResponse)
|
229 |
async def generate(request: Request):
|
|
|
163 |
text = ''.join(tokens).replace('โ', ' ').strip()
|
164 |
return text
|
165 |
|
166 |
+
def generate_text_topp(model, prompt, max_len=100, max_gen=98, p=0.9, temperature=0.8, min_len=20):
|
167 |
model_input = text_to_ids(f"<start> {prompt}")
|
168 |
model_input = model_input[:max_len]
|
169 |
generated = list(model_input)
|
|
|
176 |
logits = model(input_tensor, training=False)
|
177 |
next_token_logits = logits[0, len(generated) - 1].numpy()
|
178 |
|
|
|
179 |
if len(generated) >= min_len:
|
180 |
next_token_logits[end_id] -= 5.0
|
181 |
next_token_logits[pad_id] -= 10.0
|
182 |
|
|
|
183 |
logits_temp = next_token_logits / temperature
|
184 |
probs = tf.nn.softmax(logits_temp).numpy()
|
185 |
|
|
|
186 |
sorted_idx = np.argsort(probs)[::-1]
|
187 |
sorted_probs = probs[sorted_idx]
|
188 |
cumulative_probs = np.cumsum(sorted_probs)
|
189 |
|
|
|
190 |
cutoff = np.searchsorted(cumulative_probs, p, side='right') + 1
|
191 |
filtered_indices = sorted_idx[:cutoff]
|
192 |
filtered_probs = sorted_probs[:cutoff]
|
193 |
filtered_probs /= filtered_probs.sum()
|
194 |
|
|
|
195 |
next_token_id = np.random.choice(filtered_indices, p=filtered_probs)
|
196 |
|
|
|
197 |
generated.append(int(next_token_id))
|
198 |
next_word = sp.id_to_piece(int(next_token_id))
|
199 |
text_so_far.append(next_word)
|
200 |
|
201 |
decoded_text = decode_sp_tokens(text_so_far)
|
202 |
|
|
|
203 |
if len(generated) >= min_len and next_token_id == end_id:
|
204 |
break
|
205 |
if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
|
206 |
break
|
207 |
|
208 |
+
return decoded_text
|
209 |
|
210 |
def respond(input_text):
|
211 |
if "์ด๋ฆ" in input_text:
|
|
|
213 |
if "๋๊ตฌ" in input_text:
|
214 |
return "์ ๋ KeraLux๋ผ๊ณ ํด์."
|
215 |
|
216 |
+
return generate_text_topp(model, input_text)
|
|
|
|
|
|
|
217 |
|
218 |
@app.get("/generate", response_class=PlainTextResponse)
|
219 |
async def generate(request: Request):
|