Yuchan5386 commited on
Commit
11fb468
ยท
verified ยท
1 Parent(s): 6ae900b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -142,7 +142,7 @@ def decode_sp_tokens(tokens):
142
  text = ''.join(tokens).replace('โ–', ' ').strip()
143
  return text
144
 
145
- def generate_text_topz_stream(model, prompt, max_len=100, max_gen=98, alpha=1.5, temperature=0.8, min_len=20):
146
  model_input = text_to_ids(f"<start> {prompt}")
147
  model_input = model_input[:max_len]
148
  generated = list(model_input)
@@ -155,29 +155,37 @@ def generate_text_topz_stream(model, prompt, max_len=100, max_gen=98, alpha=1.5,
155
  logits = model(input_tensor, training=False)
156
  next_token_logits = logits[0, len(generated) - 1].numpy()
157
 
 
158
  if len(generated) >= min_len:
159
  next_token_logits[end_id] -= 5.0
160
  next_token_logits[pad_id] -= 10.0
161
 
162
  # ์˜จ๋„ ์ ์šฉ
163
  logits_temp = next_token_logits / temperature
164
-
165
- # ํ™•๋ฅ  ๊ณ„์‚ฐ
166
  probs = tf.nn.softmax(logits_temp).numpy()
167
 
168
- # Top-z ๋ณ€ํ˜• (z-transform: ํ‰ํƒ„ํ™”)
169
- transformed_probs = probs ** (1 / alpha)
170
- transformed_probs /= transformed_probs.sum()
 
 
 
 
 
 
 
171
 
172
  # ์ƒ˜ํ”Œ๋ง
173
- next_token_id = np.random.choice(len(transformed_probs), p=transformed_probs)
174
 
 
175
  generated.append(int(next_token_id))
176
  next_word = sp.id_to_piece(int(next_token_id))
177
  text_so_far.append(next_word)
178
 
179
  decoded_text = decode_sp_tokens(text_so_far)
180
 
 
181
  if len(generated) >= min_len and next_token_id == end_id:
182
  break
183
  if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
@@ -187,7 +195,7 @@ def generate_text_topz_stream(model, prompt, max_len=100, max_gen=98, alpha=1.5,
187
 
188
  def chat_stream(user_input, history_text):
189
  partial_text = ""
190
- for partial_response in generate_text_topz_stream(model, user_input):
191
  partial_text = partial_response
192
  yield history_text + f"์‚ฌ์šฉ์ž: {user_input}\nKeraLux: {partial_text}\n", \
193
  history_text + f"์‚ฌ์šฉ์ž: {user_input}\nKeraLux: {partial_text}\n"
 
142
  text = ''.join(tokens).replace('โ–', ' ').strip()
143
  return text
144
 
145
+ def generate_text_topp_stream(model, prompt, max_len=100, max_gen=98, p=0.9, temperature=0.8, min_len=20):
146
  model_input = text_to_ids(f"<start> {prompt}")
147
  model_input = model_input[:max_len]
148
  generated = list(model_input)
 
155
  logits = model(input_tensor, training=False)
156
  next_token_logits = logits[0, len(generated) - 1].numpy()
157
 
158
+ # ํŠน์ • ํ† ํฐ๋“ค ํ™•๋ฅ  ๋‚ฎ์ถค
159
  if len(generated) >= min_len:
160
  next_token_logits[end_id] -= 5.0
161
  next_token_logits[pad_id] -= 10.0
162
 
163
  # ์˜จ๋„ ์ ์šฉ
164
  logits_temp = next_token_logits / temperature
 
 
165
  probs = tf.nn.softmax(logits_temp).numpy()
166
 
167
+ # ํ™•๋ฅ  ๋‚ด๋ฆผ์ฐจ์ˆœ ์ •๋ ฌ
168
+ sorted_idx = np.argsort(probs)[::-1]
169
+ sorted_probs = probs[sorted_idx]
170
+ cumulative_probs = np.cumsum(sorted_probs)
171
+
172
+ # ๋ˆ„์ ํ•ฉ์ด p ๋„˜๋Š” ์œ„์น˜๊นŒ์ง€๋งŒ ์„ ํƒ
173
+ cutoff = np.searchsorted(cumulative_probs, p, side='right') + 1
174
+ filtered_indices = sorted_idx[:cutoff]
175
+ filtered_probs = sorted_probs[:cutoff]
176
+ filtered_probs /= filtered_probs.sum()
177
 
178
  # ์ƒ˜ํ”Œ๋ง
179
+ next_token_id = np.random.choice(filtered_indices, p=filtered_probs)
180
 
181
+ # ๊ฒฐ๊ณผ ๋ˆ„์ 
182
  generated.append(int(next_token_id))
183
  next_word = sp.id_to_piece(int(next_token_id))
184
  text_so_far.append(next_word)
185
 
186
  decoded_text = decode_sp_tokens(text_so_far)
187
 
188
+ # ์ •์ง€ ์กฐ๊ฑด
189
  if len(generated) >= min_len and next_token_id == end_id:
190
  break
191
  if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
 
195
 
196
  def chat_stream(user_input, history_text):
197
  partial_text = ""
198
+ for partial_response in generate_text_topp_stream(model, user_input):
199
  partial_text = partial_response
200
  yield history_text + f"์‚ฌ์šฉ์ž: {user_input}\nKeraLux: {partial_text}\n", \
201
  history_text + f"์‚ฌ์šฉ์ž: {user_input}\nKeraLux: {partial_text}\n"