Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -126,13 +126,15 @@ def decode_sp_tokens(tokens):
|
|
126 |
text = ''.join(tokens).replace('β', ' ').strip()
|
127 |
return text
|
128 |
|
129 |
-
def
|
|
|
|
|
130 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
131 |
model_input = model_input[:max_len]
|
132 |
generated = list(model_input)
|
133 |
text_so_far = []
|
134 |
|
135 |
-
tau = 5.0 # μ΄κΈ° λͺ©ν surprise
|
136 |
|
137 |
for step in range(max_gen):
|
138 |
pad_length = max(0, max_len - len(generated))
|
@@ -141,7 +143,7 @@ def generate_text_mirostat(model, prompt, max_len=100, max_gen=98, temperature=1
|
|
141 |
logits = model(input_tensor, training=False)
|
142 |
next_token_logits = logits[0, len(generated) - 1].numpy()
|
143 |
|
144 |
-
# λ°λ³΅ νλν°
|
145 |
token_counts = {}
|
146 |
for t in generated:
|
147 |
token_counts[t] = token_counts.get(t, 0) + 1
|
@@ -156,45 +158,54 @@ def generate_text_mirostat(model, prompt, max_len=100, max_gen=98, temperature=1
|
|
156 |
# μ¨λ μ‘°μ
|
157 |
next_token_logits = next_token_logits / temperature
|
158 |
|
159 |
-
# --- λ―Έλ‘μ€ννΈ μνλ§ ---
|
160 |
-
logits_stable = next_token_logits - np.max(next_token_logits)
|
161 |
probs = np.exp(logits_stable)
|
162 |
probs /= probs.sum()
|
163 |
|
|
|
164 |
sorted_indices = np.argsort(-probs)
|
165 |
top_indices = sorted_indices[:m]
|
166 |
top_probs = probs[top_indices]
|
167 |
top_probs /= top_probs.sum()
|
168 |
|
|
|
169 |
sampled_index = np.random.choice(top_indices, p=top_probs)
|
170 |
sampled_prob = probs[sampled_index]
|
171 |
observed_surprise = -np.log(sampled_prob + 1e-9)
|
172 |
-
|
173 |
-
# tau μ
λ°μ΄νΈ
|
174 |
tau += eta * (observed_surprise - tau)
|
175 |
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
-
next_word = sp.id_to_piece(int(
|
179 |
text_so_far.append(next_word)
|
180 |
decoded_text = decode_sp_tokens(text_so_far)
|
181 |
|
182 |
-
if len(generated) >= min_len and
|
183 |
break
|
184 |
if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?', '<end>')):
|
185 |
break
|
186 |
|
187 |
yield decoded_text
|
188 |
|
189 |
-
|
190 |
-
import gradio as gr
|
191 |
-
|
192 |
nickname = "μ¬μ©μ"
|
193 |
|
194 |
def respond(message, chat_history):
|
195 |
message = message.replace("@μ¬μ©μ1@", nickname)
|
196 |
response = ""
|
197 |
-
for partial in
|
198 |
response = partial
|
199 |
yield response
|
200 |
|
|
|
126 |
text = ''.join(tokens).replace('β', ' ').strip()
|
127 |
return text
|
128 |
|
129 |
+
def generate_text_mirostat_top_p(model, prompt, max_len=100, max_gen=98,
|
130 |
+
temperature=1.0, min_len=20,
|
131 |
+
repetition_penalty=1.2, eta=0.1, m=100, p=0.9):
|
132 |
model_input = text_to_ids(f"<start> {prompt} <sep>")
|
133 |
model_input = model_input[:max_len]
|
134 |
generated = list(model_input)
|
135 |
text_so_far = []
|
136 |
|
137 |
+
tau = 5.0 # μ΄κΈ° λͺ©ν surprise
|
138 |
|
139 |
for step in range(max_gen):
|
140 |
pad_length = max(0, max_len - len(generated))
|
|
|
143 |
logits = model(input_tensor, training=False)
|
144 |
next_token_logits = logits[0, len(generated) - 1].numpy()
|
145 |
|
146 |
+
# λ°λ³΅ νλν° μ μ©
|
147 |
token_counts = {}
|
148 |
for t in generated:
|
149 |
token_counts[t] = token_counts.get(t, 0) + 1
|
|
|
158 |
# μ¨λ μ‘°μ
|
159 |
next_token_logits = next_token_logits / temperature
|
160 |
|
161 |
+
# --- λ―Έλ‘μ€ννΈ + Top-p μνλ§ μμ ---
|
162 |
+
logits_stable = next_token_logits - np.max(next_token_logits)
|
163 |
probs = np.exp(logits_stable)
|
164 |
probs /= probs.sum()
|
165 |
|
166 |
+
# 1. mirostat top-m ν보 μΆλ¦¬κΈ°
|
167 |
sorted_indices = np.argsort(-probs)
|
168 |
top_indices = sorted_indices[:m]
|
169 |
top_probs = probs[top_indices]
|
170 |
top_probs /= top_probs.sum()
|
171 |
|
172 |
+
# 2. mirostat μνλ§
|
173 |
sampled_index = np.random.choice(top_indices, p=top_probs)
|
174 |
sampled_prob = probs[sampled_index]
|
175 |
observed_surprise = -np.log(sampled_prob + 1e-9)
|
|
|
|
|
176 |
tau += eta * (observed_surprise - tau)
|
177 |
|
178 |
+
# 3. top-p νν°λ§
|
179 |
+
sorted_top_indices = top_indices[np.argsort(-top_probs)]
|
180 |
+
sorted_top_probs = np.sort(top_probs)[::-1]
|
181 |
+
cumulative_probs = np.cumsum(sorted_top_probs)
|
182 |
+
cutoff = np.searchsorted(cumulative_probs, p, side='left') + 1
|
183 |
+
filtered_indices = sorted_top_indices[:cutoff]
|
184 |
+
filtered_probs = sorted_top_probs[:cutoff]
|
185 |
+
filtered_probs /= filtered_probs.sum()
|
186 |
+
|
187 |
+
# 4. μ΅μ’
ν ν°μ filtered μ§ν©μμ λ€μ μνλ§
|
188 |
+
final_token = np.random.choice(filtered_indices, p=filtered_probs)
|
189 |
+
|
190 |
+
generated.append(int(final_token))
|
191 |
|
192 |
+
next_word = sp.id_to_piece(int(final_token))
|
193 |
text_so_far.append(next_word)
|
194 |
decoded_text = decode_sp_tokens(text_so_far)
|
195 |
|
196 |
+
if len(generated) >= min_len and final_token == end_id:
|
197 |
break
|
198 |
if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?', '<end>')):
|
199 |
break
|
200 |
|
201 |
yield decoded_text
|
202 |
|
|
|
|
|
|
|
203 |
nickname = "μ¬μ©μ"
|
204 |
|
205 |
def respond(message, chat_history):
|
206 |
message = message.replace("@μ¬μ©μ1@", nickname)
|
207 |
response = ""
|
208 |
+
for partial in generate_text_mirostat_top_p(model, message):
|
209 |
response = partial
|
210 |
yield response
|
211 |
|