KeraLux-API / app.py
Yuchan5386's picture
Update app.py
75a0ea0 verified
raw
history blame
8.09 kB
import json
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import gradio as gr
import re
import requests
import math
import sentencepiece as spm
# SentencePiece ๋กœ๋“œ (ํ† ํฌ๋‚˜์ด์ €๋ž‘ ํŠน์ˆ˜ ํ† ํฐ ID๋„ ๋™์ผํ•˜๊ฒŒ ์„ธํŒ…)
sp = spm.SentencePieceProcessor()
sp.load("ko_unigram4.model")
pad_id = sp.piece_to_id("<pad>")
if pad_id == -1: pad_id = 0
start_id = sp.piece_to_id("<start>")
if start_id == -1: start_id = 1
end_id = sp.piece_to_id("< end >")
if end_id == -1: end_id = 2
unk_id = sp.piece_to_id("<unk>")
if unk_id == -1: unk_id = 3
vocab_size = sp.get_piece_size()
max_len = 100
def text_to_ids(text):
return sp.encode(text, out_type=int)
def ids_to_text(ids):
return sp.decode(ids)
class RotaryPositionalEmbedding(layers.Layer):
def __init__(self, dim):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
self.inv_freq = tf.constant(inv_freq, dtype=tf.float32)
def call(self, x):
# x shape: (batch, heads, seq_len, depth)
batch, heads, seq_len, depth = tf.unstack(tf.shape(x))
t = tf.range(seq_len, dtype=tf.float32) # (seq_len,)
freqs = tf.einsum('i,j->ij', t, self.inv_freq) # (seq_len, dim//2)
emb_sin = tf.sin(freqs) # (seq_len, dim//2)
emb_cos = tf.cos(freqs) # (seq_len, dim//2)
# (seq_len, dim//2) -> (1, 1, seq_len, dim//2) to broadcast with x
emb_cos = tf.reshape(emb_cos, [1, 1, seq_len, -1])
emb_sin = tf.reshape(emb_sin, [1, 1, seq_len, -1])
x1 = x[..., ::2] # (batch, heads, seq_len, depth//2)
x2 = x[..., 1::2]
x_rotated = tf.stack([
x1 * emb_cos - x2 * emb_sin,
x1 * emb_sin + x2 * emb_cos
], axis=-1) # shape (batch, heads, seq_len, depth//2, 2)
x_rotated = tf.reshape(x_rotated, tf.shape(x)) # ๋‹ค์‹œ (batch, heads, seq_len, depth)
return x_rotated
class GEGLU(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff):
super().__init__()
self.proj = layers.Dense(d_ff * 2)
self.out = layers.Dense(d_model)
def call(self, x):
x_proj = self.proj(x)
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
return self.out(x_val * tf.nn.gelu(x_gate))
class KeraLuxBlock(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff, num_heads=20, dropout_rate=0.1):
super().__init__()
self.ln1 = layers.LayerNormalization(epsilon=1e-5)
self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
self.dropout1 = layers.Dropout(dropout_rate)
self.ln2 = layers.LayerNormalization(epsilon=1e-5)
self.ffn = GEGLU(d_model, d_ff)
self.dropout2 = layers.Dropout(dropout_rate)
self.rope = RotaryPositionalEmbedding(d_model // num_heads)
def call(self, x, training=False):
x_norm = self.ln1(x)
# MHA ์ฟผ๋ฆฌ, ํ‚ค์— RoPE ์ ์šฉ
batch_size = tf.shape(x_norm)[0]
seq_len = tf.shape(x_norm)[1]
num_heads = self.mha.num_heads
depth = (x_norm.shape[-1]) // num_heads
# (batch, seq_len, d_model) -> (batch, num_heads, seq_len, depth)
qkv = tf.reshape(x_norm, [batch_size, seq_len, num_heads, depth])
qkv = tf.transpose(qkv, [0, 2, 1, 3]) # (batch, heads, seq_len, depth)
# RoPE ์ ์šฉ (query, key ๋ชจ๋‘ ๋™์ผ x_norm ์‚ฌ์šฉํ•˜๋‹ˆ ๋‘˜ ๋‹ค ์ ์šฉ)
q = self.rope(qkv)
k = self.rope(qkv)
# ๋‹ค์‹œ ์›๋ž˜ shape๋กœ
q = tf.transpose(q, [0, 2, 1, 3])
q = tf.reshape(q, [batch_size, seq_len, num_heads * depth])
k = tf.transpose(k, [0, 2, 1, 3])
k = tf.reshape(k, [batch_size, seq_len, num_heads * depth])
# MHA ํ˜ธ์ถœ: query=k=v=x_norm, ํ•˜์ง€๋งŒ RoPE ์”Œ์šด q,k๋กœ ๋Œ€์ฒด
attn_out = self.mha(query=q, value=x_norm, key=k, use_causal_mask=True, training=training)
x = x + self.dropout1(attn_out, training=training)
ffn_out = self.ffn(self.ln2(x))
x = x + self.dropout2(ffn_out, training=training)
return x
class KeraLux(tf.keras.Model):
def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=20, dropout_rate=0.1):
super().__init__()
self.token_embedding = layers.Embedding(vocab_size, d_model)
# pos_embedding ์ œ๊ฑฐ
self.blocks = [KeraLuxBlock(d_model, d_ff, num_heads, dropout_rate) for _ in range(n_layers)]
self.ln_f = layers.LayerNormalization(epsilon=1e-5)
def call(self, x, training=False):
seq_len = tf.shape(x)[1]
x = self.token_embedding(x)
for block in self.blocks:
x = block(x, training=training)
x = self.ln_f(x)
logits = tf.matmul(x, self.token_embedding.embeddings, transpose_b=True)
return logits
# ๋ชจ๋ธ ์ƒ์„ฑ & ๊ฐ€์ค‘์น˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
model = KeraLux(vocab_size=vocab_size, seq_len=max_len, d_model=160, d_ff=616, n_layers=6)
dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # ๋ฐฐ์น˜1, ์‹œํ€€์Šค๊ธธ์ด max_len
_ = model(dummy_input) # ๋ชจ๋ธ์ด ๋นŒ๋“œ๋จ
model.load_weights("KeraLux3.weights.h5")
print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ!")
def decode_sp_tokens(tokens):
text = ''.join(tokens).replace('โ–', ' ').strip()
return text
def generate_text_topkp_stream(model, prompt, max_len=100, max_gen=98, p=0.9, k=50, temperature=0.8, min_len=20):
model_input = text_to_ids(f"<start> {prompt}")
model_input = model_input[:max_len]
generated = list(model_input)
text_so_far = []
for step in range(max_gen):
pad_length = max(0, max_len - len(generated))
input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
input_tensor = tf.convert_to_tensor([input_padded])
logits = model(input_tensor, training=False)
next_token_logits = logits[0, len(generated) - 1].numpy()
if len(generated) >= min_len:
next_token_logits[end_id] -= 5.0
next_token_logits[pad_id] -= 10.0
# ์˜จ๋„ ์ ์šฉ
logits_temp = next_token_logits / temperature
# 1. ํ™•๋ฅ  ๊ณ„์‚ฐ
probs = tf.nn.softmax(logits_temp).numpy()
# 2. Top-k ํ•„ํ„ฐ๋ง
top_k_indices = np.argpartition(probs, -k)[-k:]
top_k_probs = probs[top_k_indices]
# 3. Top-p ํ•„ํ„ฐ๋ง (๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ์šฉ ์ •๋ ฌ)
sorted_idx = np.argsort(top_k_probs)[::-1]
top_k_indices = top_k_indices[sorted_idx]
top_k_probs = top_k_probs[sorted_idx]
cumulative_probs = np.cumsum(top_k_probs)
# p ๋„˜๋Š” ๋ถ€๋ถ„ ์ž๋ฅด๊ธฐ
cutoff = np.searchsorted(cumulative_probs, p, side='right') + 1
filtered_indices = top_k_indices[:cutoff]
filtered_probs = top_k_probs[:cutoff]
# ํ™•๋ฅ  ์ •๊ทœํ™”
filtered_probs /= filtered_probs.sum()
# ์ƒ˜ํ”Œ๋ง
next_token_id = np.random.choice(filtered_indices, p=filtered_probs)
generated.append(int(next_token_id))
next_word = sp.id_to_piece(int(next_token_id))
text_so_far.append(next_word)
decoded_text = decode_sp_tokens(text_so_far)
if len(generated) >= min_len and next_token_id == end_id:
break
if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
break
yield decoded_text
history = ""
def chat(user_input):
global history
response = generate_text(user_input) # ๋„ค ๋ชจ๋ธ ์ƒ์„ฑ ํ•จ์ˆ˜
history += f"์‚ฌ์šฉ์ž: {user_input}\nKeraLux: {response}\n\n"
return history
with gr.Blocks() as demo:
gr.Markdown("### ๐Ÿ“Ÿ KeraLux Textbot\n๊ฐ„๋‹จํ•˜๊ณ  ๋น ๋ฅธ ๋Œ€ํ™”์šฉ ๋ด‡์ด์—์š”.\n")
textbox = gr.Textbox(placeholder="๋ฉ”์‹œ์ง€๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”", lines=1)
output_area = gr.Textbox(label="๋Œ€ํ™” ๊ธฐ๋ก", lines=20, interactive=False)
textbox.submit(chat, inputs=textbox, outputs=output_area)
textbox.submit(lambda: "", None, textbox) # ์ž…๋ ฅ์ฐฝ ์ดˆ๊ธฐํ™”
demo.launch(share=True)