Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,64 +1,184 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
)
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
)
|
59 |
-
|
60 |
-
)
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras import layers
|
5 |
+
import gradio as gr
|
6 |
+
import re
|
7 |
+
import requests
|
8 |
+
import math
|
9 |
+
import sentencepiece as spm
|
10 |
+
|
11 |
+
# SentencePiece ๋ก๋ (ํ ํฌ๋์ด์ ๋ ํน์ ํ ํฐ ID๋ ๋์ผํ๊ฒ ์ธํ
)
|
12 |
+
sp = spm.SentencePieceProcessor()
|
13 |
+
sp.load("ko_unigram3.model")
|
14 |
+
|
15 |
+
pad_id = sp.piece_to_id("<pad>")
|
16 |
+
if pad_id == -1: pad_id = 0
|
17 |
+
start_id = sp.piece_to_id("<start>")
|
18 |
+
if start_id == -1: start_id = 1
|
19 |
+
end_id = sp.piece_to_id("< end >")
|
20 |
+
if end_id == -1: end_id = 2
|
21 |
+
unk_id = sp.piece_to_id("<unk>")
|
22 |
+
if unk_id == -1: unk_id = 3
|
23 |
+
|
24 |
+
vocab_size = sp.get_piece_size()
|
25 |
+
max_len = 128
|
26 |
+
|
27 |
+
def text_to_ids(text):
|
28 |
+
return sp.encode(text, out_type=int)
|
29 |
+
|
30 |
+
def ids_to_text(ids):
|
31 |
+
return sp.decode(ids)
|
32 |
+
|
33 |
+
# GEGLU ๋ ์ด์ด
|
34 |
+
class GEGLU(tf.keras.layers.Layer):
|
35 |
+
def __init__(self, d_model, d_ff):
|
36 |
+
super().__init__()
|
37 |
+
self.proj = layers.Dense(d_ff * 2)
|
38 |
+
self.out = layers.Dense(d_model)
|
39 |
+
def call(self, x):
|
40 |
+
x_proj = self.proj(x)
|
41 |
+
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
|
42 |
+
return self.out(x_val * tf.nn.gelu(x_gate))
|
43 |
+
|
44 |
+
# GPT ๋ธ๋ก
|
45 |
+
class GPTBlock(tf.keras.layers.Layer):
|
46 |
+
def __init__(self, d_model, d_ff, num_heads=16, dropout_rate=0.1):
|
47 |
+
super().__init__()
|
48 |
+
self.ln1 = layers.LayerNormalization(epsilon=1e-5)
|
49 |
+
self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
|
50 |
+
self.dropout1 = layers.Dropout(dropout_rate)
|
51 |
+
self.ln2 = layers.LayerNormalization(epsilon=1e-5)
|
52 |
+
self.ffn = GEGLU(d_model, d_ff)
|
53 |
+
self.dropout2 = layers.Dropout(dropout_rate)
|
54 |
+
def call(self, x, training=False):
|
55 |
+
x_norm = self.ln1(x)
|
56 |
+
attn_out = self.attn(query=x_norm, value=x_norm, key=x_norm,
|
57 |
+
use_causal_mask=True, training=training)
|
58 |
+
x = x + self.dropout1(attn_out, training=training)
|
59 |
+
ffn_out = self.ffn(self.ln2(x))
|
60 |
+
x = x + self.dropout2(ffn_out, training=training)
|
61 |
+
return x
|
62 |
+
|
63 |
+
# GPT ๋ชจ๋ธ
|
64 |
+
class GPT(tf.keras.Model):
|
65 |
+
def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=16, dropout_rate=0.1):
|
66 |
+
super().__init__()
|
67 |
+
self.token_embedding = layers.Embedding(vocab_size, d_model)
|
68 |
+
self.pos_embedding = self.add_weight(
|
69 |
+
name="pos_embedding",
|
70 |
+
shape=[seq_len, d_model],
|
71 |
+
initializer=tf.keras.initializers.RandomNormal(stddev=0.01)
|
72 |
+
)
|
73 |
+
self.blocks = [GPTBlock(d_model, d_ff, num_heads, dropout_rate) for _ in range(n_layers)]
|
74 |
+
self.ln_f = layers.LayerNormalization(epsilon=1e-5)
|
75 |
+
def call(self, x, training=False):
|
76 |
+
seq_len = tf.shape(x)[1]
|
77 |
+
x = self.token_embedding(x) + self.pos_embedding[tf.newaxis, :seq_len, :]
|
78 |
+
for block in self.blocks:
|
79 |
+
x = block(x, training=training)
|
80 |
+
x = self.ln_f(x)
|
81 |
+
logits = tf.matmul(x, self.token_embedding.embeddings, transpose_b=True)
|
82 |
+
return logits
|
83 |
+
|
84 |
+
# ๋ชจ๋ธ ์์ฑ & ๊ฐ์ค์น ๋ถ๋ฌ์ค๊ธฐ
|
85 |
+
model = GPT(vocab_size=vocab_size, seq_len=max_len, d_model=128, d_ff=512, n_layers=6)
|
86 |
+
dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # ๋ฐฐ์น1, ์ํ์ค๊ธธ์ด max_len
|
87 |
+
_ = model(dummy_input) # ๋ชจ๋ธ์ด ๋น๋๋จ
|
88 |
+
model.load_weights("KeraLux3.weights.h5")
|
89 |
+
print("๋ชจ๋ธ ๊ฐ์ค์น ๋ก๋ ์๋ฃ!")
|
90 |
+
|
91 |
+
def decode_sp_tokens(tokens):
|
92 |
+
text = ''.join(tokens).replace('โ', ' ').strip()
|
93 |
+
return text
|
94 |
+
|
95 |
+
def generate_text_topkp_stream(model, prompt, max_len=100, max_gen=98, p=0.9, k=50, temperature=0.8, min_len=20):
|
96 |
+
model_input = text_to_ids(f"<start> {prompt}")
|
97 |
+
model_input = model_input[:max_len]
|
98 |
+
generated = list(model_input)
|
99 |
+
text_so_far = []
|
100 |
+
|
101 |
+
for step in range(max_gen):
|
102 |
+
pad_length = max(0, max_len - len(generated))
|
103 |
+
input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
|
104 |
+
input_tensor = tf.convert_to_tensor([input_padded])
|
105 |
+
logits = model(input_tensor, training=False)
|
106 |
+
next_token_logits = logits[0, len(generated) - 1].numpy()
|
107 |
+
|
108 |
+
if len(generated) >= min_len:
|
109 |
+
next_token_logits[end_id] -= 5.0
|
110 |
+
next_token_logits[pad_id] -= 10.0
|
111 |
+
|
112 |
+
# ์จ๋ ์ ์ฉ
|
113 |
+
logits_temp = next_token_logits / temperature
|
114 |
+
|
115 |
+
# 1. ํ๋ฅ ๊ณ์ฐ
|
116 |
+
probs = tf.nn.softmax(logits_temp).numpy()
|
117 |
+
|
118 |
+
# 2. Top-k ํํฐ๋ง
|
119 |
+
top_k_indices = np.argpartition(probs, -k)[-k:]
|
120 |
+
top_k_probs = probs[top_k_indices]
|
121 |
+
|
122 |
+
# 3. Top-p ํํฐ๋ง (๋์ ํฉ ๊ณ์ฐ์ฉ ์ ๋ ฌ)
|
123 |
+
sorted_idx = np.argsort(top_k_probs)[::-1]
|
124 |
+
top_k_indices = top_k_indices[sorted_idx]
|
125 |
+
top_k_probs = top_k_probs[sorted_idx]
|
126 |
+
cumulative_probs = np.cumsum(top_k_probs)
|
127 |
+
|
128 |
+
# p ๋๋ ๋ถ๋ถ ์๋ฅด๊ธฐ
|
129 |
+
cutoff = np.searchsorted(cumulative_probs, p, side='right') + 1
|
130 |
+
|
131 |
+
filtered_indices = top_k_indices[:cutoff]
|
132 |
+
filtered_probs = top_k_probs[:cutoff]
|
133 |
+
|
134 |
+
# ํ๋ฅ ์ ๊ทํ
|
135 |
+
filtered_probs /= filtered_probs.sum()
|
136 |
+
|
137 |
+
# ์ํ๋ง
|
138 |
+
next_token_id = np.random.choice(filtered_indices, p=filtered_probs)
|
139 |
+
|
140 |
+
generated.append(int(next_token_id))
|
141 |
+
next_word = sp.id_to_piece(int(next_token_id))
|
142 |
+
text_so_far.append(next_word)
|
143 |
+
|
144 |
+
decoded_text = decode_sp_tokens(text_so_far)
|
145 |
+
|
146 |
+
if len(generated) >= min_len and next_token_id == end_id:
|
147 |
+
break
|
148 |
+
if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
|
149 |
+
break
|
150 |
+
|
151 |
+
yield decoded_text
|
152 |
+
|
153 |
+
def chat(user_input, history):
|
154 |
+
if history is None:
|
155 |
+
history = []
|
156 |
+
|
157 |
+
for partial_response in generate_text_topkp_stream(model, user_input, p=0.9):
|
158 |
+
yield history + [(user_input, partial_response)], history + [(user_input, partial_response)]
|
159 |
+
|
160 |
+
with gr.Blocks(title="KeraLux Chat") as demo:
|
161 |
+
gr.Markdown(
|
162 |
+
"""
|
163 |
+
# ๐ก KeraLux์ ๋ํํด๋ณด์ธ์!
|
164 |
+
๋ํ๋ฅผ ์
๋ ฅํ๋ฉด KeraLux๊ฐ ๋๋ํ๊ฒ ๋๋ตํด์ค ๊ฑฐ์์.
|
165 |
+
""",
|
166 |
+
elem_id="title",
|
167 |
+
)
|
168 |
+
gr.Markdown("---")
|
169 |
+
|
170 |
+
with gr.Row():
|
171 |
+
with gr.Column(scale=1):
|
172 |
+
chatbot = gr.Chatbot(label="KeraLux ์ฑํ
์ฐฝ", bubble_full_width=False)
|
173 |
+
with gr.Column(scale=0):
|
174 |
+
msg = gr.Textbox(
|
175 |
+
label="๋น์ ์ ์ง๋ฌธ์ ์
๋ ฅํ์ธ์!",
|
176 |
+
placeholder="ex) ๋ ์ข ๋์์ค ์ ์๋?",
|
177 |
+
lines=1,
|
178 |
+
)
|
179 |
+
state = gr.State([])
|
180 |
+
|
181 |
+
msg.submit(chat, inputs=[msg, state], outputs=[chatbot, state])
|
182 |
+
msg.submit(lambda: "", None, msg) # ์
๋ ฅ์ฐฝ ์ด๊ธฐํ
|
183 |
+
|
184 |
+
demo.launch(share=True)
|