Spaces:
Sleeping
Sleeping
Create chatbot.py
Browse files- chatbot.py +156 -0
chatbot.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
import requests
|
5 |
+
from tensorflow.keras.models import load_model
|
6 |
+
from tensorflow.keras.preprocessing.text import tokenizer_from_json
|
7 |
+
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
8 |
+
|
9 |
+
# ํ ํฌ๋์ด์ ๋ถ๋ฌ์ค๊ธฐ
|
10 |
+
def load_tokenizer(filename):
|
11 |
+
with open(filename, 'r', encoding='utf-8') as f:
|
12 |
+
return tokenizer_from_json(json.load(f))
|
13 |
+
|
14 |
+
tokenizer_q = load_tokenizer('kossistant_q.json')
|
15 |
+
tokenizer_a = load_tokenizer('kossistant_a.json')
|
16 |
+
|
17 |
+
# ๋ชจ๋ธ ๋ฐ ํ๋ผ๋ฏธํฐ ๋ก๋
|
18 |
+
model = load_model('kossistant.h5', compile=False)
|
19 |
+
max_len_q = model.input_shape[0][1]
|
20 |
+
max_len_a = model.input_shape[1][1]
|
21 |
+
index_to_word = {v: k for k, v in tokenizer_a.word_index.items()}
|
22 |
+
index_to_word[0] = ''
|
23 |
+
start_token = 'start'
|
24 |
+
end_token = 'end'
|
25 |
+
|
26 |
+
# ํ ํฐ ์ํ๋ง ํจ์
|
27 |
+
def sample_from_top_p_top_k(prob_dist, top_p=0.85, top_k=40, temperature=0.8, repetition_penalty=1.4, generated_ids=[]):
|
28 |
+
logits = np.log(prob_dist + 1e-9) / temperature
|
29 |
+
for idx in generated_ids:
|
30 |
+
logits[idx] /= repetition_penalty
|
31 |
+
probs = np.exp(logits)
|
32 |
+
probs = probs / np.sum(probs)
|
33 |
+
top_k_indices = np.argsort(probs)[-top_k:]
|
34 |
+
top_k_probs = probs[top_k_indices]
|
35 |
+
sorted_indices = top_k_indices[np.argsort(top_k_probs)[::-1]]
|
36 |
+
sorted_probs = probs[sorted_indices]
|
37 |
+
cumulative_probs = np.cumsum(sorted_probs)
|
38 |
+
cutoff_index = np.searchsorted(cumulative_probs, top_p)
|
39 |
+
final_indices = sorted_indices[:cutoff_index + 1]
|
40 |
+
final_probs = probs[final_indices]
|
41 |
+
final_probs = final_probs / np.sum(final_probs)
|
42 |
+
return np.random.choice(final_indices, p=final_probs)
|
43 |
+
|
44 |
+
# ๋์ฝ๋ฉ
|
45 |
+
def decode_sequence_custom(input_text, max_attempts=2):
|
46 |
+
input_seq = tokenizer_q.texts_to_sequences([input_text])
|
47 |
+
input_seq = pad_sequences(input_seq, maxlen=max_len_q, padding='post')
|
48 |
+
|
49 |
+
for _ in range(max_attempts + 1):
|
50 |
+
target_seq = tokenizer_a.texts_to_sequences([start_token])[0]
|
51 |
+
target_seq = pad_sequences([target_seq], maxlen=max_len_a, padding='post')
|
52 |
+
|
53 |
+
decoded_sentence = ''
|
54 |
+
generated_ids = []
|
55 |
+
|
56 |
+
for i in range(max_len_a):
|
57 |
+
predictions = model.predict([input_seq, target_seq], verbose=0)
|
58 |
+
prob_dist = predictions[0, i, :]
|
59 |
+
pred_id = sample_from_top_p_top_k(prob_dist, generated_ids=generated_ids)
|
60 |
+
generated_ids.append(pred_id)
|
61 |
+
pred_word = index_to_word.get(pred_id, '')
|
62 |
+
if pred_word == end_token:
|
63 |
+
break
|
64 |
+
decoded_sentence += pred_word + ' '
|
65 |
+
if i + 1 < max_len_a:
|
66 |
+
target_seq[0, i + 1] = pred_id
|
67 |
+
|
68 |
+
cleaned = re.sub(r'\b<end>\b', '', decoded_sentence)
|
69 |
+
cleaned = re.sub(r'\s+', ' ', cleaned)
|
70 |
+
if is_valid_response(cleaned):
|
71 |
+
return cleaned.strip()
|
72 |
+
|
73 |
+
return "์ฃ์กํด์, ๋ต๋ณ ์์ฑ์ ์คํจํ์ด์."
|
74 |
+
|
75 |
+
def is_valid_response(response):
|
76 |
+
if len(response.strip()) < 2:
|
77 |
+
return False
|
78 |
+
if re.search(r'[ใฑ-ใ
ใ
-ใ
ฃ]{3,}', response):
|
79 |
+
return False
|
80 |
+
if len(response.split()) < 2:
|
81 |
+
return False
|
82 |
+
if response.count(' ') < 2:
|
83 |
+
return False
|
84 |
+
if any(tok in response.lower() for tok in ['hello', 'this', 'ใ
ใ
']):
|
85 |
+
return False
|
86 |
+
return True
|
87 |
+
|
88 |
+
def extract_main_query(text):
|
89 |
+
sentences = re.split(r'[.?!]\s*', text)
|
90 |
+
sentences = [s.strip() for s in sentences if s.strip()]
|
91 |
+
if not sentences:
|
92 |
+
return text
|
93 |
+
last = sentences[-1]
|
94 |
+
last = re.sub(r'[^๊ฐ-ํฃa-zA-Z0-9 ]', '', last)
|
95 |
+
particles = ['์ด', '๊ฐ', '์', '๋', '์', '๋ฅผ', '์', '์์', '์๊ฒ', 'ํํ
', '๋ณด๋ค']
|
96 |
+
for p in particles:
|
97 |
+
last = re.sub(rf'\b(\w+){p}\b', r'\1', last)
|
98 |
+
return last.strip()
|
99 |
+
|
100 |
+
def get_wikipedia_summary(query):
|
101 |
+
cleaned_query = extract_main_query(query)
|
102 |
+
url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}"
|
103 |
+
res = requests.get(url)
|
104 |
+
if res.status_code == 200:
|
105 |
+
return res.json().get("extract", "์์ฝ ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.")
|
106 |
+
else:
|
107 |
+
return "์ํค๋ฐฑ๊ณผ์์ ์ ๋ณด๋ฅผ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค."
|
108 |
+
|
109 |
+
def simple_intent_classifier(text):
|
110 |
+
text = text.lower()
|
111 |
+
greet_keywords = ["์๋
", "๋ฐ๊ฐ์", "์ด๋ฆ", "๋๊ตฌ", "์๊ฐ", "์ด๋์ ์", "์ ์ฒด", "๋ช ์ด", "๋ ๋ญ์ผ"]
|
112 |
+
info_keywords = ["์ค๋ช
", "์ ๋ณด", "๋ฌด์", "๋ญ์ผ", "์ด๋", "๋๊ตฌ", "์", "์ด๋ป๊ฒ", "์ข
๋ฅ", "๊ฐ๋
"]
|
113 |
+
math_keywords = ["๋ํ๊ธฐ", "๋นผ๊ธฐ", "๊ณฑํ๊ธฐ", "๋๋๊ธฐ", "๋ฃจํธ", "์ ๊ณฑ", "+", "-", "*", "/", "=", "^", "โ", "๊ณ์ฐ", "๋ช์ด์ผ", "์ผ๋ง์ผ"]
|
114 |
+
|
115 |
+
if any(kw in text for kw in greet_keywords):
|
116 |
+
return "์ธ์ฌ"
|
117 |
+
elif any(kw in text for kw in info_keywords):
|
118 |
+
return "์ ๋ณด์ง๋ฌธ"
|
119 |
+
elif any(kw in text for kw in math_keywords):
|
120 |
+
return "์ํ์ง๋ฌธ"
|
121 |
+
else:
|
122 |
+
return "์ผ์๋ํ"
|
123 |
+
|
124 |
+
def parse_math_question(text):
|
125 |
+
text = text.replace("๊ณฑํ๊ธฐ", "*").replace("๋ํ๊ธฐ", "+").replace("๋นผ๊ธฐ", "-").replace("๋๋๊ธฐ", "/").replace("์ ๊ณฑ", "*2")
|
126 |
+
text = re.sub(r'๋ฃจํธ\s(\d+)', r'math.sqrt(\1)', text)
|
127 |
+
try:
|
128 |
+
result = eval(text)
|
129 |
+
return f"์ ๋ต์ {result}์
๋๋ค."
|
130 |
+
except:
|
131 |
+
return "๊ณ์ฐํ ์ ์๋ ์์์ด์์. ๋ค์ ํ๋ฒ ํ์ธํด ์ฃผ์ธ์!"
|
132 |
+
|
133 |
+
# ์ ์ฒด ์๋ต ํจ์
|
134 |
+
def respond(input_text):
|
135 |
+
intent = simple_intent_classifier(input_text)
|
136 |
+
|
137 |
+
if "/์ฌ์ฉ๋ฒ" in input_text:
|
138 |
+
return "์์ ๋กญ๊ฒ ์ฌ์ฉํด์ฃผ์ธ์. ๋ฑํ ์ ์ฝ์ ์์ต๋๋ค."
|
139 |
+
|
140 |
+
if "์ด๋ฆ" in input_text:
|
141 |
+
return "์ ์ด๋ฆ์ kossistant์
๋๋ค."
|
142 |
+
|
143 |
+
if "๋๊ตฌ" in input_text:
|
144 |
+
return "์ ๋ kossistant์ด๋ผ๊ณ ํด์."
|
145 |
+
|
146 |
+
if intent == "์ํ์ง๋ฌธ":
|
147 |
+
return parse_math_question(input_text)
|
148 |
+
|
149 |
+
if intent == "์ ๋ณด์ง๋ฌธ":
|
150 |
+
keyword = re.sub(r"(์ ๋ํด|์ ๋ํ|์ ๋ํด์)?\s*(์ค๋ช
ํด์ค|์๋ ค์ค|๋ญ์ผ|๊ฐ๋
|์ ์|์ ๋ณด)?", "", input_text).strip()
|
151 |
+
if not keyword:
|
152 |
+
return "์ด๋ค ์ฃผ์ ์ ๋ํด ๊ถ๊ธํ๊ฐ์?"
|
153 |
+
summary = get_wikipedia_summary(keyword)
|
154 |
+
return f"{summary}\n๋ค๋ฅธ ๊ถ๊ธํ ์ ์์ผ์ ๊ฐ์?"
|
155 |
+
|
156 |
+
return decode_sequence_custom(input_text)
|