KeraLux-API / app.py
Yuchan5386's picture
Update app.py
a19f837 verified
raw
history blame
6.76 kB
import json
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import gradio as gr
import re
import requests
import math
import sentencepiece as spm
# SentencePiece ๋กœ๋“œ (ํ† ํฌ๋‚˜์ด์ €๋ž‘ ํŠน์ˆ˜ ํ† ํฐ ID๋„ ๋™์ผํ•˜๊ฒŒ ์„ธํŒ…)
sp = spm.SentencePieceProcessor()
sp.load("ko_unigram3.model")
pad_id = sp.piece_to_id("<pad>")
if pad_id == -1: pad_id = 0
start_id = sp.piece_to_id("<start>")
if start_id == -1: start_id = 1
end_id = sp.piece_to_id("< end >")
if end_id == -1: end_id = 2
unk_id = sp.piece_to_id("<unk>")
if unk_id == -1: unk_id = 3
vocab_size = sp.get_piece_size()
max_len = 128
def text_to_ids(text):
return sp.encode(text, out_type=int)
def ids_to_text(ids):
return sp.decode(ids)
# GEGLU ๋ ˆ์ด์–ด
class GEGLU(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff):
super().__init__()
self.proj = layers.Dense(d_ff * 2)
self.out = layers.Dense(d_model)
def call(self, x):
x_proj = self.proj(x)
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
return self.out(x_val * tf.nn.gelu(x_gate))
# GPT ๋ธ”๋ก
class GPTBlock(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff, num_heads=16, dropout_rate=0.1):
super().__init__()
self.ln1 = layers.LayerNormalization(epsilon=1e-5)
self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
self.dropout1 = layers.Dropout(dropout_rate)
self.ln2 = layers.LayerNormalization(epsilon=1e-5)
self.ffn = GEGLU(d_model, d_ff)
self.dropout2 = layers.Dropout(dropout_rate)
def call(self, x, training=False):
x_norm = self.ln1(x)
attn_out = self.attn(query=x_norm, value=x_norm, key=x_norm,
use_causal_mask=True, training=training)
x = x + self.dropout1(attn_out, training=training)
ffn_out = self.ffn(self.ln2(x))
x = x + self.dropout2(ffn_out, training=training)
return x
# GPT ๋ชจ๋ธ
class GPT(tf.keras.Model):
def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=16, dropout_rate=0.1):
super().__init__()
self.token_embedding = layers.Embedding(vocab_size, d_model)
self.pos_embedding = self.add_weight(
name="pos_embedding",
shape=[seq_len, d_model],
initializer=tf.keras.initializers.RandomNormal(stddev=0.01)
)
self.blocks = [GPTBlock(d_model, d_ff, num_heads, dropout_rate) for _ in range(n_layers)]
self.ln_f = layers.LayerNormalization(epsilon=1e-5)
def call(self, x, training=False):
seq_len = tf.shape(x)[1]
x = self.token_embedding(x) + self.pos_embedding[tf.newaxis, :seq_len, :]
for block in self.blocks:
x = block(x, training=training)
x = self.ln_f(x)
logits = tf.matmul(x, self.token_embedding.embeddings, transpose_b=True)
return logits
# ๋ชจ๋ธ ์ƒ์„ฑ & ๊ฐ€์ค‘์น˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
model = GPT(vocab_size=vocab_size, seq_len=max_len, d_model=128, d_ff=512, n_layers=6)
dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # ๋ฐฐ์น˜1, ์‹œํ€€์Šค๊ธธ์ด max_len
_ = model(dummy_input) # ๋ชจ๋ธ์ด ๋นŒ๋“œ๋จ
model.load_weights("KeraLux3.weights.h5")
print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ!")
def decode_sp_tokens(tokens):
text = ''.join(tokens).replace('โ–', ' ').strip()
return text
def generate_text_topkp_stream(model, prompt, max_len=100, max_gen=98, p=0.9, k=50, temperature=0.8, min_len=20):
model_input = text_to_ids(f"<start> {prompt}")
model_input = model_input[:max_len]
generated = list(model_input)
text_so_far = []
for step in range(max_gen):
pad_length = max(0, max_len - len(generated))
input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
input_tensor = tf.convert_to_tensor([input_padded])
logits = model(input_tensor, training=False)
next_token_logits = logits[0, len(generated) - 1].numpy()
if len(generated) >= min_len:
next_token_logits[end_id] -= 5.0
next_token_logits[pad_id] -= 10.0
# ์˜จ๋„ ์ ์šฉ
logits_temp = next_token_logits / temperature
# 1. ํ™•๋ฅ  ๊ณ„์‚ฐ
probs = tf.nn.softmax(logits_temp).numpy()
# 2. Top-k ํ•„ํ„ฐ๋ง
top_k_indices = np.argpartition(probs, -k)[-k:]
top_k_probs = probs[top_k_indices]
# 3. Top-p ํ•„ํ„ฐ๋ง (๋ˆ„์ ํ•ฉ ๊ณ„์‚ฐ์šฉ ์ •๋ ฌ)
sorted_idx = np.argsort(top_k_probs)[::-1]
top_k_indices = top_k_indices[sorted_idx]
top_k_probs = top_k_probs[sorted_idx]
cumulative_probs = np.cumsum(top_k_probs)
# p ๋„˜๋Š” ๋ถ€๋ถ„ ์ž๋ฅด๊ธฐ
cutoff = np.searchsorted(cumulative_probs, p, side='right') + 1
filtered_indices = top_k_indices[:cutoff]
filtered_probs = top_k_probs[:cutoff]
# ํ™•๋ฅ  ์ •๊ทœํ™”
filtered_probs /= filtered_probs.sum()
# ์ƒ˜ํ”Œ๋ง
next_token_id = np.random.choice(filtered_indices, p=filtered_probs)
generated.append(int(next_token_id))
next_word = sp.id_to_piece(int(next_token_id))
text_so_far.append(next_word)
decoded_text = decode_sp_tokens(text_so_far)
if len(generated) >= min_len and next_token_id == end_id:
break
if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
break
yield decoded_text
def chat(user_input, history):
if history is None:
history = []
for partial_response in generate_text_topkp_stream(model, user_input, p=0.9):
yield history + [(user_input, partial_response)], history + [(user_input, partial_response)]
with gr.Blocks(title="KeraLux Chat") as demo:
gr.Markdown(
"""
# ๐Ÿ’ก KeraLux์™€ ๋Œ€ํ™”ํ•ด๋ณด์„ธ์š”!
๋Œ€ํ™”๋ฅผ ์ž…๋ ฅํ•˜๋ฉด KeraLux๊ฐ€ ๋˜‘๋˜‘ํ•˜๊ฒŒ ๋Œ€๋‹ตํ•ด์ค„ ๊ฑฐ์˜ˆ์š”.
""",
elem_id="title",
)
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=1):
chatbot = gr.Chatbot(label="KeraLux ์ฑ„ํŒ…์ฐฝ", bubble_full_width=False)
with gr.Column(scale=0):
msg = gr.Textbox(
label="๋‹น์‹ ์˜ ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”!",
placeholder="ex) ๋‚˜ ์ข€ ๋„์™€์ค„ ์ˆ˜ ์žˆ๋‹ˆ?",
lines=1,
)
state = gr.State([])
msg.submit(chat, inputs=[msg, state], outputs=[chatbot, state])
msg.submit(lambda: "", None, msg) # ์ž…๋ ฅ์ฐฝ ์ดˆ๊ธฐํ™”
demo.launch(share=True)