KeraLux-API / app.py
Yuchan5386's picture
Update app.py
22e1363 verified
import requests
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import asyncio
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, PlainTextResponse
import sentencepiece as spm
import re
import math
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
app = FastAPI()
from fastapi.middleware.cors import CORSMiddleware
origins = [
"https://insect5386.github.io",
"https://insect5386.github.io/insect5386"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# SentencePiece ๋กœ๋“œ (ํ† ํฌ๋‚˜์ด์ €๋ž‘ ํŠน์ˆ˜ ํ† ํฐ ID๋„ ๋™์ผํ•˜๊ฒŒ ์„ธํŒ…)
sp = spm.SentencePieceProcessor()
sp.load("ko_unigram4.model")
pad_id = sp.piece_to_id("<pad>")
if pad_id == -1: pad_id = 0
start_id = sp.piece_to_id("<start>")
if start_id == -1: start_id = 1
end_id = sp.piece_to_id("< end >")
if end_id == -1: end_id = 2
unk_id = sp.piece_to_id("<unk>")
if unk_id == -1: unk_id = 3
vocab_size = sp.get_piece_size()
max_len = 100
def text_to_ids(text):
return sp.encode(text, out_type=int)
def ids_to_text(ids):
return sp.decode(ids)
class RotaryPositionalEmbedding(layers.Layer):
def __init__(self, dim):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
self.inv_freq = tf.constant(inv_freq, dtype=tf.float32)
def call(self, x):
# x shape: (batch, heads, seq_len, depth)
batch, heads, seq_len, depth = tf.unstack(tf.shape(x))
t = tf.range(seq_len, dtype=tf.float32) # (seq_len,)
freqs = tf.einsum('i,j->ij', t, self.inv_freq) # (seq_len, dim//2)
emb_sin = tf.sin(freqs) # (seq_len, dim//2)
emb_cos = tf.cos(freqs) # (seq_len, dim//2)
# (seq_len, dim//2) -> (1, 1, seq_len, dim//2) to broadcast with x
emb_cos = tf.reshape(emb_cos, [1, 1, seq_len, -1])
emb_sin = tf.reshape(emb_sin, [1, 1, seq_len, -1])
x1 = x[..., ::2] # (batch, heads, seq_len, depth//2)
x2 = x[..., 1::2]
x_rotated = tf.stack([
x1 * emb_cos - x2 * emb_sin,
x1 * emb_sin + x2 * emb_cos
], axis=-1) # shape (batch, heads, seq_len, depth//2, 2)
x_rotated = tf.reshape(x_rotated, tf.shape(x)) # ๋‹ค์‹œ (batch, heads, seq_len, depth)
return x_rotated
class GEGLU(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff):
super().__init__()
self.proj = layers.Dense(d_ff * 2)
self.out = layers.Dense(d_model)
def call(self, x):
x_proj = self.proj(x)
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
return self.out(x_val * tf.nn.gelu(x_gate))
class KeraLuxBlock(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff, num_heads=20, dropout_rate=0.1):
super().__init__()
self.ln1 = layers.LayerNormalization(epsilon=1e-5)
self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
self.dropout1 = layers.Dropout(dropout_rate)
self.ln2 = layers.LayerNormalization(epsilon=1e-5)
self.ffn = GEGLU(d_model, d_ff)
self.dropout2 = layers.Dropout(dropout_rate)
self.rope = RotaryPositionalEmbedding(d_model // num_heads)
def call(self, x, training=False):
x_norm = self.ln1(x)
# MHA ์ฟผ๋ฆฌ, ํ‚ค์— RoPE ์ ์šฉ
batch_size = tf.shape(x_norm)[0]
seq_len = tf.shape(x_norm)[1]
num_heads = self.mha.num_heads
depth = (x_norm.shape[-1]) // num_heads
# (batch, seq_len, d_model) -> (batch, num_heads, seq_len, depth)
qkv = tf.reshape(x_norm, [batch_size, seq_len, num_heads, depth])
qkv = tf.transpose(qkv, [0, 2, 1, 3]) # (batch, heads, seq_len, depth)
# RoPE ์ ์šฉ (query, key ๋ชจ๋‘ ๋™์ผ x_norm ์‚ฌ์šฉํ•˜๋‹ˆ ๋‘˜ ๋‹ค ์ ์šฉ)
q = self.rope(qkv)
k = self.rope(qkv)
# ๋‹ค์‹œ ์›๋ž˜ shape๋กœ
q = tf.transpose(q, [0, 2, 1, 3])
q = tf.reshape(q, [batch_size, seq_len, num_heads * depth])
k = tf.transpose(k, [0, 2, 1, 3])
k = tf.reshape(k, [batch_size, seq_len, num_heads * depth])
# MHA ํ˜ธ์ถœ: query=k=v=x_norm, ํ•˜์ง€๋งŒ RoPE ์”Œ์šด q,k๋กœ ๋Œ€์ฒด
attn_out = self.mha(query=q, value=x_norm, key=k, use_causal_mask=True, training=training)
x = x + self.dropout1(attn_out, training=training)
ffn_out = self.ffn(self.ln2(x))
x = x + self.dropout2(ffn_out, training=training)
return x
class KeraLux(tf.keras.Model):
def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=20, dropout_rate=0.1):
super().__init__()
self.token_embedding = layers.Embedding(vocab_size, d_model)
# pos_embedding ์ œ๊ฑฐ
self.blocks = [KeraLuxBlock(d_model, d_ff, num_heads, dropout_rate) for _ in range(n_layers)]
self.ln_f = layers.LayerNormalization(epsilon=1e-5)
def call(self, x, training=False):
seq_len = tf.shape(x)[1]
x = self.token_embedding(x)
for block in self.blocks:
x = block(x, training=training)
x = self.ln_f(x)
logits = tf.matmul(x, self.token_embedding.embeddings, transpose_b=True)
return logits
# ๋ชจ๋ธ ์ƒ์„ฑ & ๊ฐ€์ค‘์น˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
model = KeraLux(vocab_size=vocab_size, seq_len=max_len, d_model=160, d_ff=616, n_layers=6)
dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # ๋ฐฐ์น˜1, ์‹œํ€€์Šค๊ธธ์ด max_len
_ = model(dummy_input) # ๋ชจ๋ธ์ด ๋นŒ๋“œ๋จ
model.load_weights("KeraLux3.weights.h5")
print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ!")
def decode_sp_tokens(tokens):
text = ''.join(tokens).replace('โ–', ' ').strip()
return text
def generate_text_topp(model, prompt, max_len=100, max_gen=98, p=0.9, temperature=0.8, min_len=20):
model_input = text_to_ids(f"<start> {prompt}")
model_input = model_input[:max_len]
generated = list(model_input)
text_so_far = []
for step in range(max_gen):
pad_length = max(0, max_len - len(generated))
input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
input_tensor = tf.convert_to_tensor([input_padded])
logits = model(input_tensor, training=False)
next_token_logits = logits[0, len(generated) - 1].numpy()
if len(generated) >= min_len:
next_token_logits[end_id] -= 5.0
next_token_logits[pad_id] -= 10.0
logits_temp = next_token_logits / temperature
probs = tf.nn.softmax(logits_temp).numpy()
sorted_idx = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_idx]
cumulative_probs = np.cumsum(sorted_probs)
cutoff = np.searchsorted(cumulative_probs, p, side='right') + 1
filtered_indices = sorted_idx[:cutoff]
filtered_probs = sorted_probs[:cutoff]
filtered_probs /= filtered_probs.sum()
next_token_id = np.random.choice(filtered_indices, p=filtered_probs)
generated.append(int(next_token_id))
next_word = sp.id_to_piece(int(next_token_id))
text_so_far.append(next_word)
decoded_text = decode_sp_tokens(text_so_far)
if len(generated) >= min_len and next_token_id == end_id:
break
if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?')):
break
return decoded_text
def respond(input_text):
if "์ด๋ฆ„" in input_text:
return "์ œ ์ด๋ฆ„์€ KeraLux์ž…๋‹ˆ๋‹ค."
if "๋ˆ„๊ตฌ" in input_text:
return "์ €๋Š” KeraLux๋ผ๊ณ  ํ•ด์š”."
return generate_text_topp(model, input_text)
@app.get("/generate", response_class=PlainTextResponse)
async def generate(request: Request):
prompt = request.query_params.get("prompt", "์•ˆ๋…•ํ•˜์„ธ์š”")
response_text = respond(prompt)
return response_text