Yuchan5386 commited on
Commit
c57f761
ยท
verified ยท
1 Parent(s): ac18de1

Create chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +156 -0
chatbot.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import numpy as np
4
+ import requests
5
+ from tensorflow.keras.models import load_model
6
+ from tensorflow.keras.preprocessing.text import tokenizer_from_json
7
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
8
+
9
+ # ํ† ํฌ๋‚˜์ด์ € ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
10
+ def load_tokenizer(filename):
11
+ with open(filename, 'r', encoding='utf-8') as f:
12
+ return tokenizer_from_json(json.load(f))
13
+
14
+ tokenizer_q = load_tokenizer('kossistant_q.json')
15
+ tokenizer_a = load_tokenizer('kossistant_a.json')
16
+
17
+ # ๋ชจ๋ธ ๋ฐ ํŒŒ๋ผ๋ฏธํ„ฐ ๋กœ๋“œ
18
+ model = load_model('kossistant.h5', compile=False)
19
+ max_len_q = model.input_shape[0][1]
20
+ max_len_a = model.input_shape[1][1]
21
+ index_to_word = {v: k for k, v in tokenizer_a.word_index.items()}
22
+ index_to_word[0] = ''
23
+ start_token = 'start'
24
+ end_token = 'end'
25
+
26
+ # ํ† ํฐ ์ƒ˜ํ”Œ๋ง ํ•จ์ˆ˜
27
+ def sample_from_top_p_top_k(prob_dist, top_p=0.85, top_k=40, temperature=0.8, repetition_penalty=1.4, generated_ids=[]):
28
+ logits = np.log(prob_dist + 1e-9) / temperature
29
+ for idx in generated_ids:
30
+ logits[idx] /= repetition_penalty
31
+ probs = np.exp(logits)
32
+ probs = probs / np.sum(probs)
33
+ top_k_indices = np.argsort(probs)[-top_k:]
34
+ top_k_probs = probs[top_k_indices]
35
+ sorted_indices = top_k_indices[np.argsort(top_k_probs)[::-1]]
36
+ sorted_probs = probs[sorted_indices]
37
+ cumulative_probs = np.cumsum(sorted_probs)
38
+ cutoff_index = np.searchsorted(cumulative_probs, top_p)
39
+ final_indices = sorted_indices[:cutoff_index + 1]
40
+ final_probs = probs[final_indices]
41
+ final_probs = final_probs / np.sum(final_probs)
42
+ return np.random.choice(final_indices, p=final_probs)
43
+
44
+ # ๋””์ฝ”๋”ฉ
45
+ def decode_sequence_custom(input_text, max_attempts=2):
46
+ input_seq = tokenizer_q.texts_to_sequences([input_text])
47
+ input_seq = pad_sequences(input_seq, maxlen=max_len_q, padding='post')
48
+
49
+ for _ in range(max_attempts + 1):
50
+ target_seq = tokenizer_a.texts_to_sequences([start_token])[0]
51
+ target_seq = pad_sequences([target_seq], maxlen=max_len_a, padding='post')
52
+
53
+ decoded_sentence = ''
54
+ generated_ids = []
55
+
56
+ for i in range(max_len_a):
57
+ predictions = model.predict([input_seq, target_seq], verbose=0)
58
+ prob_dist = predictions[0, i, :]
59
+ pred_id = sample_from_top_p_top_k(prob_dist, generated_ids=generated_ids)
60
+ generated_ids.append(pred_id)
61
+ pred_word = index_to_word.get(pred_id, '')
62
+ if pred_word == end_token:
63
+ break
64
+ decoded_sentence += pred_word + ' '
65
+ if i + 1 < max_len_a:
66
+ target_seq[0, i + 1] = pred_id
67
+
68
+ cleaned = re.sub(r'\b<end>\b', '', decoded_sentence)
69
+ cleaned = re.sub(r'\s+', ' ', cleaned)
70
+ if is_valid_response(cleaned):
71
+ return cleaned.strip()
72
+
73
+ return "์ฃ„์†กํ•ด์š”, ๋‹ต๋ณ€ ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์–ด์š”."
74
+
75
+ def is_valid_response(response):
76
+ if len(response.strip()) < 2:
77
+ return False
78
+ if re.search(r'[ใ„ฑ-ใ…Žใ…-ใ…ฃ]{3,}', response):
79
+ return False
80
+ if len(response.split()) < 2:
81
+ return False
82
+ if response.count(' ') < 2:
83
+ return False
84
+ if any(tok in response.lower() for tok in ['hello', 'this', 'ใ…‹ใ…‹']):
85
+ return False
86
+ return True
87
+
88
+ def extract_main_query(text):
89
+ sentences = re.split(r'[.?!]\s*', text)
90
+ sentences = [s.strip() for s in sentences if s.strip()]
91
+ if not sentences:
92
+ return text
93
+ last = sentences[-1]
94
+ last = re.sub(r'[^๊ฐ€-ํžฃa-zA-Z0-9 ]', '', last)
95
+ particles = ['์ด', '๊ฐ€', '์€', '๋Š”', '์„', '๋ฅผ', '์˜', '์—์„œ', '์—๊ฒŒ', 'ํ•œํ…Œ', '๋ณด๋‹ค']
96
+ for p in particles:
97
+ last = re.sub(rf'\b(\w+){p}\b', r'\1', last)
98
+ return last.strip()
99
+
100
+ def get_wikipedia_summary(query):
101
+ cleaned_query = extract_main_query(query)
102
+ url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}"
103
+ res = requests.get(url)
104
+ if res.status_code == 200:
105
+ return res.json().get("extract", "์š”์•ฝ ์ •๋ณด๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
106
+ else:
107
+ return "์œ„ํ‚ค๋ฐฑ๊ณผ์—์„œ ์ •๋ณด๋ฅผ ๊ฐ€์ ธ์˜ฌ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
108
+
109
+ def simple_intent_classifier(text):
110
+ text = text.lower()
111
+ greet_keywords = ["์•ˆ๋…•", "๋ฐ˜๊ฐ€์›Œ", "์ด๋ฆ„", "๋ˆ„๊ตฌ", "์†Œ๊ฐœ", "์–ด๋””์„œ ์™”", "์ •์ฒด", "๋ช‡ ์‚ด", "๋„ˆ ๋ญ์•ผ"]
112
+ info_keywords = ["์„ค๋ช…", "์ •๋ณด", "๋ฌด์—‡", "๋ญ์•ผ", "์–ด๋””", "๋ˆ„๊ตฌ", "์™œ", "์–ด๋–ป๊ฒŒ", "์ข…๋ฅ˜", "๊ฐœ๋…"]
113
+ math_keywords = ["๋”ํ•˜๊ธฐ", "๋นผ๊ธฐ", "๊ณฑํ•˜๊ธฐ", "๋‚˜๋ˆ„๊ธฐ", "๋ฃจํŠธ", "์ œ๊ณฑ", "+", "-", "*", "/", "=", "^", "โˆš", "๊ณ„์‚ฐ", "๋ช‡์ด์•ผ", "์–ผ๋งˆ์•ผ"]
114
+
115
+ if any(kw in text for kw in greet_keywords):
116
+ return "์ธ์‚ฌ"
117
+ elif any(kw in text for kw in info_keywords):
118
+ return "์ •๋ณด์งˆ๋ฌธ"
119
+ elif any(kw in text for kw in math_keywords):
120
+ return "์ˆ˜ํ•™์งˆ๋ฌธ"
121
+ else:
122
+ return "์ผ์ƒ๋Œ€ํ™”"
123
+
124
+ def parse_math_question(text):
125
+ text = text.replace("๊ณฑํ•˜๊ธฐ", "*").replace("๋”ํ•˜๊ธฐ", "+").replace("๋นผ๊ธฐ", "-").replace("๋‚˜๋ˆ„๊ธฐ", "/").replace("์ œ๊ณฑ", "*2")
126
+ text = re.sub(r'๋ฃจํŠธ\s(\d+)', r'math.sqrt(\1)', text)
127
+ try:
128
+ result = eval(text)
129
+ return f"์ •๋‹ต์€ {result}์ž…๋‹ˆ๋‹ค."
130
+ except:
131
+ return "๊ณ„์‚ฐํ•  ์ˆ˜ ์—†๋Š” ์ˆ˜์‹์ด์—์š”. ๋‹ค์‹œ ํ•œ๋ฒˆ ํ™•์ธํ•ด ์ฃผ์„ธ์š”!"
132
+
133
+ # ์ „์ฒด ์‘๋‹ต ํ•จ์ˆ˜
134
+ def respond(input_text):
135
+ intent = simple_intent_classifier(input_text)
136
+
137
+ if "/์‚ฌ์šฉ๋ฒ•" in input_text:
138
+ return "์ž์œ ๋กญ๊ฒŒ ์‚ฌ์šฉํ•ด์ฃผ์„ธ์š”. ๋”ฑํžˆ ์ œ์•ฝ์€ ์—†์Šต๋‹ˆ๋‹ค."
139
+
140
+ if "์ด๋ฆ„" in input_text:
141
+ return "์ œ ์ด๋ฆ„์€ kossistant์ž…๋‹ˆ๋‹ค."
142
+
143
+ if "๋ˆ„๊ตฌ" in input_text:
144
+ return "์ €๋Š” kossistant์ด๋ผ๊ณ  ํ•ด์š”."
145
+
146
+ if intent == "์ˆ˜ํ•™์งˆ๋ฌธ":
147
+ return parse_math_question(input_text)
148
+
149
+ if intent == "์ •๋ณด์งˆ๋ฌธ":
150
+ keyword = re.sub(r"(์— ๋Œ€ํ•ด|์— ๋Œ€ํ•œ|์— ๋Œ€ํ•ด์„œ)?\s*(์„ค๋ช…ํ•ด์ค˜|์•Œ๋ ค์ค˜|๋ญ์•ผ|๊ฐœ๋…|์ •์˜|์ •๋ณด)?", "", input_text).strip()
151
+ if not keyword:
152
+ return "์–ด๋–ค ์ฃผ์ œ์— ๋Œ€ํ•ด ๊ถ๊ธˆํ•œ๊ฐ€์š”?"
153
+ summary = get_wikipedia_summary(keyword)
154
+ return f"{summary}\n๋‹ค๋ฅธ ๊ถ๊ธˆํ•œ ์  ์žˆ์œผ์‹ ๊ฐ€์š”?"
155
+
156
+ return decode_sequence_custom(input_text)