InteractGPT-API / api.py
Yuchan5386's picture
Update api.py
35d657e verified
raw
history blame
10.5 kB
import requests
import numpy as np
import tensorflow as tf
import asyncio
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import nltk
nltk.download('punkt')
from nltk.tokenize import
app = FastAPI()
sp = spm.SentencePieceProcessor()
sp.load("kolig_unigram.model")
pad_id = sp.piece_to_id("<pad>")
if pad_id == -1: pad_id = 0
start_id = sp.piece_to_id("<start>")
if start_id == -1: start_id = 1
end_id = sp.piece_to_id("< end >")
if end_id == -1: end_id = 2
unk_id = sp.piece_to_id("<unk>")
if unk_id == -1: unk_id = 3
vocab_size = sp.get_piece_size()
max_len = 100
def text_to_ids(text):
return sp.encode(text, out_type=int)
def ids_to_text(ids):
return sp.decode(ids)
class RotaryPositionalEmbedding(layers.Layer):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
self.inv_freq = tf.constant(inv_freq, dtype=tf.float32)
def call(self, x):
batch, heads, seq_len, depth = tf.unstack(tf.shape(x))
t = tf.range(seq_len, dtype=tf.float32)
freqs = tf.einsum('i,j->ij', t, self.inv_freq)
emb_sin = tf.sin(freqs)
emb_cos = tf.cos(freqs)
emb_cos = tf.reshape(emb_cos, [1, 1, seq_len, -1])
emb_sin = tf.reshape(emb_sin, [1, 1, seq_len, -1])
x1 = x[..., ::2]
x2 = x[..., 1::2]
x_rotated = tf.stack([
x1 * emb_cos - x2 * emb_sin,
x1 * emb_sin + x2 * emb_cos
], axis=-1)
x_rotated = tf.reshape(x_rotated, tf.shape(x))
return x_rotated
class SwiGLU(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff):
super().__init__()
self.proj = tf.keras.layers.Dense(d_ff * 2)
self.out = tf.keras.layers.Dense(d_model)
def call(self, x):
x_proj = self.proj(x)
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
return self.out(x_val * tf.nn.silu(x_gate))
class GPTBlock(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff, num_heads=8, dropout_rate=0.1, adapter_dim=64):
super().__init__()
self.ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
self.adapter_down = tf.keras.layers.Dense(adapter_dim, activation='gelu')
self.adapter_up = tf.keras.layers.Dense(d_model)
self.ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
self.ffn = SwiGLU(d_model, d_ff)
self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
self.rope = RotaryPositionalEmbedding(d_model // num_heads)
def call(self, x, training=False):
x_norm = self.ln1(x)
b, s, _ = tf.shape(x_norm)[0], tf.shape(x_norm)[1], tf.shape(x_norm)[2]
h = self.mha.num_heads
d = x_norm.shape[-1] // h
qkv = tf.reshape(x_norm, [b, s, h, d])
qkv = tf.transpose(qkv, [0, 2, 1, 3])
q = self.rope(qkv)
k = self.rope(qkv)
q = tf.reshape(tf.transpose(q, [0, 2, 1, 3]), [b, s, h * d])
k = tf.reshape(tf.transpose(k, [0, 2, 1, 3]), [b, s, h * d])
attn_out = self.mha(query=q, value=x_norm, key=k, use_causal_mask=True, training=training)
attn_out = self.dropout1(attn_out, training=training)
adapter_out = self.adapter_up(self.adapter_down(attn_out))
attn_out = attn_out + adapter_out
x = x + attn_out
ffn_out = self.ffn(self.ln2(x))
x = x + self.dropout2(ffn_out, training=training)
return x
class InteractGPT(tf.keras.Model):
def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=8, dropout_rate=0.1):
super().__init__()
self.token_embedding = tf.keras.layers.Embedding(vocab_size, d_model)
self.blocks = [GPTBlock(d_model, d_ff, num_heads, dropout_rate) for _ in range(n_layers)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=1e-5)
def call(self, x, training=False):
x = self.token_embedding(x)
for block in self.blocks:
x = block(x, training=training)
x = self.ln_f(x)
logits = tf.matmul(x, self.token_embedding.embeddings, transpose_b=True)
return logits
model = InteractGPT(vocab_size=vocab_size, seq_len=max_len, d_model=256, d_ff=1024, n_layers=6)
dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # 배치1, 시퀀스길이 max_len
_ = model(dummy_input) # 모델이 빌드됨
model.load_weights("InteractGPT.weights.h5")
print("모델 가중치 로드 완료!")
def extract_main_query(query):
words = query.split()
return " ".join(words[:3])
def get_wikipedia_summary(query):
cleaned_query = extract_main_query(query)
url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}"
res = requests.get(url)
if res.status_code == 200:
return res.json().get("extract", "요약 정보를 찾을 수 없습니다.")
else:
return "위키백과에서 정보를 가져올 수 없습니다."
def summarize_text(text, top_n=3):
sentences = sent_tokenize(text)
if len(sentences) <= top_n:
return text
vectorizer = TfidfVectorizer(ngram_range=(1, 2), stop_words=['은', '는', '이', '가', '을', '를', '에', '에서'])
tfidf_matrix = vectorizer.fit_transform(sentences)
sim_matrix = cosine_similarity(tfidf_matrix, tfidf_matrix)
np.fill_diagonal(sim_matrix, 0)
scores = sim_matrix.sum(axis=1)
ranked_idx = np.argsort(scores)[::-1]
selected_idx = sorted(ranked_idx[:top_n])
summary = " ".join([sentences[i] for i in selected_idx])
return summary
def simple_intent_classifier(text):
text = text.lower()
greet_keywords = ["안녕", "반가워", "이름", "누구", "소개", "어디서 왔", "정체", "몇 살", "너 뭐야"]
info_keywords = ["설명", "정보", "무엇", "뭐야", "어디", "누구", "왜", "어떻게", "종류", "개념"]
if any(kw in text for kw in greet_keywords):
return "인사"
elif any(kw in text for kw in info_keywords):
return "정보질문"
else:
return "일상대화"
def generate_text_mirostat_top_p(model, prompt, max_len=100, max_gen=98,
temperature=1.0, min_len=20,
repetition_penalty=1.2, eta=0.1, m=100, p=0.9):
model_input = text_to_ids(f"<start> {prompt} <sep>")
model_input = model_input[:max_len]
generated = list(model_input)
tau = 5.0 # 초기 목표 surprise
for step in range(max_gen):
pad_length = max(0, max_len - len(generated))
input_padded = np.pad(generated, (0, pad_length), constant_values=pad_id)
input_tensor = tf.convert_to_tensor([input_padded])
logits = model(input_tensor, training=False)
next_token_logits = logits[0, len(generated) - 1].numpy()
# 반복 페널티 적용
token_counts = {}
for t in generated:
token_counts[t] = token_counts.get(t, 0) + 1
for token_id, count in token_counts.items():
next_token_logits[token_id] /= (repetition_penalty ** count)
# 최소 길이 넘으면 종료 토큰 확률 낮추기
if len(generated) >= min_len:
next_token_logits[end_id] -= 5.0
next_token_logits[pad_id] -= 10.0
# 온도 조절
next_token_logits = next_token_logits / temperature
# --- 미로스타트 + Top-p 샘플링 ---
logits_stable = next_token_logits - np.max(next_token_logits)
probs = np.exp(logits_stable)
probs /= probs.sum()
sorted_indices = np.argsort(-probs)
top_indices = sorted_indices[:m]
top_probs = probs[top_indices]
top_probs /= top_probs.sum()
sampled_index = np.random.choice(top_indices, p=top_probs)
sampled_prob = probs[sampled_index]
observed_surprise = -np.log(sampled_prob + 1e-9)
tau += eta * (observed_surprise - tau)
sorted_top_indices = top_indices[np.argsort(-top_probs)]
sorted_top_probs = np.sort(top_probs)[::-1]
cumulative_probs = np.cumsum(sorted_top_probs)
cutoff = np.searchsorted(cumulative_probs, p, side='left') + 1
filtered_indices = sorted_top_indices[:cutoff]
filtered_probs = sorted_top_probs[:cutoff]
filtered_probs /= filtered_probs.sum()
final_token = np.random.choice(filtered_indices, p=filtered_probs)
generated.append(int(final_token))
decoded_text = decode_ids(generated)
for token in ["<start>", "<sep>", "<end>"]:
decoded_text = decoded_text.replace(token, "")
decoded_text = decoded_text.strip()
if len(generated) >= min_len and (final_token == end_id or decoded_text.endswith(('.', '!', '?'))):
yield decoded_text
break
async def async_generator_wrapper(prompt: str):
intent = simple_intent_classifier(prompt)
if intent == "정보질문":
wiki_summary = get_wikipedia_summary(prompt)
summarized = summarize_text(wiki_summary, top_n=3)
yield f"『 \"{prompt}\" 에 대한 위키백과 요약입니다. 』\n\n{summarized}\n\n"
# 이후 일반 생성으로 이어감 (스트리밍)
gen = generate_text_mirostat_top_p(model, prompt)
for text_piece in gen:
yield text_piece
await asyncio.sleep(0.1)
@app.get("/generate")
async def generate(request: Request):
prompt = request.query_params.get("prompt", "안녕하세요")
return StreamingResponse(async_generator_wrapper(prompt), media_type="text/plain")