Flexi-API / api.py
Yuchan5386's picture
Update api.py
56efcce verified
import requests
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import asyncio
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, PlainTextResponse
import sentencepiece as spm
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from fastapi.middleware.cors import CORSMiddleware
import re
app = FastAPI()
origins = [
"https://insect5386.github.io",
"https://insect5386.github.io/insect5386"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
sp = spm.SentencePieceProcessor()
sp.load("kolig_unigram.model")
pad_id = sp.piece_to_id("<pad>")
if pad_id == -1: pad_id = 0
start_id = sp.piece_to_id("<start>")
if start_id == -1: start_id = 1
end_id = sp.piece_to_id("<end>")
if end_id == -1: end_id = 2
unk_id = sp.piece_to_id("<unk>")
if unk_id == -1: unk_id = 3
vocab_size = sp.get_piece_size()
max_len = 100
def text_to_ids(text):
return sp.encode(text, out_type=int)
def ids_to_text(ids):
return sp.decode(ids)
class RotaryPositionalEmbedding(layers.Layer):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
self.inv_freq = tf.constant(inv_freq, dtype=tf.float32)
def call(self, x):
batch, heads, seq_len, depth = tf.unstack(tf.shape(x))
t = tf.range(seq_len, dtype=tf.float32)
freqs = tf.einsum('i,j->ij', t, self.inv_freq)
emb_sin = tf.sin(freqs)
emb_cos = tf.cos(freqs)
emb_cos = tf.reshape(emb_cos, [1, 1, seq_len, -1])
emb_sin = tf.reshape(emb_sin, [1, 1, seq_len, -1])
x1 = x[..., ::2]
x2 = x[..., 1::2]
x_rotated = tf.stack([
x1 * emb_cos - x2 * emb_sin,
x1 * emb_sin + x2 * emb_cos
], axis=-1)
x_rotated = tf.reshape(x_rotated, tf.shape(x))
return x_rotated
class SwiGLU(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff):
super().__init__()
self.proj = tf.keras.layers.Dense(d_ff * 2)
self.out = tf.keras.layers.Dense(d_model)
def call(self, x):
x_proj = self.proj(x)
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
return self.out(x_val * tf.nn.silu(x_gate))
class Block(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff, num_heads=8, dropout_rate=0.05, adapter_dim=64):
super().__init__()
self.ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
self.adapter_down = tf.keras.layers.Dense(adapter_dim, activation='gelu')
self.adapter_up = tf.keras.layers.Dense(d_model)
self.ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
self.ffn = SwiGLU(d_model, d_ff)
self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
self.rope = RotaryPositionalEmbedding(d_model // num_heads)
def call(self, x, training=False):
x_norm = self.ln1(x)
b, s, _ = tf.shape(x_norm)[0], tf.shape(x_norm)[1], tf.shape(x_norm)[2]
h = self.mha.num_heads
d = x_norm.shape[-1] // h
qkv = tf.reshape(x_norm, [b, s, h, d])
qkv = tf.transpose(qkv, [0, 2, 1, 3])
q = self.rope(qkv)
k = self.rope(qkv)
q = tf.reshape(tf.transpose(q, [0, 2, 1, 3]), [b, s, h * d])
k = tf.reshape(tf.transpose(k, [0, 2, 1, 3]), [b, s, h * d])
attn_out = self.mha(query=q, value=x_norm, key=k, use_causal_mask=True, training=training)
attn_out = self.dropout1(attn_out, training=training)
adapter_out = self.adapter_up(self.adapter_down(attn_out))
attn_out = attn_out + adapter_out
x = x + attn_out
ffn_out = self.ffn(self.ln2(x))
x = x + self.dropout2(ffn_out, training=training)
return x
class Flexi(tf.keras.Model):
def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=8, dropout_rate=0.05):
super().__init__()
self.token_embedding = tf.keras.layers.Embedding(vocab_size, d_model)
self.blocks = [Block(d_model, d_ff, num_heads, dropout_rate) for _ in range(n_layers)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=1e-5)
def call(self, x, training=False):
x = self.token_embedding(x)
for block in self.blocks:
x = block(x, training=training)
x = self.ln_f(x)
logits = tf.matmul(x, self.token_embedding.embeddings, transpose_b=True)
return logits
model = Flexi(
vocab_size=vocab_size,
seq_len=max_len,
d_model=256,
d_ff=1024,
n_layers=16
)
dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # ๋ฐฐ์น˜1, ์‹œํ€€์Šค๊ธธ์ด max_len
_ = model(dummy_input) # ๋ชจ๋ธ์ด ๋นŒ๋“œ๋จ
model.load_weights("Flexi.weights.h5")
print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ!")
def generate_text_sample(model, prompt, max_len=100, max_gen=98,
temperature=0.85, top_k=65, top_p=0.9, min_len=12):
model_input = text_to_ids(f"{prompt}")
model_input = model_input[:max_len]
generated = list(model_input)
for _ in range(max_gen):
pad_len = max(0, max_len - len(generated))
input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
input_tensor = tf.convert_to_tensor([input_padded])
logits = model(input_tensor, training=False)
next_logits = logits[0, len(generated) - 1].numpy()
# Temperature ์ ์šฉ
next_logits = next_logits / temperature
probs = np.exp(next_logits - np.max(next_logits))
probs = probs / probs.sum()
# Top-K ํ•„ํ„ฐ๋ง
if top_k is not None and top_k > 0:
indices_to_remove = probs < np.sort(probs)[-top_k]
probs[indices_to_remove] = 0
probs /= probs.sum()
# Top-P (๋ˆ„์  ํ™•๋ฅ ) ํ•„ํ„ฐ๋ง
if top_p is not None and 0 < top_p < 1:
sorted_indices = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_indices]
cumulative_probs = np.cumsum(sorted_probs)
cutoff_index = np.searchsorted(cumulative_probs, top_p, side='right')
probs_to_keep = sorted_indices[:cutoff_index+1]
mask = np.ones_like(probs, dtype=bool)
mask[probs_to_keep] = False
probs[mask] = 0
probs /= probs.sum()
# ์ƒ˜ํ”Œ๋ง
next_token = np.random.choice(len(probs), p=probs)
generated.append(int(next_token))
# ๋””์ฝ”๋”ฉ ๋ฐ ํ›„์ฒ˜๋ฆฌ
decoded = sp.decode(generated)
for t in ["<start>", "<sep>", "<end>"]:
decoded = decoded.replace(t, "")
decoded = decoded.strip()
if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('์š”', '๋‹ค', '.', '!', '?'))):
return decoded
decoded = sp.decode(generated)
for t in ["<start>", "<sep>", "<end>"]:
decoded = decoded.replace(t, "")
return decoded.strip()
# ์œ ํšจํ•œ ์‘๋‹ต์ธ์ง€ ๊ฒ€์‚ฌ
def is_valid_response(response):
if len(response.strip()) < 2:
return False
if re.search(r'[ใ„ฑ-ใ…Žใ…-ใ…ฃ]{3,}', response):
return False
if len(response.split()) < 2:
return False
if response.count(' ') < 2:
return False
# 'ใ…‹ใ…‹' ์ฒดํฌ ์ œ๊ฑฐ
if any(tok in response.lower() for tok in ['hello', 'this']):
return False
return True
def respond(input_text):
# ์ด๋ฆ„ ๊ด€๋ จ ์งˆ๋ฌธ์— ๋”ฑ ๋ฐ˜์‘ํ•˜๋Š” ๋ถ€๋ถ„ ์œ ์ง€
if "์ด๋ฆ„" in input_text:
response = "์ œ ์ด๋ฆ„์€ Flexi์ž…๋‹ˆ๋‹ค."
return response
if "๋ˆ„๊ตฌ" in input_text:
response = "์ €๋Š” Flexi๋ผ๊ณ  ํ•ด์š”."
return response
# mismatch_tone ๊ฒ€์‚ฌ ์ œ๊ฑฐ
full_prompt = f"<start> {input_text} <sep>"
for _ in range(3): # ์ตœ๋Œ€ 3๋ฒˆ ์žฌ์‹œ๋„
full_response = generate_text_sample(model, full_prompt)
if "์‘๋‹ต:" in full_response:
response = full_response.split("<sep>")[-1].strip()
else:
response = full_response.strip()
if is_valid_response(response): # mismatch_tone ์ œ๊ฑฐ
return response
return "์ฃ„์†กํ•ด์š”, ์ œ๋Œ€๋กœ ๋‹ต๋ณ€์„ ๋ชปํ–ˆ์–ด์š”."
@app.get("/generate", response_class=PlainTextResponse)
async def generate(request: Request):
prompt = request.query_params.get("prompt", "์•ˆ๋…•ํ•˜์„ธ์š”")
response_text = respond(prompt)
return response_text