InteractGPT-API / api.py
Yuchan5386's picture
Update api.py
009dbf5 verified
raw
history blame
8.59 kB
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import asyncio
import json
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import sentencepiece as spm
import requests
app = FastAPI()
sp = spm.SentencePieceProcessor()
sp.load("kolig_unigram.model")
pad_id = sp.piece_to_id("<pad>")
if pad_id == -1: pad_id = 0
start_id = sp.piece_to_id("<start>")
if start_id == -1: start_id = 1
end_id = sp.piece_to_id("< end >")
if end_id == -1: end_id = 2
unk_id = sp.piece_to_id("<unk>")
if unk_id == -1: unk_id = 3
vocab_size = sp.get_piece_size()
max_len = 100
def text_to_ids(text):
return sp.encode(text, out_type=int)
def ids_to_text(ids):
return sp.decode(ids)
class RotaryPositionalEmbedding(layers.Layer):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
self.inv_freq = tf.constant(inv_freq, dtype=tf.float32)
def call(self, x):
batch, heads, seq_len, depth = tf.unstack(tf.shape(x))
t = tf.range(seq_len, dtype=tf.float32)
freqs = tf.einsum('i,j->ij', t, self.inv_freq)
emb_sin = tf.sin(freqs)
emb_cos = tf.cos(freqs)
emb_cos = tf.reshape(emb_cos, [1, 1, seq_len, -1])
emb_sin = tf.reshape(emb_sin, [1, 1, seq_len, -1])
x1 = x[..., ::2]
x2 = x[..., 1::2]
x_rotated = tf.stack([
x1 * emb_cos - x2 * emb_sin,
x1 * emb_sin + x2 * emb_cos
], axis=-1)
x_rotated = tf.reshape(x_rotated, tf.shape(x))
return x_rotated
class SwiGLU(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff):
super().__init__()
self.proj = tf.keras.layers.Dense(d_ff * 2)
self.out = tf.keras.layers.Dense(d_model)
def call(self, x):
x_proj = self.proj(x)
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
return self.out(x_val * tf.nn.silu(x_gate))
class GPTBlock(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff, num_heads=8, dropout_rate=0.1, adapter_dim=64):
super().__init__()
self.ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
self.adapter_down = tf.keras.layers.Dense(adapter_dim, activation='gelu')
self.adapter_up = tf.keras.layers.Dense(d_model)
self.ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
self.ffn = SwiGLU(d_model, d_ff)
self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
self.rope = RotaryPositionalEmbedding(d_model // num_heads)
def call(self, x, training=False):
x_norm = self.ln1(x)
b, s, _ = tf.shape(x_norm)[0], tf.shape(x_norm)[1], tf.shape(x_norm)[2]
h = self.mha.num_heads
d = x_norm.shape[-1] // h
qkv = tf.reshape(x_norm, [b, s, h, d])
qkv = tf.transpose(qkv, [0, 2, 1, 3])
q = self.rope(qkv)
k = self.rope(qkv)
q = tf.reshape(tf.transpose(q, [0, 2, 1, 3]), [b, s, h * d])
k = tf.reshape(tf.transpose(k, [0, 2, 1, 3]), [b, s, h * d])
attn_out = self.mha(query=q, value=x_norm, key=k, use_causal_mask=True, training=training)
attn_out = self.dropout1(attn_out, training=training)
adapter_out = self.adapter_up(self.adapter_down(attn_out))
attn_out = attn_out + adapter_out
x = x + attn_out
ffn_out = self.ffn(self.ln2(x))
x = x + self.dropout2(ffn_out, training=training)
return x
class InteractGPT(tf.keras.Model):
def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=8, dropout_rate=0.1):
super().__init__()
self.token_embedding = tf.keras.layers.Embedding(vocab_size, d_model)
self.blocks = [GPTBlock(d_model, d_ff, num_heads, dropout_rate) for _ in range(n_layers)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=1e-5)
def call(self, x, training=False):
x = self.token_embedding(x)
for block in self.blocks:
x = block(x, training=training)
x = self.ln_f(x)
logits = tf.matmul(x, self.token_embedding.embeddings, transpose_b=True)
return logits
model = InteractGPT(vocab_size=vocab_size, seq_len=max_len, d_model=256, d_ff=1024, n_layers=6)
dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # 배치1, 시퀀스길이 max_len
_ = model(dummy_input) # 모델이 빌드됨
model.load_weights("InteractGPT.weights.h5")
print("모델 가중치 로드 완료!")
def decode_sp_tokens(tokens):
text = ''.join(tokens).replace('▁', ' ').strip()
return text
def generate_text_mirostat_top_p(model, prompt, max_len=100, max_gen=98,
temperature=1.0, min_len=20,
repetition_penalty=1.2, eta=0.1, m=100, p=0.9):
model_input = text_to_ids(f"<start> {prompt} <sep>")
model_input = model_input[:max_len]
generated = list(model_input)
text_so_far = []
tau = 5.0 # 초기 목표 surprise
for step in range(max_gen):
pad_length = max(0, max_len - len(generated))
input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
input_tensor = tf.convert_to_tensor([input_padded])
logits = model(input_tensor, training=False)
next_token_logits = logits[0, len(generated) - 1].numpy()
# 반복 페널티 적용
token_counts = {}
for t in generated:
token_counts[t] = token_counts.get(t, 0) + 1
for token_id, count in token_counts.items():
next_token_logits[token_id] /= (repetition_penalty ** count)
# 최소 길이 넘으면 종료 토큰 확률 낮추기
if len(generated) >= min_len:
next_token_logits[end_id] -= 5.0
next_token_logits[pad_id] -= 10.0
# 온도 조절
next_token_logits = next_token_logits / temperature
# --- 미로스타트 + Top-p 샘플링 시작 ---
logits_stable = next_token_logits - np.max(next_token_logits)
probs = np.exp(logits_stable)
probs /= probs.sum()
# 1. mirostat top-m 후보 추리기
sorted_indices = np.argsort(-probs)
top_indices = sorted_indices[:m]
top_probs = probs[top_indices]
top_probs /= top_probs.sum()
# 2. mirostat 샘플링
sampled_index = np.random.choice(top_indices, p=top_probs)
sampled_prob = probs[sampled_index]
observed_surprise = -np.log(sampled_prob + 1e-9)
tau += eta * (observed_surprise - tau)
# 3. top-p 필터링
sorted_top_indices = top_indices[np.argsort(-top_probs)]
sorted_top_probs = np.sort(top_probs)[::-1]
cumulative_probs = np.cumsum(sorted_top_probs)
cutoff = np.searchsorted(cumulative_probs, p, side='left') + 1
filtered_indices = sorted_top_indices[:cutoff]
filtered_probs = sorted_top_probs[:cutoff]
filtered_probs /= filtered_probs.sum()
# 4. 최종 토큰은 filtered 집합에서 다시 샘플링
final_token = np.random.choice(filtered_indices, p=filtered_probs)
generated.append(int(final_token))
next_word = sp.id_to_piece(int(final_token))
text_so_far.append(next_word)
decoded_text = decode_sp_tokens(text_so_far)
if len(generated) >= min_len and final_token == end_id:
break
if len(generated) >= min_len and decoded_text.endswith(('.', '!', '?', '<end>')):
break
yield decoded_text
async def async_generator_wrapper(prompt: str):
# 동기 제너레이터를 비동기로 감싸기
loop = asyncio.get_event_loop()
gen = generate_text_mirostat_top_p(model, prompt)
for text_piece in gen:
yield text_piece
# 토큰 생성 속도 조절 (0.1초 딜레이)
await asyncio.sleep(0.1)
@app.get("/generate")
async def generate(request: Request):
# 쿼리 파라미터로 prompt 받음, 없으면 기본값
prompt = request.query_params.get("prompt", "안녕하세요")
# 스트리밍 응답으로 보냄
return StreamingResponse(async_generator_wrapper(prompt), media_type="text/plain")