Yuchan5386 commited on
Commit
77d5918
Β·
verified Β·
1 Parent(s): 8dae3bf

Create api.py

Browse files
Files changed (1) hide show
  1. api.py +360 -0
api.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras import layers
5
+ import asyncio
6
+ from fastapi import FastAPI, Request
7
+ from fastapi.responses import StreamingResponse, PlainTextResponse
8
+ import sentencepiece as spm
9
+ import re
10
+ import math
11
+ from sklearn.feature_extraction.text import TfidfVectorizer
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+
14
+ app = FastAPI()
15
+
16
+
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+
19
+ origins = [
20
+ "https://insect5386.github.io",
21
+ "https://insect5386.github.io/insect5386"
22
+ ]
23
+
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=origins,
27
+ allow_credentials=True,
28
+ allow_methods=["*"],
29
+ allow_headers=["*"],
30
+ )
31
+
32
+ sp = spm.SentencePieceProcessor()
33
+ sp.load("kolig_unigram.model")
34
+
35
+ pad_id = sp.piece_to_id("<pad>")
36
+ if pad_id == -1: pad_id = 0
37
+ start_id = sp.piece_to_id("<start>")
38
+ if start_id == -1: start_id = 1
39
+ end_id = sp.piece_to_id("<end>")
40
+ if end_id == -1: end_id = 2
41
+ unk_id = sp.piece_to_id("<unk>")
42
+ if unk_id == -1: unk_id = 3
43
+
44
+ vocab_size = sp.get_piece_size()
45
+ max_len = 100
46
+
47
+ def text_to_ids(text):
48
+ return sp.encode(text, out_type=int)
49
+
50
+ def ids_to_text(ids):
51
+ return sp.decode(ids)
52
+
53
+ class RotaryPositionalEmbedding(layers.Layer):
54
+ def __init__(self, dim):
55
+ super().__init__()
56
+ inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
57
+ self.inv_freq = tf.constant(inv_freq, dtype=tf.float32)
58
+
59
+ def call(self, x):
60
+ batch, heads, seq_len, depth = tf.unstack(tf.shape(x))
61
+ t = tf.range(seq_len, dtype=tf.float32)
62
+ freqs = tf.einsum('i,j->ij', t, self.inv_freq)
63
+ emb_sin = tf.sin(freqs)
64
+ emb_cos = tf.cos(freqs)
65
+ emb_cos = tf.reshape(emb_cos, [1, 1, seq_len, -1])
66
+ emb_sin = tf.reshape(emb_sin, [1, 1, seq_len, -1])
67
+ x1 = x[..., ::2]
68
+ x2 = x[..., 1::2]
69
+ x_rotated = tf.stack([
70
+ x1 * emb_cos - x2 * emb_sin,
71
+ x1 * emb_sin + x2 * emb_cos
72
+ ], axis=-1)
73
+ x_rotated = tf.reshape(x_rotated, tf.shape(x))
74
+ return x_rotated
75
+
76
+ class SwiGLU(tf.keras.layers.Layer):
77
+ def __init__(self, d_model, d_ff):
78
+ super().__init__()
79
+ self.proj = tf.keras.layers.Dense(d_ff * 2)
80
+ self.out = tf.keras.layers.Dense(d_model)
81
+
82
+ def call(self, x):
83
+ x_proj = self.proj(x)
84
+ x_val, x_gate = tf.split(x_proj, 2, axis=-1)
85
+ return self.out(x_val * tf.nn.silu(x_gate))
86
+
87
+ class GPTBlock(tf.keras.layers.Layer):
88
+ def __init__(self, d_model, d_ff, num_heads=8, dropout_rate=0.1, adapter_dim=64):
89
+ super().__init__()
90
+ self.ln1 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
91
+ self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model // num_heads)
92
+ self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
93
+ self.adapter_down = tf.keras.layers.Dense(adapter_dim, activation='gelu')
94
+ self.adapter_up = tf.keras.layers.Dense(d_model)
95
+
96
+ self.ln2 = tf.keras.layers.LayerNormalization(epsilon=1e-5)
97
+ self.ffn = SwiGLU(d_model, d_ff)
98
+ self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
99
+ self.rope = RotaryPositionalEmbedding(d_model // num_heads)
100
+
101
+ def call(self, x, training=False):
102
+ x_norm = self.ln1(x)
103
+ b, s, _ = tf.shape(x_norm)[0], tf.shape(x_norm)[1], tf.shape(x_norm)[2]
104
+ h = self.mha.num_heads
105
+ d = x_norm.shape[-1] // h
106
+
107
+ qkv = tf.reshape(x_norm, [b, s, h, d])
108
+ qkv = tf.transpose(qkv, [0, 2, 1, 3])
109
+ q = self.rope(qkv)
110
+ k = self.rope(qkv)
111
+ q = tf.reshape(tf.transpose(q, [0, 2, 1, 3]), [b, s, h * d])
112
+ k = tf.reshape(tf.transpose(k, [0, 2, 1, 3]), [b, s, h * d])
113
+
114
+ attn_out = self.mha(query=q, value=x_norm, key=k, use_causal_mask=True, training=training)
115
+ attn_out = self.dropout1(attn_out, training=training)
116
+
117
+ adapter_out = self.adapter_up(self.adapter_down(attn_out))
118
+ attn_out = attn_out + adapter_out
119
+
120
+ x = x + attn_out
121
+ ffn_out = self.ffn(self.ln2(x))
122
+ x = x + self.dropout2(ffn_out, training=training)
123
+ return x
124
+
125
+ class InteractGPT(tf.keras.Model):
126
+ def __init__(self, vocab_size, seq_len, d_model, d_ff, n_layers, num_heads=8, dropout_rate=0.1):
127
+ super().__init__()
128
+ self.token_embedding = tf.keras.layers.Embedding(vocab_size, d_model)
129
+ self.blocks = [GPTBlock(d_model, d_ff, num_heads, dropout_rate) for _ in range(n_layers)]
130
+ self.ln_f = tf.keras.layers.LayerNormalization(epsilon=1e-5)
131
+
132
+ def call(self, x, training=False):
133
+ x = self.token_embedding(x)
134
+ for block in self.blocks:
135
+ x = block(x, training=training)
136
+ x = self.ln_f(x)
137
+ logits = tf.matmul(x, self.token_embedding.embeddings, transpose_b=True)
138
+ return logits
139
+
140
+ model = InteractGPT(vocab_size=vocab_size, seq_len=max_len, d_model=256, d_ff=1024, n_layers=6)
141
+
142
+ dummy_input = tf.zeros((1, max_len), dtype=tf.int32) # 배치1, μ‹œν€€μŠ€κΈΈμ΄ max_len
143
+ _ = model(dummy_input) # λͺ¨λΈμ΄ λΉŒλ“œλ¨
144
+ model.load_weights("Flexi.weights.h5")
145
+ print("λͺ¨λΈ κ°€μ€‘μΉ˜ λ‘œλ“œ μ™„λ£Œ!")
146
+
147
+
148
+ def is_greedy_response_acceptable(text):
149
+ text = text.strip()
150
+
151
+ # λ„ˆλ¬΄ 짧은 λ¬Έμž₯ κ±°λ₯΄κΈ°
152
+ if len(text) < 5:
153
+ return False
154
+
155
+ # 단어 수 λ„ˆλ¬΄ 적은 것도 거름
156
+ if len(text.split()) < 3:
157
+ return False
158
+
159
+ # γ…‹γ…‹γ…‹ 같은 자λͺ¨ μ—°μ†λ§Œ 있으면 거름 (단, 'γ…‹γ…‹' ν¬ν•¨λ˜λ©΄ ν—ˆμš©)
160
+ if re.search(r'[γ„±-γ…Žγ…-γ…£]{3,}', text) and 'γ…‹γ…‹' not in text:
161
+ return False
162
+
163
+ # λ¬Έμž₯ 끝이 μ–΄μƒ‰ν•œ 경우 (λ‹€/μš”/μ£  λ“± 일반적 ν˜•νƒœλ‘œ λλ‚˜μ§€ μ•ŠμœΌλ©΄ 거름)
164
+ if not re.search(r'(λ‹€|μš”|μ£ |λ‹€\.|μš”\.|μ£ \.|λ‹€!|μš”!|μ£ !|\!|\?|\.)$', text):
165
+ return False
166
+
167
+ return True
168
+
169
+ def generate_text_sample(model, prompt, max_len=100, max_gen=98,
170
+ temperature=0.7, top_k=40, top_p=0.9, min_len=12):
171
+ model_input = text_to_ids(f"<start> {prompt} <sep>")
172
+ model_input = model_input[:max_len]
173
+ generated = list(model_input)
174
+
175
+ for _ in range(max_gen):
176
+ pad_len = max(0, max_len - len(generated))
177
+ input_padded = np.pad(generated, (0, pad_len), constant_values=pad_id)
178
+ input_tensor = tf.convert_to_tensor([input_padded])
179
+ logits = model(input_tensor, training=False)
180
+ next_logits = logits[0, len(generated) - 1].numpy()
181
+
182
+ # Temperature 적용
183
+ next_logits = next_logits / temperature
184
+ probs = np.exp(next_logits - np.max(next_logits))
185
+ probs = probs / probs.sum()
186
+
187
+ # Top-K 필터링
188
+ if top_k is not None and top_k > 0:
189
+ indices_to_remove = probs < np.sort(probs)[-top_k]
190
+ probs[indices_to_remove] = 0
191
+ probs /= probs.sum()
192
+
193
+ # Top-P (λˆ„μ  ν™•λ₯ ) 필터링
194
+ if top_p is not None and 0 < top_p < 1:
195
+ sorted_indices = np.argsort(probs)[::-1]
196
+ sorted_probs = probs[sorted_indices]
197
+ cumulative_probs = np.cumsum(sorted_probs)
198
+ # λˆ„μ  ν™•λ₯ μ΄ top_p μ΄ˆκ³Όν•˜λŠ” 토큰듀은 제거
199
+ cutoff_index = np.searchsorted(cumulative_probs, top_p, side='right')
200
+ probs_to_keep = sorted_indices[:cutoff_index+1]
201
+
202
+ mask = np.ones_like(probs, dtype=bool)
203
+ mask[probs_to_keep] = False
204
+ probs[mask] = 0
205
+ probs /= probs.sum()
206
+
207
+ # μƒ˜ν”Œλ§
208
+ next_token = np.random.choice(len(probs), p=probs)
209
+ generated.append(int(next_token))
210
+
211
+ # λ””μ½”λ”© 및 ν›„μ²˜λ¦¬
212
+ decoded = sp.decode(generated)
213
+ for t in ["<start>", "<sep>", "<end>"]:
214
+ decoded = decoded.replace(t, "")
215
+ decoded = decoded.strip()
216
+
217
+ if len(generated) >= min_len and (next_token == end_id or decoded.endswith(('μš”', 'λ‹€', '.', '!', '?'))):
218
+ if is_greedy_response_acceptable(decoded):
219
+ return decoded
220
+ else:
221
+ continue
222
+
223
+ decoded = sp.decode(generated)
224
+ for t in ["<start>", "<sep>", "<end>"]:
225
+ decoded = decoded.replace(t, "")
226
+ return decoded.strip()
227
+
228
+ def mismatch_tone(input_text, output_text):
229
+ if "γ…‹γ…‹" in input_text and not re.search(r'γ…‹γ…‹|γ…Ž|재밌|놀|λ§Œλ‚˜|λ§›μ§‘|μ—¬ν–‰', output_text):
230
+ return True
231
+ return False
232
+
233
+ # μœ νš¨ν•œ 응닡인지 검사
234
+ def is_valid_response(response):
235
+ if len(response.strip()) < 2:
236
+ return False
237
+ if re.search(r'[γ„±-γ…Žγ…-γ…£]{3,}', response):
238
+ return False
239
+ if len(response.split()) < 2:
240
+ return False
241
+ if response.count(' ') < 2:
242
+ return False
243
+ if any(tok in response.lower() for tok in ['hello', 'this', 'γ…‹γ…‹']):
244
+ return False
245
+ return True
246
+
247
+ # μœ„ν‚€ μš”μ•½ κ΄€λ ¨
248
+ def extract_main_query(text):
249
+ sentences = re.split(r'[.?!]\s*', text)
250
+ sentences = [s.strip() for s in sentences if s.strip()]
251
+ if not sentences:
252
+ return text
253
+ last = sentences[-1]
254
+ last = re.sub(r'[^κ°€-힣a-zA-Z0-9 ]', '', last)
255
+ particles = ['이', 'κ°€', '은', 'λŠ”', '을', 'λ₯Ό', '의', 'μ—μ„œ', 'μ—κ²Œ', 'ν•œν…Œ', '보닀']
256
+ for p in particles:
257
+ last = re.sub(rf'\b(\w+){p}\b', r'\1', last)
258
+ return last.strip()
259
+
260
+ def get_wikipedia_summary(query):
261
+ cleaned_query = extract_main_query(query)
262
+ url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}"
263
+ res = requests.get(url)
264
+ if res.status_code == 200:
265
+ return res.json().get("extract", "μš”μ•½ 정보λ₯Ό 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
266
+ else:
267
+ return "μœ„ν‚€λ°±κ³Όμ—μ„œ 정보λ₯Ό κ°€μ Έμ˜¬ 수 μ—†μŠ΅λ‹ˆλ‹€."
268
+
269
+ def textrank_summarize(text, top_n=3):
270
+ sentences = re.split(r'(?<=[.!?])\s+', text.strip())
271
+ sentences = [s.strip() for s in sentences if len(s.strip()) > 10]
272
+ if len(sentences) <= top_n:
273
+ return text
274
+ vectorizer = TfidfVectorizer()
275
+ tfidf_matrix = vectorizer.fit_transform(sentences)
276
+ sim_matrix = cosine_similarity(tfidf_matrix)
277
+ np.fill_diagonal(sim_matrix, 0)
278
+ def pagerank(matrix, damping=0.85, max_iter=100, tol=1e-4):
279
+ N = matrix.shape[0]
280
+ ranks = np.ones(N) / N
281
+ row_sums = np.sum(matrix, axis=1)
282
+ row_sums[row_sums == 0] = 1
283
+ for _ in range(max_iter):
284
+ prev_ranks = ranks.copy()
285
+ for i in range(N):
286
+ incoming = matrix[:, i]
287
+ ranks[i] = (1 - damping) / N + damping * np.sum(incoming * prev_ranks / row_sums)
288
+ if np.linalg.norm(ranks - prev_ranks) < tol:
289
+ break
290
+ return ranks
291
+ scores = pagerank(sim_matrix)
292
+ ranked_idx = np.argsort(scores)[::-1]
293
+ selected_idx = sorted(ranked_idx[:top_n])
294
+ summary = ' '.join([sentences[i] for i in selected_idx])
295
+ return summary
296
+
297
+ def summarize_from_wikipedia(query, top_n=3):
298
+ raw_summary = get_wikipedia_summary(query)
299
+ first_summary = textrank_summarize(raw_summary, top_n=top_n)
300
+ second_summary = textrank_summarize(first_summary, top_n=top_n)
301
+ return second_summary
302
+
303
+
304
+ def simple_intent_classifier(text):
305
+ text = text.lower()
306
+ greet_keywords = ["μ•ˆλ…•", "λ°˜κ°€μ›Œ", "이름", "λˆ„κ΅¬", "μ†Œκ°œ", "μ–΄λ””μ„œ μ™”", "정체", "λͺ‡ μ‚΄", "λ„ˆ 뭐야"]
307
+ info_keywords = ["μ„€λͺ…", "정보", "무엇", "뭐야", "μ–΄λ””", "λˆ„κ΅¬", "μ™œ", "μ–΄λ–»κ²Œ", "μ’…λ₯˜", "κ°œλ…"]
308
+ math_keywords = ["λ”ν•˜κΈ°", "λΉΌκΈ°", "κ³±ν•˜κΈ°", "λ‚˜λˆ„κΈ°", "루트", "제곱", "+", "-", "*", "/", "=", "^", "√", "계산", "λͺ‡μ΄μ•Ό", "μ–Όλ§ˆμ•Ό"]
309
+ if any(kw in text for kw in greet_keywords):
310
+ return "인사"
311
+ elif any(kw in text for kw in info_keywords):
312
+ return "μ •λ³΄μ§ˆλ¬Έ"
313
+ elif any(kw in text for kw in math_keywords):
314
+ return "μˆ˜ν•™μ§ˆλ¬Έ"
315
+ else:
316
+ return "μΌμƒλŒ€ν™”"
317
+
318
+ def parse_math_question(text):
319
+ text = text.replace("κ³±ν•˜κΈ°", "*").replace("λ”ν•˜κΈ°", "+").replace("λΉΌκΈ°", "-").replace("λ‚˜λˆ„κΈ°", "/").replace("제곱", "*2")
320
+ text = re.sub(r'루트\s(\d+)', r'math.sqrt(\1)', text)
321
+ try:
322
+ result = eval(text)
323
+ return f"정닡은 {result}μž…λ‹ˆλ‹€."
324
+ except:
325
+ return "계산할 수 μ—†λŠ” μˆ˜μ‹μ΄μ—μš”. λ‹€μ‹œ ν•œλ²ˆ 확인해 μ£Όμ„Έμš”!"
326
+
327
+ # μ΅œμ’… 응닡 ν•¨μˆ˜
328
+ def respond(input_text):
329
+ intent = simple_intent_classifier(input_text)
330
+
331
+ if "이름" in input_text:
332
+ return "제 이름은 Flexiμž…λ‹ˆλ‹€."
333
+
334
+ if "λˆ„κ΅¬" in input_text:
335
+ return "μ €λŠ” Flexi라고 ν•΄μš”."
336
+
337
+ if intent == "μˆ˜ν•™μ§ˆλ¬Έ":
338
+ return parse_math_question(input_text)
339
+
340
+ if intent == "인사":
341
+ return "λ°˜κ°€μ›Œμš”! 무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?"
342
+
343
+ if intent == "μ •λ³΄μ§ˆλ¬Έ":
344
+ keyword = re.sub(r"(에 λŒ€ν•΄|에 λŒ€ν•œ|에 λŒ€ν•΄μ„œ)?\s*(μ„€λͺ…ν•΄μ€˜|μ•Œλ €μ€˜|뭐야|κ°œλ…|μ •μ˜|정보)?", "", input_text).strip()
345
+ if not keyword:
346
+ return "μ–΄λ–€ μ£Όμ œμ— λŒ€ν•΄ κΆκΈˆν•œκ°€μš”?"
347
+ summary = summarize_from_wikipedia(keyword)
348
+ return f"{summary}\nλ‹€λ₯Έ κΆκΈˆν•œ 점 μžˆμœΌμ‹ κ°€μš”?"
349
+
350
+ # 일상 λŒ€ν™”: μƒ˜ν”Œλ§ + fallback
351
+ response = generate_text_sample(model, input_text)
352
+ if not is_valid_response(response) or mismatch_tone(input_text, response):
353
+ response = generate_text_sample(model, input_text)
354
+ return response
355
+
356
+ @app.get("/generate", response_class=PlainTextResponse)
357
+ async def generate(request: Request):
358
+ prompt = request.query_params.get("prompt", "μ•ˆλ…•ν•˜μ„Έμš”")
359
+ response_text = respond(prompt)
360
+ return response_text