Flexi-API / api.py
Yuchan5386's picture
Update api.py
ced5c26 verified
raw
history blame
15.9 kB
import requests
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import asyncio
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse, PlainTextResponse
import sentencepiece as spm
import re
import math
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
origins = [
"https://insect5386.github.io",
"https://insect5386.github.io/insect5386"
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
sp = spm.SentencePieceProcessor()
sp.load("kolig_unigram.model")
pad_id = sp.piece_to_id("<pad>")
if pad_id == -1: pad_id = 0
start_id = sp.piece_to_id("<start>")
if start_id == -1: start_id = 1
end_id = sp.piece_to_id("<end>")
if end_id == -1: end_id = 2
unk_id = sp.piece_to_id("<unk>")
if unk_id == -1: unk_id = 3
vocab_size = sp.get_piece_size()
max_len = 100
def text_to_ids(text):
return sp.encode(text, out_type=int)
def ids_to_text(ids):
return sp.decode(ids)
class RotaryPositionalEmbedding(layers.Layer):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
self.inv_freq = tf.constant(inv_freq, dtype=tf.float32)
def call(self, x):
batch, heads, seq_len, depth = tf.unstack(tf.shape(x))
t = tf.range(seq_len, dtype=tf.float32)
freqs = tf.einsum('i,j->ij', t, self.inv_freq)
emb_sin = tf.sin(freqs)
emb_cos = tf.cos(freqs)
emb_cos = tf.reshape(emb_cos, [1, 1, seq_len, -1])
emb_sin = tf.reshape(emb_sin, [1, 1, seq_len, -1])
x1 = x[..., ::2]
x2 = x[..., 1::2]
x_rotated = tf.stack([
x1 * emb_cos - x2 * emb_sin,
x1 * emb_sin + x2 * emb_cos
], axis=-1)
x_rotated = tf.reshape(x_rotated, tf.shape(x))
return x_rotated
class SwiGLU(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff):
super().__init__()
self.proj = tf.keras.layers.Dense(d_ff * 2)
self.out = tf.keras.layers.Dense(d_model)
def call(self, x):
x_proj = self.proj(x)
x_val, x_gate = tf.split(x_proj, 2, axis=-1)
return self.out(x_val * tf.nn.silu(x_gate))
class Block(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff, num_heads=8, dropout_rate=0.05, adapter_dim=64):
super().__init__()
self.ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
self.adapter_down = tf.keras.layers.Dense(adapter_dim, activation='gelu')
self.adapter_up = tf.keras.layers.Dense(d_model)
self.ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
self.ffn = SwiGLU(d_model, d_ff)
self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
self.rope = RotaryPositionalEmbedding(d_model // num_heads)
def call(self, x, training=False):
x_norm = self.ln1(x)
b, s, _ = tf.shape(x_norm)[0], tf.shape(x_norm)[1], tf.shape(x_norm)[2]
h = self.mha.num_heads
d = x_norm.shape[-1] // h
qkv = tf.reshape(x_norm, [b, s, h, d])
qkv = tf.transpose(qkv, [0, 2, 1, 3])
q = self.rope(qkv)
k = self.rope(qkv)
q = tf.reshape(tf.transpose(q, [0, 2, 1, 3]), [b, s, h * d])
k = tf.reshape(tf.transpose(k, [0, 2, 1, 3]), [b, s, h * d])
attn_out = self.mha(query=q, value=x_norm, key=k, use_causal_mask=True, training=training)
attn_out = self.dropout1(attn_out, training=training)
adapter_out = self.adapter_up(self.adapter_down(attn_out))
attn_out = attn_out + adapter_out
x = x + attn_out
ffn_out = self.ffn(self.ln2(x))
x = x + self.dropout2(ffn_out, training=training)
return x
class Flexi(tf.keras.Model):
def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=8, dropout_rate=0.05):
super().__init__()
self.token_embedding = tf.keras.layers.Embedding(vocab_size, d_model)
self.blocks = [Block(d_model, d_ff, num_heads, dropout_rate) for _ in range(n_layers)]
self.ln_f = tf.keras.layers.LayerNormalization(epsilon=1e-5)
def call(self, x, training=False):
x = self.token_embedding(x)
for block in self.blocks:
x = block(x, training=training)
x = self.ln_f(x)
logits = tf.matmul(x, self.token_embedding.embeddings, transpose_b=True)
return logits
model = Flexi(
vocab_size=vocab_size,
seq_len=max_len,
d_model=256,
d_ff=1024,
n_layers=16
)
dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # ๋ฐฐ์น˜1, ์‹œํ€€์Šค๊ธธ์ด max_len
_ = model(dummy_input) # ๋ชจ๋ธ์ด ๋นŒ๋“œ๋จ
model.load_weights("Flexi.weights.h5")
print("๋ชจ๋ธ ๊ฐ€์ค‘์น˜ ๋กœ๋“œ ์™„๋ฃŒ!")
def is_greedy_response_acceptable(text):
text = text.strip()
# ๋„ˆ๋ฌด ์งง์€ ๋ฌธ์žฅ ๊ฑฐ๋ฅด๊ธฐ
if len(text) < 5:
return False
# ๋‹จ์–ด ์ˆ˜ ๋„ˆ๋ฌด ์ ์€ ๊ฒƒ๋„ ๊ฑฐ๋ฆ„
if len(text.split()) < 3:
return False
# ใ…‹ใ…‹ใ…‹ ๊ฐ™์€ ์ž๋ชจ ์—ฐ์†๋งŒ ์žˆ์œผ๋ฉด ๊ฑฐ๋ฆ„ (๋‹จ, 'ใ…‹ใ…‹' ํฌํ•จ๋˜๋ฉด ํ—ˆ์šฉ)
if re.search(r'[ใ„ฑ-ใ…Žใ…-ใ…ฃ]{3,}', text) and 'ใ…‹ใ…‹' not in text:
return False
# ๋ฌธ์žฅ ๋์ด ์–ด์ƒ‰ํ•œ ๊ฒฝ์šฐ (๋‹ค/์š”/์ฃ  ๋“ฑ ์ผ๋ฐ˜์  ํ˜•ํƒœ๋กœ ๋๋‚˜์ง€ ์•Š์œผ๋ฉด ๊ฑฐ๋ฆ„)
if not re.search(r'(๋‹ค|์š”|์ฃ |๋‹ค\.|์š”\.|์ฃ \.|๋‹ค!|์š”!|์ฃ !|\!|\?|\.)$', text):
return False
return True
def generate_text_sample(model, prompt, max_len=100, max_gen=98,
temperature=0.8, top_k=55, top_p=0.95, min_len=12):
model_input = text_to_ids(f"<start> {prompt} <sep>")
model_input = model_input[:max_len]
generated = list(model_input)
for _ in range(max_gen):
pad_len = max(0, max_len - len(generated))
input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
input_tensor = tf.convert_to_tensor([input_padded])
logits = model(input_tensor, training=False)
next_logits = logits[0, len(generated) - 1].numpy()
# Temperature ์ ์šฉ
next_logits = next_logits / temperature
probs = np.exp(next_logits - np.max(next_logits))
probs = probs / probs.sum()
# Top-K ํ•„ํ„ฐ๋ง
if top_k is not None and top_k > 0:
indices_to_remove = probs < np.sort(probs)[-top_k]
probs[indices_to_remove] = 0
probs /= probs.sum()
# Top-P (๋ˆ„์  ํ™•๋ฅ ) ํ•„ํ„ฐ๋ง
if top_p is not None and 0 < top_p < 1:
sorted_indices = np.argsort(probs)[::-1]
sorted_probs = probs[sorted_indices]
cumulative_probs = np.cumsum(sorted_probs)
# ๋ˆ„์  ํ™•๋ฅ ์ด top_p ์ดˆ๊ณผํ•˜๋Š” ํ† ํฐ๋“ค์€ ์ œ๊ฑฐ
cutoff_index = np.searchsorted(cumulative_probs, top_p, side='right')
probs_to_keep = sorted_indices[:cutoff_index+1]
mask = np.ones_like(probs, dtype=bool)
mask[probs_to_keep] = False
probs[mask] = 0
probs /= probs.sum()
# ์ƒ˜ํ”Œ๋ง
next_token = np.random.choice(len(probs), p=probs)
generated.append(int(next_token))
# ๋””์ฝ”๋”ฉ ๋ฐ ํ›„์ฒ˜๋ฆฌ
decoded = sp.decode(generated)
for t in ["<start>", "<sep>", "<end>"]:
decoded = decoded.replace(t, "")
decoded = decoded.strip()
if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('์š”', '๋‹ค', '.', '!', '?'))):
if is_greedy_response_acceptable(decoded):
return decoded
else:
continue
decoded = sp.decode(generated)
for t in ["<start>", "<sep>", "<end>"]:
decoded = decoded.replace(t, "")
return decoded.strip()
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics.pairwise import cosine_similarity
class SimilarityMemory:
def __init__(self, n_components=100):
self.memory_texts = []
self.vectorizer = TfidfVectorizer()
self.svd = TruncatedSVD(n_components=n_components)
self.embeddings = None
self.fitted = False
def add(self, text: str):
self.memory_texts.append(text)
self._update_embeddings()
def _update_embeddings(self):
if len(self.memory_texts) == 0:
self.embeddings = None
self.fitted = False
return
X = self.vectorizer.fit_transform(self.memory_texts)
n_comp = min(self.svd.n_components, X.shape[1], len(self.memory_texts)-1)
if n_comp <= 0:
self.embeddings = X.toarray()
self.fitted = True
return
self.svd = TruncatedSVD(n_components=n_comp)
self.embeddings = self.svd.fit_transform(X)
self.fitted = True
def retrieve(self, query: str, top_k=3):
if not self.fitted or self.embeddings is None or len(self.memory_texts) == 0:
return []
Xq = self.vectorizer.transform([query])
if self.svd.n_components > Xq.shape[1] or self.svd.n_components > len(self.memory_texts) - 1:
q_emb = Xq.toarray()
else:
q_emb = self.svd.transform(Xq)
sims = cosine_similarity(q_emb, self.embeddings)[0]
top_indices = sims.argsort()[::-1][:top_k]
return [self.memory_texts[i] for i in top_indices]
def process_input(self, new_text: str, top_k=3):
"""์ž๋™์œผ๋กœ ๊ธฐ์–ต ์ €์žฅํ•˜๊ณ , ์œ ์‚ฌํ•œ ๊ธฐ์–ต ์ฐพ์•„์„œ ํ•ฉ์นœ ํ”„๋กฌํ”„ํŠธ ์ƒ์„ฑ"""
related_memories = self.retrieve(new_text, top_k=top_k)
self.add(new_text)
return self.merge_prompt(new_text, related_memories)
def merge_prompt(self, prompt: str, memories: list):
context = "\n".join(memories)
return f"{context}\n\n{prompt}" if context else prompt
def mismatch_tone(input_text, output_text):
if "ใ…‹ใ…‹" in input_text and not re.search(r'ใ…‹ใ…‹|ใ…Ž|์žฌ๋ฐŒ|๋†€|๋งŒ๋‚˜|๋ง›์ง‘|์—ฌํ–‰', output_text):
return True
return False
# ์œ ํšจํ•œ ์‘๋‹ต์ธ์ง€ ๊ฒ€์‚ฌ
def is_valid_response(response):
if len(response.strip()) < 2:
return False
if re.search(r'[ใ„ฑ-ใ…Žใ…-ใ…ฃ]{3,}', response):
return False
if len(response.split()) < 2:
return False
if response.count(' ') < 2:
return False
if any(tok in response.lower() for tok in ['hello', 'this', 'ใ…‹ใ…‹']):
return False
return True
# ์œ„ํ‚ค ์š”์•ฝ ๊ด€๋ จ
def extract_main_query(text):
sentences = re.split(r'[.?!]\s*', text)
sentences = [s.strip() for s in sentences if s.strip()]
if not sentences:
return text
last = sentences[-1]
last = re.sub(r'[^๊ฐ€-ํžฃa-zA-Z0-9 ]', '', last)
particles = ['์ด', '๊ฐ€', '์€', '๋Š”', '์„', '๋ฅผ', '์˜', '์—์„œ', '์—๊ฒŒ', 'ํ•œํ…Œ', '๋ณด๋‹ค']
for p in particles:
last = re.sub(rf'\b(\w+){p}\b', r'\1', last)
return last.strip()
def get_wikipedia_summary(query):
cleaned_query = extract_main_query(query)
url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}"
res = requests.get(url)
if res.status_code == 200:
return res.json().get("extract", "์š”์•ฝ ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
else:
return "์œ„ํ‚ค๋ฐฑ๊ณผ์—์„œ ์ •๋ณด๋ฅผ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
def textrank_summarize(text, top_n=3):
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
sentences = [s.strip() for s in sentences if len(s.strip()) > 10]
if len(sentences) <= top_n:
return text
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform(sentences)
sim_matrix = cosine_similarity(tfidf_matrix)
np.fill_diagonal(sim_matrix, 0)
def pagerank(matrix, damping=0.85, max_iter=100, tol=1e-4):
N = matrix.shape[0]
ranks = np.ones(N) / N
row_sums = np.sum(matrix, axis=1)
row_sums[row_sums == 0] = 1
for _ in range(max_iter):
prev_ranks = ranks.copy()
for i in range(N):
incoming = matrix[:, i]
ranks[i] = (1 - damping) / N + damping * np.sum(incoming * prev_ranks / row_sums)
if np.linalg.norm(ranks - prev_ranks) < tol:
break
return ranks
scores = pagerank(sim_matrix)
ranked_idx = np.argsort(scores)[::-1]
selected_idx = sorted(ranked_idx[:top_n])
summary = ' '.join([sentences[i] for i in selected_idx])
return summary
def summarize_from_wikipedia(query, top_n=3):
raw_summary = get_wikipedia_summary(query)
first_summary = textrank_summarize(raw_summary, top_n=top_n)
second_summary = textrank_summarize(first_summary, top_n=top_n)
return second_summary
def simple_intent_classifier(text):
text = text.lower()
greet_keywords = ["์•ˆ๋…•", "๋ฐ˜๊ฐ€์›Œ", "์ด๋ฆ„", "๋ˆ„๊ตฌ", "์†Œ๊ฐœ", "์–ด๋””์„œ ์™”", "์ •์ฒด", "๋ช‡ ์‚ด", "๋„ˆ ๋ญ์•ผ"]
info_keywords = ["์„ค๋ช…", "์ •๋ณด", "๋ฌด์—‡", "๋ญ์•ผ", "์–ด๋””", "๋ˆ„๊ตฌ", "์™œ", "์–ด๋–ป๊ฒŒ", "์ข…๋ฅ˜", "๊ฐœ๋…"]
if any(kw in text for kw in greet_keywords):
return "์ธ์‚ฌ"
elif any(kw in text for kw in info_keywords):
return "์ •๋ณด์งˆ๋ฌธ"
else:
return "์ผ์ƒ๋Œ€ํ™”"
def respond(input_text):
# 1) ์‚ฌ์šฉ์ž ์ž…๋ ฅ ๊ธฐ์–ต์— ์ €์žฅ (์›ํ•˜๋ฉด)
memory.process_input(input_text)
intent = simple_intent_classifier(input_text)
if "์ด๋ฆ„" in input_text:
response = "์ œ ์ด๋ฆ„์€ Flexi์ž…๋‹ˆ๋‹ค."
memory.process_input(response) # ๋‹ต๋ณ€๋„ ๊ธฐ์–ต์— ์ถ”๊ฐ€ ๊ฐ€๋Šฅ
return response
if "๋ˆ„๊ตฌ" in input_text:
response = "์ €๋Š” Flexi๋ผ๊ณ  ํ•ด์š”."
memory.process_input(response)
return response
if intent == "์ •๋ณด์งˆ๋ฌธ":
keyword = re.sub(r"(์— ๋Œ€ํ•ด|์— ๋Œ€ํ•œ|์— ๋Œ€ํ•ด์„œ)?\s*(์„ค๋ช…ํ•ด์ค˜|์•Œ๋ ค์ค˜|๋ญ์•ผ|๊ฐœ๋…|์ •์˜|์ •๋ณด)?", "", input_text).strip()
if not keyword:
response = "์–ด๋–ค ์ฃผ์ œ์— ๋Œ€ํ•ด ๊ถ๊ธˆํ•œ๊ฐ€์š”?"
memory.add(response)
return response
summary = summarize_from_wikipedia(keyword)
response = f"{summary}\n๋‹ค๋ฅธ ๊ถ๊ธˆํ•œ ์  ์žˆ์œผ์‹ ๊ฐ€์š”?"
return response
# ๊ธฐ์–ต์—์„œ ์œ ์‚ฌ ๋ฌธ์žฅ ๊บผ๋‚ด์„œ ํ”„๋กฌํ”„ํŠธ ๋งŒ๋“ค๊ธฐ
related_memories = memory.retrieve(input_text, top_k=3)
merged_prompt = merge_prompt_with_memory(input_text, related_memories)
# ๋ชจ๋ธ๋กœ ์‘๋‹ต ์ƒ์„ฑ
response = generate_text_sample(model, merged_prompt)
# ์‘๋‹ต ๊ฒ€์ฆ, ์•ˆ ๋งž์œผ๋ฉด ์žฌ์ƒ์„ฑ
if not is_valid_response(response) or mismatch_tone(input_text, response):
response = generate_text_sample(model, merged_prompt)
# ์ตœ์ข… ์‘๋‹ต๋„ ๊ธฐ์–ต์— ์ถ”๊ฐ€
memory.process_input(response)
return response
@app.get("/generate", response_class=PlainTextResponse)
async def generate(request: Request):
prompt = request.query_params.get("prompt", "์•ˆ๋…•ํ•˜์„ธ์š”")
response_text = respond(prompt)
return response_text