Yuchan5386 commited on
Commit
1d3aa80
Β·
verified Β·
1 Parent(s): d4685ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -126,13 +126,15 @@ def decode_sp_tokens(tokens):
126
  text = ''.join(tokens).replace('▁', ' ').strip()
127
  return text
128
 
129
- def generate_text_mirostat(model, prompt, max_len=100, max_gen=98, temperature=1.0, min_len=20, repetition_penalty=1.2, eta=0.1, m=100):
 
 
130
  model_input = text_to_ids(f"<start> {prompt} <sep>")
131
  model_input = model_input[:max_len]
132
  generated = list(model_input)
133
  text_so_far = []
134
 
135
- tau = 5.0 # 초기 λͺ©ν‘œ surprise (μ •λ³΄λŸ‰)
136
 
137
  for step in range(max_gen):
138
  pad_length = max(0, max_len - len(generated))
@@ -141,7 +143,7 @@ def generate_text_mirostat(model, prompt, max_len=100, max_gen=98, temperature=1
141
  logits = model(input_tensor, training=False)
142
  next_token_logits = logits[0, len(generated) - 1].numpy()
143
 
144
- # 반볡 νŽ˜λ„ν‹°
145
  token_counts = {}
146
  for t in generated:
147
  token_counts[t] = token_counts.get(t, 0) + 1
@@ -156,45 +158,54 @@ def generate_text_mirostat(model, prompt, max_len=100, max_gen=98, temperature=1
156
  # μ˜¨λ„ 쑰절
157
  next_token_logits = next_token_logits / temperature
158
 
159
- # --- λ―Έλ‘œμŠ€νƒ€νŠΈ μƒ˜ν”Œλ§ ---
160
- logits_stable = next_token_logits - np.max(next_token_logits) # μ•ˆμ •ν™”
161
  probs = np.exp(logits_stable)
162
  probs /= probs.sum()
163
 
 
164
  sorted_indices = np.argsort(-probs)
165
  top_indices = sorted_indices[:m]
166
  top_probs = probs[top_indices]
167
  top_probs /= top_probs.sum()
168
 
 
169
  sampled_index = np.random.choice(top_indices, p=top_probs)
170
  sampled_prob = probs[sampled_index]
171
  observed_surprise = -np.log(sampled_prob + 1e-9)
172
-
173
- # tau μ—…λ°μ΄νŠΈ
174
  tau += eta * (observed_surprise - tau)
175
 
176
- generated.append(int(sampled_index))
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- next_word = sp.id_to_piece(int(sampled_index))
179
  text_so_far.append(next_word)
180
  decoded_text = decode_sp_tokens(text_so_far)
181
 
182
- if len(generated) >= min_len and sampled_index == end_id:
183
  break
184
  if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?', '<end>')):
185
  break
186
 
187
  yield decoded_text
188
 
189
-
190
- import gradio as gr
191
-
192
  nickname = "μ‚¬μš©μž"
193
 
194
  def respond(message, chat_history):
195
  message = message.replace("@μ‚¬μš©μž1@", nickname)
196
  response = ""
197
- for partial in generate_text_mirostat(model, message):
198
  response = partial
199
  yield response
200
 
 
126
  text = ''.join(tokens).replace('▁', ' ').strip()
127
  return text
128
 
129
+ def generate_text_mirostat_top_p(model, prompt, max_len=100, max_gen=98,
130
+ temperature=1.0, min_len=20,
131
+ repetition_penalty=1.2, eta=0.1, m=100, p=0.9):
132
  model_input = text_to_ids(f"<start> {prompt} <sep>")
133
  model_input = model_input[:max_len]
134
  generated = list(model_input)
135
  text_so_far = []
136
 
137
+ tau = 5.0 # 초기 λͺ©ν‘œ surprise
138
 
139
  for step in range(max_gen):
140
  pad_length = max(0, max_len - len(generated))
 
143
  logits = model(input_tensor, training=False)
144
  next_token_logits = logits[0, len(generated) - 1].numpy()
145
 
146
+ # 반볡 νŽ˜λ„ν‹° 적용
147
  token_counts = {}
148
  for t in generated:
149
  token_counts[t] = token_counts.get(t, 0) + 1
 
158
  # μ˜¨λ„ 쑰절
159
  next_token_logits = next_token_logits / temperature
160
 
161
+ # --- λ―Έλ‘œμŠ€νƒ€νŠΈ + Top-p μƒ˜ν”Œλ§ μ‹œμž‘ ---
162
+ logits_stable = next_token_logits - np.max(next_token_logits)
163
  probs = np.exp(logits_stable)
164
  probs /= probs.sum()
165
 
166
+ # 1. mirostat top-m 후보 좔리기
167
  sorted_indices = np.argsort(-probs)
168
  top_indices = sorted_indices[:m]
169
  top_probs = probs[top_indices]
170
  top_probs /= top_probs.sum()
171
 
172
+ # 2. mirostat μƒ˜ν”Œλ§
173
  sampled_index = np.random.choice(top_indices, p=top_probs)
174
  sampled_prob = probs[sampled_index]
175
  observed_surprise = -np.log(sampled_prob + 1e-9)
 
 
176
  tau += eta * (observed_surprise - tau)
177
 
178
+ # 3. top-p 필터링
179
+ sorted_top_indices = top_indices[np.argsort(-top_probs)]
180
+ sorted_top_probs = np.sort(top_probs)[::-1]
181
+ cumulative_probs = np.cumsum(sorted_top_probs)
182
+ cutoff = np.searchsorted(cumulative_probs, p, side='left') + 1
183
+ filtered_indices = sorted_top_indices[:cutoff]
184
+ filtered_probs = sorted_top_probs[:cutoff]
185
+ filtered_probs /= filtered_probs.sum()
186
+
187
+ # 4. μ΅œμ’… 토큰은 filtered μ§‘ν•©μ—μ„œ λ‹€μ‹œ μƒ˜ν”Œλ§
188
+ final_token = np.random.choice(filtered_indices, p=filtered_probs)
189
+
190
+ generated.append(int(final_token))
191
 
192
+ next_word = sp.id_to_piece(int(final_token))
193
  text_so_far.append(next_word)
194
  decoded_text = decode_sp_tokens(text_so_far)
195
 
196
+ if len(generated) >= min_len and final_token == end_id:
197
  break
198
  if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?', '<end>')):
199
  break
200
 
201
  yield decoded_text
202
 
 
 
 
203
  nickname = "μ‚¬μš©μž"
204
 
205
  def respond(message, chat_history):
206
  message = message.replace("@μ‚¬μš©μž1@", nickname)
207
  response = ""
208
+ for partial in generate_text_mirostat_top_p(model, message):
209
  response = partial
210
  yield response
211