KeraLux-API / app.py
Yuchan5386's picture
Update app.py
7cb8b57 verified
raw
history blame
8.06 kB
import json
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import gradio as gr
import re
import requests
import math
import sentencepiece as spm
# SentencePiece ๋กœ๋“œ (ํ† ํฌ๋‚˜์ด์ €๋ž‘ ํŠน์ˆ˜ ํ† ํฐ ID๋„ ๋™์ผํ•˜๊ฒŒ ์„ธํŒ…)
sp = spm.SentencePieceProcessor()
sp.load("ko_unigram4.model")
pad_id = sp.piece_to_id("<pad>")
if pad_id == -1: pad_id = 0
start_id = sp.piece_to_id("<start>")
if start_id == -1: start_id = 1
end_id = sp.piece_to_id("< end >")
if end_id == -1: end_id = 2
unk_id = sp.piece_to_id("<unk>")
if unk_id == -1: unk_id = 3
vocab_size = sp.get_piece_size()
max_len = 100
def text_to_ids(text):
return sp.encode(text, out_type=int)
def ids_to_text(ids):
return sp.decode(ids)
class RotaryPositionalEmbedding(layers.Layer):
def __init__(self, dim):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
self.inv_freq = tf.constant(inv_freq, dtype=tf.float32)
def call(self, x):
# x shape: (batch, heads, seq_len, depth)
batch, heads, seq_len, depth = tf.unstack(tf.shape(x))
t = tf.range(seq_len, dtype=tf.float32) # (seq_len,)
freqs = tf.einsum('i,j->ij', t, self.inv_freq) # (seq_len, dim//2)
emb_sin = tf.sin(freqs) # (seq_len, dim//2)
emb_cos = tf.cos(freqs) # (seq_len, dim//2)
# (seq_len, dim//2) -> (1, 1, seq_len, dim//2) to broadcast with x
emb_cos = tf.reshape(emb_cos, [1, 1, seq_len, -1])
emb_sin = tf.reshape(emb_sin, [1, 1, seq_len, -1])
x1 = x[..., ::2] # (batch, heads, seq_len, depth//2)
x2 = x[..., 1::2]
x_rotated = tf.stack([
x1 * emb_cos - x2 * emb_sin,
x1 * emb_sin + x2 * emb_cos
], axis=-1) # shape (batch, heads, seq_len, depth//2, 2)
x_rotated = tf.reshape(x_rotated, tf.shape(x)) # ๋‹ค์‹œ (batch, heads, seq_len, depth)
return x_rotated
class GEGLU(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff):
super().__init__()
self.proj = layers.Dense(d_ff * 2)
self.out = layers.Dense(d_model)
def call(self, x):
x_proj = self.proj(x)
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
return self.out(x_val * tf.nn.gelu(x_gate))
class KeraLuxBlock(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff, num_heads=20, dropout_rate=0.1):
super().__init__()
self.ln1 = layers.LayerNormalization(epsilon=1e-5)
self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
self.dropout1 = layers.Dropout(dropout_rate)
self.ln2 = layers.LayerNormalization(epsilon=1e-5)
self.ffn = GEGLU(d_model, d_ff)
self.dropout2 = layers.Dropout(dropout_rate)
self.rope = RotaryPositionalEmbedding(d_model // num_heads)
def call(self, x, training=False):
x_norm = self.ln1(x)
# MHA ์ฟผ๋ฆฌ, ํ‚ค์— RoPE ์ ์šฉ
batch_size = tf.shape(x_norm)[0]
seq_len = tf.shape(x_norm)[1]
num_heads = self.mha.num_heads
depth = (x_norm.shape[-1]) // num_heads
# (batch, seq_len, d_model) -> (batch, num_heads, seq_len, depth)
qkv = tf.reshape(x_norm, [batch_size, seq_len, num_heads, depth])
qkv = tf.transpose(qkv, [0, 2, 1, 3]) # (batch, heads, seq_len, depth)
# RoPE ์ ์šฉ (query, key ๋ชจ๋‘ ๋™์ผ x_norm ์‚ฌ์šฉํ•˜๋‹ˆ ๋‘˜ ๋‹ค ์ ์šฉ)
q = self.rope(qkv)
k = self.rope(qkv)
# ๋‹ค์‹œ ์›๋ž˜ shape๋กœ
q = tf.transpose(q, [0, 2, 1, 3])
q = tf.reshape(q, [batch_size, seq_len, num_heads * depth])
k = tf.transpose(k, [0, 2, 1, 3])
k = tf.reshape(k, [batch_size, seq_len, num_heads * depth])
# MHA ํ˜ธ์ถœ: query=k=v=x_norm, ํ•˜์ง€๋งŒ RoPE ์”Œ์šด q,k๋กœ ๋Œ€์ฒด
attn_out = self.mha(query=q, value=x_norm, key=k, use_causal_mask=True, training=training)
x = x + self.dropout1(attn_out, training=training)
ffn_out = self.ffn(self.ln2(x))
x = x + self.dropout2(ffn_out, training=training)
return x
class KeraLux(tf.keras.Model):
def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=20, dropout_rate=0.1):
super().__init__()
self.token_embedding = layers.Embedding(vocab_size, d_model)
# pos_embedding ์ œ๊ฑฐ
self.blocks = [KeraLuxBlock(d_model, d_ff, num_heads, dropout_rate) for _ in range(n_layers)]
self.ln_f = layers.LayerNormalization(epsilon=1e-5)
def call(self, x, training=False):
seq_len = tf.shape(x)[1]
x = self.token_embedding(x)
for block in self.blocks:
x = block(x, training=training)
x = self.ln_f(x)
logits = tf.matmul(x, self.token_embedding.embeddings, transpose_b=True)
return logits
# ๋ชจ๋ธ ์ƒ์„ฑ & ๊ฐ€์ค‘์น˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
model = KeraLux(vocab_size=vocab_size, seq_len=max_len, d_model=160, d_ff=616, n_layers=6)
dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # ๋ฐฐ์น˜1, ์‹œํ€€์Šค๊ธธ์ด max_len
_ = model(dummy_input) # ๋ชจ๋ธ์ด ๋นŒ๋“œ๋จ
model.load_weights("KeraLux3.weights.h5")
print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ!")
def decode_sp_tokens(tokens):
text = ''.join(tokens).replace('โ–', ' ').strip()
return text
def generate_text_topp_stream(model, prompt, max_len=100, max_gen=98, p=0.9, temperature=0.8, min_len=20):
model_input = text_to_ids(f"<start> {prompt}")
model_input = model_input[:max_len]
generated = list(model_input)
text_so_far = []
for step in range(max_gen):
pad_length = max(0, max_len - len(generated))
input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
input_tensor = tf.convert_to_tensor([input_padded])
logits = model(input_tensor, training=False)
next_token_logits = logits[0, len(generated) - 1].numpy()
# ํŠน์ • ํ† ํฐ๋“ค ํ™•๋ฅ  ๋‚ฎ์ถค
if len(generated) >= min_len:
next_token_logits[end_id] -= 5.0
next_token_logits[pad_id] -= 10.0
# ์˜จ๋„ ์ ์šฉ
logits_temp = next_token_logits / temperature
probs = tf.nn.softmax(logits_temp).numpy()
# ํ™•๋ฅ  ๋‚ด๋ฆผ์ฐจ์ˆœ ์ •๋ ฌ
sorted_idx = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_idx]
cumulative_probs = np.cumsum(sorted_probs)
# ๋ˆ„์ ํ•ฉ์ด p ๋„˜๋Š” ์œ„์น˜๊นŒ์ง€๋งŒ ์„ ํƒ
cutoff = np.searchsorted(cumulative_probs, p, side='right') + 1
filtered_indices = sorted_idx[:cutoff]
filtered_probs = sorted_probs[:cutoff]
filtered_probs /= filtered_probs.sum()
# ์ƒ˜ํ”Œ๋ง
next_token_id = np.random.choice(filtered_indices, p=filtered_probs)
# ๊ฒฐ๊ณผ ๋ˆ„์ 
generated.append(int(next_token_id))
next_word = sp.id_to_piece(int(next_token_id))
text_so_far.append(next_word)
decoded_text = decode_sp_tokens(text_so_far)
# ์ •์ง€ ์กฐ๊ฑด
if len(generated) >= min_len and next_token_id == end_id:
break
if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
break
yield decoded_text
def chat_stream(user_input, history_text):
partial_text = ""
for partial_response in generate_text_topp_stream(model, user_input):
partial_text = partial_response
yield history_text + f"์‚ฌ์šฉ์ž: {user_input}\nColloGPT: {partial_text}\n", \
history_text + f"์‚ฌ์šฉ์ž: {user_input}\nColloGPT: {partial_text}\n"
with gr.Blocks() as demo:
gr.Markdown("### ๐Ÿ“Ÿ ColloGPT Textbot with Streaming")
textbox = gr.Textbox(placeholder="๋ฉ”์‹œ์ง€๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”", lines=1)
output_area = gr.Textbox(label="๋Œ€ํ™” ๊ธฐ๋ก", lines=20, interactive=False)
state = gr.State("") # ์„ธ์…˜๋ณ„ ์ €์žฅ์†Œ
textbox.submit(chat_stream, inputs=[textbox, state], outputs=[output_area, state])
textbox.submit(lambda: "", None, textbox)
demo.launch()