Yuchan5386 commited on
Commit
86c9c3c
ยท
verified ยท
1 Parent(s): 1d3aa80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -60
app.py CHANGED
@@ -126,86 +126,64 @@ def decode_sp_tokens(tokens):
126
  text = ''.join(tokens).replace('โ–', ' ').strip()
127
  return text
128
 
129
- def generate_text_mirostat_top_p(model, prompt, max_len=100, max_gen=98,
130
- temperature=1.0, min_len=20,
131
- repetition_penalty=1.2, eta=0.1, m=100, p=0.9):
132
  model_input = text_to_ids(f"<start> {prompt} <sep>")
133
  model_input = model_input[:max_len]
134
  generated = list(model_input)
135
- text_so_far = []
136
-
137
- tau = 5.0 # ์ดˆ๊ธฐ ๋ชฉํ‘œ surprise
138
 
139
  for step in range(max_gen):
140
- pad_length = max(0, max_len - len(generated))
141
- input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
142
  input_tensor = tf.convert_to_tensor([input_padded])
 
143
  logits = model(input_tensor, training=False)
144
- next_token_logits = logits[0, len(generated) - 1].numpy()
145
-
146
- # ๋ฐ˜๋ณต ํŽ˜๋„ํ‹ฐ ์ ์šฉ
147
- token_counts = {}
148
- for t in generated:
149
- token_counts[t] = token_counts.get(t, 0) + 1
150
- for token_id, count in token_counts.items():
151
- next_token_logits[token_id] /= (repetition_penalty ** count)
152
-
153
- # ์ตœ์†Œ ๊ธธ์ด ๋„˜์œผ๋ฉด ์ข…๋ฃŒ ํ† ํฐ ํ™•๋ฅ  ๋‚ฎ์ถ”๊ธฐ
154
- if len(generated) >= min_len:
155
- next_token_logits[end_id] -= 5.0
156
- next_token_logits[pad_id] -= 10.0
157
-
158
- # ์˜จ๋„ ์กฐ์ ˆ
159
- next_token_logits = next_token_logits / temperature
160
-
161
- # --- ๋ฏธ๋กœ์Šคํƒ€ํŠธ + Top-p ์ƒ˜ํ”Œ๋ง ์‹œ์ž‘ ---
162
- logits_stable = next_token_logits - np.max(next_token_logits)
163
- probs = np.exp(logits_stable)
164
  probs /= probs.sum()
165
 
166
- # 1. mirostat top-m ํ›„๋ณด ์ถ”๋ฆฌ๊ธฐ
167
- sorted_indices = np.argsort(-probs)
168
- top_indices = sorted_indices[:m]
169
- top_probs = probs[top_indices]
170
- top_probs /= top_probs.sum()
171
-
172
- # 2. mirostat ์ƒ˜ํ”Œ๋ง
173
- sampled_index = np.random.choice(top_indices, p=top_probs)
174
- sampled_prob = probs[sampled_index]
175
- observed_surprise = -np.log(sampled_prob + 1e-9)
176
- tau += eta * (observed_surprise - tau)
177
-
178
- # 3. top-p ํ•„ํ„ฐ๋ง
179
- sorted_top_indices = top_indices[np.argsort(-top_probs)]
180
- sorted_top_probs = np.sort(top_probs)[::-1]
181
- cumulative_probs = np.cumsum(sorted_top_probs)
182
- cutoff = np.searchsorted(cumulative_probs, p, side='left') + 1
183
- filtered_indices = sorted_top_indices[:cutoff]
184
- filtered_probs = sorted_top_probs[:cutoff]
185
- filtered_probs /= filtered_probs.sum()
186
 
187
- # 4. ์ตœ์ข… ํ† ํฐ์€ filtered ์ง‘ํ•ฉ์—์„œ ๋‹ค์‹œ ์ƒ˜ํ”Œ๋ง
188
- final_token = np.random.choice(filtered_indices, p=filtered_probs)
 
189
 
190
- generated.append(int(final_token))
 
191
 
192
- next_word = sp.id_to_piece(int(final_token))
193
- text_so_far.append(next_word)
194
- decoded_text = decode_sp_tokens(text_so_far)
 
195
 
196
- if len(generated) >= min_len and final_token == end_id:
 
197
  break
198
- if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?', '<end>')):
199
- break
200
-
201
- yield decoded_text
202
 
203
  nickname = "์‚ฌ์šฉ์ž"
204
 
205
  def respond(message, chat_history):
206
  message = message.replace("@์‚ฌ์šฉ์ž1@", nickname)
207
  response = ""
208
- for partial in generate_text_mirostat_top_p(model, message):
209
  response = partial
210
  yield response
211
 
 
126
  text = ''.join(tokens).replace('โ–', ' ').strip()
127
  return text
128
 
129
+ def generate_text_top_p(model, prompt, max_len=100, max_gen=98,
130
+ temperature=1.0, min_len=20,
131
+ repetition_penalty=1.1, top_p=0.9):
132
  model_input = text_to_ids(f"<start> {prompt} <sep>")
133
  model_input = model_input[:max_len]
134
  generated = list(model_input)
 
 
 
135
 
136
  for step in range(max_gen):
137
+ pad_len = max(0, max_len - len(generated))
138
+ input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
139
  input_tensor = tf.convert_to_tensor([input_padded])
140
+
141
  logits = model(input_tensor, training=False)
142
+ next_logits = logits[0, len(generated) - 1].numpy()
143
+
144
+ # ๋ฐ˜๋ณต ์–ต์ œ penalty
145
+ for t in set(generated):
146
+ count = generated.count(t)
147
+ next_logits[t] /= (repetition_penalty ** count)
148
+
149
+ # ์ข…๋ฃŒ ์กฐ๊ฑด ๋ฐฉ์ง€
150
+ if len(generated) < min_len:
151
+ next_logits[end_id] -= 5.0
152
+ next_logits[pad_id] -= 10.0
153
+
154
+ # ์˜จ๋„ ์ ์šฉ
155
+ next_logits = next_logits / temperature
156
+ probs = np.exp(next_logits - np.max(next_logits))
 
 
 
 
 
157
  probs /= probs.sum()
158
 
159
+ # Top-p ํ•„ํ„ฐ๋ง
160
+ sorted_idx = np.argsort(-probs)
161
+ sorted_probs = probs[sorted_idx]
162
+ cum_probs = np.cumsum(sorted_probs)
163
+ cutoff = np.searchsorted(cum_probs, top_p) + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ filtered_idx = sorted_idx[:cutoff]
166
+ filtered_probs = sorted_probs[:cutoff]
167
+ filtered_probs /= filtered_probs.sum()
168
 
169
+ sampled = np.random.choice(filtered_idx, p=filtered_probs)
170
+ generated.append(int(sampled))
171
 
172
+ decoded = sp.decode(generated)
173
+ for t in ["<start>", "<sep>", "<end>"]:
174
+ decoded = decoded.replace(t, "")
175
+ decoded = decoded.strip()
176
 
177
+ if len(generated) >= min_len and (sampled == end_id or decoded.endswith(('.', '!', '?'))):
178
+ yield decoded
179
  break
 
 
 
 
180
 
181
  nickname = "์‚ฌ์šฉ์ž"
182
 
183
  def respond(message, chat_history):
184
  message = message.replace("@์‚ฌ์šฉ์ž1@", nickname)
185
  response = ""
186
+ for partial in generate_text_top_p(model, message):
187
  response = partial
188
  yield response
189