Yuchan5386 commited on
Commit
a19f837
ยท
verified ยท
1 Parent(s): f8cd8f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +184 -64
app.py CHANGED
@@ -1,64 +1,184 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
62
-
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras import layers
5
+ import gradio as gr
6
+ import re
7
+ import requests
8
+ import math
9
+ import sentencepiece as spm
10
+
11
+ # SentencePiece ๋กœ๋“œ (ํ† ํฌ๋‚˜์ด์ €๋ž‘ ํŠน์ˆ˜ ํ† ํฐ ID๋„ ๋™์ผํ•˜๊ฒŒ ์„ธํŒ…)
12
+ sp = spm.SentencePieceProcessor()
13
+ sp.load("ko_unigram3.model")
14
+
15
+ pad_id = sp.piece_to_id("<pad>")
16
+ if pad_id == -1: pad_id = 0
17
+ start_id = sp.piece_to_id("<start>")
18
+ if start_id == -1: start_id = 1
19
+ end_id = sp.piece_to_id("< end >")
20
+ if end_id == -1: end_id = 2
21
+ unk_id = sp.piece_to_id("<unk>")
22
+ if unk_id == -1: unk_id = 3
23
+
24
+ vocab_size = sp.get_piece_size()
25
+ max_len = 128
26
+
27
+ def text_to_ids(text):
28
+ return sp.encode(text, out_type=int)
29
+
30
+ def ids_to_text(ids):
31
+ return sp.decode(ids)
32
+
33
+ # GEGLU ๋ ˆ์ด์–ด
34
+ class GEGLU(tf.keras.layers.Layer):
35
+ def __init__(self, d_model, d_ff):
36
+ super().__init__()
37
+ self.proj = layers.Dense(d_ff * 2)
38
+ self.out = layers.Dense(d_model)
39
+ def call(self, x):
40
+ x_proj = self.proj(x)
41
+ x_val, x_gate = tf.split(x_proj, 2, axis=-1)
42
+ return self.out(x_val * tf.nn.gelu(x_gate))
43
+
44
+ # GPT ๋ธ”๋ก
45
+ class GPTBlock(tf.keras.layers.Layer):
46
+ def __init__(self, d_model, d_ff, num_heads=16, dropout_rate=0.1):
47
+ super().__init__()
48
+ self.ln1 = layers.LayerNormalization(epsilon=1e-5)
49
+ self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
50
+ self.dropout1 = layers.Dropout(dropout_rate)
51
+ self.ln2 = layers.LayerNormalization(epsilon=1e-5)
52
+ self.ffn = GEGLU(d_model, d_ff)
53
+ self.dropout2 = layers.Dropout(dropout_rate)
54
+ def call(self, x, training=False):
55
+ x_norm = self.ln1(x)
56
+ attn_out = self.attn(query=x_norm, value=x_norm, key=x_norm,
57
+ use_causal_mask=True, training=training)
58
+ x = x + self.dropout1(attn_out, training=training)
59
+ ffn_out = self.ffn(self.ln2(x))
60
+ x = x + self.dropout2(ffn_out, training=training)
61
+ return x
62
+
63
+ # GPT ๋ชจ๋ธ
64
+ class GPT(tf.keras.Model):
65
+ def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=16, dropout_rate=0.1):
66
+ super().__init__()
67
+ self.token_embedding = layers.Embedding(vocab_size, d_model)
68
+ self.pos_embedding = self.add_weight(
69
+ name="pos_embedding",
70
+ shape=[seq_len, d_model],
71
+ initializer=tf.keras.initializers.RandomNormal(stddev=0.01)
72
+ )
73
+ self.blocks = [GPTBlock(d_model, d_ff, num_heads, dropout_rate) for _ in range(n_layers)]
74
+ self.ln_f = layers.LayerNormalization(epsilon=1e-5)
75
+ def call(self, x, training=False):
76
+ seq_len = tf.shape(x)[1]
77
+ x = self.token_embedding(x) + self.pos_embedding[tf.newaxis, :seq_len, :]
78
+ for block in self.blocks:
79
+ x = block(x, training=training)
80
+ x = self.ln_f(x)
81
+ logits = tf.matmul(x, self.token_embedding.embeddings, transpose_b=True)
82
+ return logits
83
+
84
+ # ๋ชจ๋ธ ์ƒ์„ฑ & ๊ฐ€์ค‘์น˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
85
+ model = GPT(vocab_size=vocab_size, seq_len=max_len, d_model=128, d_ff=512, n_layers=6)
86
+ dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # ๋ฐฐ์น˜1, ์‹œํ€€์Šค๊ธธ์ด max_len
87
+ _ = model(dummy_input) # ๋ชจ๋ธ์ด ๋นŒ๋“œ๋จ
88
+ model.load_weights("KeraLux3.weights.h5")
89
+ print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ!")
90
+
91
+ def decode_sp_tokens(tokens):
92
+ text = ''.join(tokens).replace('โ–', ' ').strip()
93
+ return text
94
+
95
+ def generate_text_topkp_stream(model, prompt, max_len=100, max_gen=98, p=0.9, k=50, temperature=0.8, min_len=20):
96
+ model_input = text_to_ids(f"<start> {prompt}")
97
+ model_input = model_input[:max_len]
98
+ generated = list(model_input)
99
+ text_so_far = []
100
+
101
+ for step in range(max_gen):
102
+ pad_length = max(0, max_len - len(generated))
103
+ input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
104
+ input_tensor = tf.convert_to_tensor([input_padded])
105
+ logits = model(input_tensor, training=False)
106
+ next_token_logits = logits[0, len(generated) - 1].numpy()
107
+
108
+ if len(generated) >= min_len:
109
+ next_token_logits[end_id] -= 5.0
110
+ next_token_logits[pad_id] -= 10.0
111
+
112
+ # ์˜จ๋„ ์ ์šฉ
113
+ logits_temp = next_token_logits / temperature
114
+
115
+ # 1. ํ™•๋ฅ  ๊ณ„์‚ฐ
116
+ probs = tf.nn.softmax(logits_temp).numpy()
117
+
118
+ # 2. Top-k ํ•„ํ„ฐ๋ง
119
+ top_k_indices = np.argpartition(probs, -k)[-k:]
120
+ top_k_probs = probs[top_k_indices]
121
+
122
+ # 3. Top-p ํ•„ํ„ฐ๋ง (๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ์šฉ ์ •๋ ฌ)
123
+ sorted_idx = np.argsort(top_k_probs)[::-1]
124
+ top_k_indices = top_k_indices[sorted_idx]
125
+ top_k_probs = top_k_probs[sorted_idx]
126
+ cumulative_probs = np.cumsum(top_k_probs)
127
+
128
+ # p ๋„˜๋Š” ๋ถ€๋ถ„ ์ž๋ฅด๊ธฐ
129
+ cutoff = np.searchsorted(cumulative_probs, p, side='right') + 1
130
+
131
+ filtered_indices = top_k_indices[:cutoff]
132
+ filtered_probs = top_k_probs[:cutoff]
133
+
134
+ # ํ™•๋ฅ  ์ •๊ทœํ™”
135
+ filtered_probs /= filtered_probs.sum()
136
+
137
+ # ์ƒ˜ํ”Œ๋ง
138
+ next_token_id = np.random.choice(filtered_indices, p=filtered_probs)
139
+
140
+ generated.append(int(next_token_id))
141
+ next_word = sp.id_to_piece(int(next_token_id))
142
+ text_so_far.append(next_word)
143
+
144
+ decoded_text = decode_sp_tokens(text_so_far)
145
+
146
+ if len(generated) >= min_len and next_token_id == end_id:
147
+ break
148
+ if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
149
+ break
150
+
151
+ yield decoded_text
152
+
153
+ def chat(user_input, history):
154
+ if history is None:
155
+ history = []
156
+
157
+ for partial_response in generate_text_topkp_stream(model, user_input, p=0.9):
158
+ yield history + [(user_input, partial_response)], history + [(user_input, partial_response)]
159
+
160
+ with gr.Blocks(title="KeraLux Chat") as demo:
161
+ gr.Markdown(
162
+ """
163
+ # ๐Ÿ’ก KeraLux์™€ ๋Œ€ํ™”ํ•ด๋ณด์„ธ์š”!
164
+ ๋Œ€ํ™”๋ฅผ ์ž…๋ ฅํ•˜๋ฉด KeraLux๊ฐ€ ๋˜‘๋˜‘ํ•˜๊ฒŒ ๋Œ€๋‹ตํ•ด์ค„ ๊ฑฐ์˜ˆ์š”.
165
+ """,
166
+ elem_id="title",
167
+ )
168
+ gr.Markdown("---")
169
+
170
+ with gr.Row():
171
+ with gr.Column(scale=1):
172
+ chatbot = gr.Chatbot(label="KeraLux ์ฑ„ํŒ…์ฐฝ", bubble_full_width=False)
173
+ with gr.Column(scale=0):
174
+ msg = gr.Textbox(
175
+ label="๋‹น์‹ ์˜ ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”!",
176
+ placeholder="ex) ๋‚˜ ์ข€ ๋„์™€์ค„ ์ˆ˜ ์žˆ๋‹ˆ?",
177
+ lines=1,
178
+ )
179
+ state = gr.State([])
180
+
181
+ msg.submit(chat, inputs=[msg, state], outputs=[chatbot, state])
182
+ msg.submit(lambda: "", None, msg) # ์ž…๋ ฅ์ฐฝ ์ดˆ๊ธฐํ™”
183
+
184
+ demo.launch(share=True)