Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -142,7 +142,7 @@ def decode_sp_tokens(tokens):
|
|
142 |
text = ''.join(tokens).replace('โ', ' ').strip()
|
143 |
return text
|
144 |
|
145 |
-
def
|
146 |
model_input = text_to_ids(f"<start> {prompt}")
|
147 |
model_input = model_input[:max_len]
|
148 |
generated = list(model_input)
|
@@ -155,29 +155,37 @@ def generate_text_topz_stream(model, prompt, max_len=100, max_gen=98, alpha=1.5,
|
|
155 |
logits = model(input_tensor, training=False)
|
156 |
next_token_logits = logits[0, len(generated) - 1].numpy()
|
157 |
|
|
|
158 |
if len(generated) >= min_len:
|
159 |
next_token_logits[end_id] -= 5.0
|
160 |
next_token_logits[pad_id] -= 10.0
|
161 |
|
162 |
# ์จ๋ ์ ์ฉ
|
163 |
logits_temp = next_token_logits / temperature
|
164 |
-
|
165 |
-
# ํ๋ฅ ๊ณ์ฐ
|
166 |
probs = tf.nn.softmax(logits_temp).numpy()
|
167 |
|
168 |
-
#
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
# ์ํ๋ง
|
173 |
-
next_token_id = np.random.choice(
|
174 |
|
|
|
175 |
generated.append(int(next_token_id))
|
176 |
next_word = sp.id_to_piece(int(next_token_id))
|
177 |
text_so_far.append(next_word)
|
178 |
|
179 |
decoded_text = decode_sp_tokens(text_so_far)
|
180 |
|
|
|
181 |
if len(generated) >= min_len and next_token_id == end_id:
|
182 |
break
|
183 |
if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
|
@@ -187,7 +195,7 @@ def generate_text_topz_stream(model, prompt, max_len=100, max_gen=98, alpha=1.5,
|
|
187 |
|
188 |
def chat_stream(user_input, history_text):
|
189 |
partial_text = ""
|
190 |
-
for partial_response in
|
191 |
partial_text = partial_response
|
192 |
yield history_text + f"์ฌ์ฉ์: {user_input}\nKeraLux: {partial_text}\n", \
|
193 |
history_text + f"์ฌ์ฉ์: {user_input}\nKeraLux: {partial_text}\n"
|
|
|
142 |
text = ''.join(tokens).replace('โ', ' ').strip()
|
143 |
return text
|
144 |
|
145 |
+
def generate_text_topp_stream(model, prompt, max_len=100, max_gen=98, p=0.9, temperature=0.8, min_len=20):
|
146 |
model_input = text_to_ids(f"<start> {prompt}")
|
147 |
model_input = model_input[:max_len]
|
148 |
generated = list(model_input)
|
|
|
155 |
logits = model(input_tensor, training=False)
|
156 |
next_token_logits = logits[0, len(generated) - 1].numpy()
|
157 |
|
158 |
+
# ํน์ ํ ํฐ๋ค ํ๋ฅ ๋ฎ์ถค
|
159 |
if len(generated) >= min_len:
|
160 |
next_token_logits[end_id] -= 5.0
|
161 |
next_token_logits[pad_id] -= 10.0
|
162 |
|
163 |
# ์จ๋ ์ ์ฉ
|
164 |
logits_temp = next_token_logits / temperature
|
|
|
|
|
165 |
probs = tf.nn.softmax(logits_temp).numpy()
|
166 |
|
167 |
+
# ํ๋ฅ ๋ด๋ฆผ์ฐจ์ ์ ๋ ฌ
|
168 |
+
sorted_idx = np.argsort(probs)[::-1]
|
169 |
+
sorted_probs = probs[sorted_idx]
|
170 |
+
cumulative_probs = np.cumsum(sorted_probs)
|
171 |
+
|
172 |
+
# ๋์ ํฉ์ด p ๋๋ ์์น๊น์ง๋ง ์ ํ
|
173 |
+
cutoff = np.searchsorted(cumulative_probs, p, side='right') + 1
|
174 |
+
filtered_indices = sorted_idx[:cutoff]
|
175 |
+
filtered_probs = sorted_probs[:cutoff]
|
176 |
+
filtered_probs /= filtered_probs.sum()
|
177 |
|
178 |
# ์ํ๋ง
|
179 |
+
next_token_id = np.random.choice(filtered_indices, p=filtered_probs)
|
180 |
|
181 |
+
# ๊ฒฐ๊ณผ ๋์
|
182 |
generated.append(int(next_token_id))
|
183 |
next_word = sp.id_to_piece(int(next_token_id))
|
184 |
text_so_far.append(next_word)
|
185 |
|
186 |
decoded_text = decode_sp_tokens(text_so_far)
|
187 |
|
188 |
+
# ์ ์ง ์กฐ๊ฑด
|
189 |
if len(generated) >= min_len and next_token_id == end_id:
|
190 |
break
|
191 |
if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
|
|
|
195 |
|
196 |
def chat_stream(user_input, history_text):
|
197 |
partial_text = ""
|
198 |
+
for partial_response in generate_text_topp_stream(model, user_input):
|
199 |
partial_text = partial_response
|
200 |
yield history_text + f"์ฌ์ฉ์: {user_input}\nKeraLux: {partial_text}\n", \
|
201 |
history_text + f"์ฌ์ฉ์: {user_input}\nKeraLux: {partial_text}\n"
|