Yuchan5386 commited on
Commit
22e1363
ยท
verified ยท
1 Parent(s): 8fdf3c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -13
app.py CHANGED
@@ -163,7 +163,7 @@ def decode_sp_tokens(tokens):
163
  text = ''.join(tokens).replace('โ–', ' ').strip()
164
  return text
165
 
166
- def generate_text_topp_stream(model, prompt, max_len=100, max_gen=98, p=0.9, temperature=0.8, min_len=20):
167
  model_input = text_to_ids(f"<start> {prompt}")
168
  model_input = model_input[:max_len]
169
  generated = list(model_input)
@@ -176,43 +176,36 @@ def generate_text_topp_stream(model, prompt, max_len=100, max_gen=98, p=0.9, tem
176
  logits = model(input_tensor, training=False)
177
  next_token_logits = logits[0, len(generated) - 1].numpy()
178
 
179
- # ํŠน์ • ํ† ํฐ๋“ค ํ™•๋ฅ  ๋‚ฎ์ถค
180
  if len(generated) >= min_len:
181
  next_token_logits[end_id] -= 5.0
182
  next_token_logits[pad_id] -= 10.0
183
 
184
- # ์˜จ๋„ ์ ์šฉ
185
  logits_temp = next_token_logits / temperature
186
  probs = tf.nn.softmax(logits_temp).numpy()
187
 
188
- # ํ™•๋ฅ  ๋‚ด๋ฆผ์ฐจ์ˆœ ์ •๋ ฌ
189
  sorted_idx = np.argsort(probs)[::-1]
190
  sorted_probs = probs[sorted_idx]
191
  cumulative_probs = np.cumsum(sorted_probs)
192
 
193
- # ๋ˆ„์ ํ•ฉ์ด p ๋„˜๋Š” ์œ„์น˜๊นŒ์ง€๋งŒ ์„ ํƒ
194
  cutoff = np.searchsorted(cumulative_probs, p, side='right') + 1
195
  filtered_indices = sorted_idx[:cutoff]
196
  filtered_probs = sorted_probs[:cutoff]
197
  filtered_probs /= filtered_probs.sum()
198
 
199
- # ์ƒ˜ํ”Œ๋ง
200
  next_token_id = np.random.choice(filtered_indices, p=filtered_probs)
201
 
202
- # ๊ฒฐ๊ณผ ๋ˆ„์ 
203
  generated.append(int(next_token_id))
204
  next_word = sp.id_to_piece(int(next_token_id))
205
  text_so_far.append(next_word)
206
 
207
  decoded_text = decode_sp_tokens(text_so_far)
208
 
209
- # ์ •์ง€ ์กฐ๊ฑด
210
  if len(generated) >= min_len and next_token_id == end_id:
211
  break
212
  if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
213
  break
214
 
215
- return decoded_text
216
 
217
  def respond(input_text):
218
  if "์ด๋ฆ„" in input_text:
@@ -220,10 +213,7 @@ def respond(input_text):
220
  if "๋ˆ„๊ตฌ" in input_text:
221
  return "์ €๋Š” KeraLux๋ผ๊ณ  ํ•ด์š”."
222
 
223
- # ์ œ๋„ˆ๋ ˆ์ดํ„ฐ ๊ฒฐ๊ณผ๋ฅผ ๋ชจ๋‘ ์ด์–ด๋ถ™์ž„
224
- response = ''.join(generate_text_topp_stream(model, input_text))
225
- return response
226
-
227
 
228
  @app.get("/generate", response_class=PlainTextResponse)
229
  async def generate(request: Request):
 
163
  text = ''.join(tokens).replace('โ–', ' ').strip()
164
  return text
165
 
166
+ def generate_text_topp(model, prompt, max_len=100, max_gen=98, p=0.9, temperature=0.8, min_len=20):
167
  model_input = text_to_ids(f"<start> {prompt}")
168
  model_input = model_input[:max_len]
169
  generated = list(model_input)
 
176
  logits = model(input_tensor, training=False)
177
  next_token_logits = logits[0, len(generated) - 1].numpy()
178
 
 
179
  if len(generated) >= min_len:
180
  next_token_logits[end_id] -= 5.0
181
  next_token_logits[pad_id] -= 10.0
182
 
 
183
  logits_temp = next_token_logits / temperature
184
  probs = tf.nn.softmax(logits_temp).numpy()
185
 
 
186
  sorted_idx = np.argsort(probs)[::-1]
187
  sorted_probs = probs[sorted_idx]
188
  cumulative_probs = np.cumsum(sorted_probs)
189
 
 
190
  cutoff = np.searchsorted(cumulative_probs, p, side='right') + 1
191
  filtered_indices = sorted_idx[:cutoff]
192
  filtered_probs = sorted_probs[:cutoff]
193
  filtered_probs /= filtered_probs.sum()
194
 
 
195
  next_token_id = np.random.choice(filtered_indices, p=filtered_probs)
196
 
 
197
  generated.append(int(next_token_id))
198
  next_word = sp.id_to_piece(int(next_token_id))
199
  text_so_far.append(next_word)
200
 
201
  decoded_text = decode_sp_tokens(text_so_far)
202
 
 
203
  if len(generated) >= min_len and next_token_id == end_id:
204
  break
205
  if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
206
  break
207
 
208
+ return decoded_text
209
 
210
  def respond(input_text):
211
  if "์ด๋ฆ„" in input_text:
 
213
  if "๋ˆ„๊ตฌ" in input_text:
214
  return "์ €๋Š” KeraLux๋ผ๊ณ  ํ•ด์š”."
215
 
216
+ return generate_text_topp(model, input_text)
 
 
 
217
 
218
  @app.get("/generate", response_class=PlainTextResponse)
219
  async def generate(request: Request):