import re import json import numpy as np import requests from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing.text import tokenizer_from_json from tensorflow.keras.preprocessing.sequence import pad_sequences def load_tokenizer(filename): with open(filename, 'r', encoding='utf-8') as f: return tokenizer_from_json(json.load(f)) tokenizer_q = load_tokenizer('kossistant_q.json') tokenizer_a = load_tokenizer('kossistant_a.json') # 모델 및 파라미터 로드 model = load_model('kossistant.h5', compile=False) max_len_q = model.input_shape[0][1] max_len_a = model.input_shape[1][1] index_to_word = {v: k for k, v in tokenizer_a.word_index.items()} index_to_word[0] = '' start_token = 'start' end_token = 'end' # 토큰 샘플링 함수 def sample_from_top_p_top_k(prob_dist, top_p=0.85, top_k=40, temperature=0.8, repetition_penalty=1.4, generated_ids=[]): logits = np.log(prob_dist + 1e-9) / temperature for idx in generated_ids: logits[idx] /= repetition_penalty probs = np.exp(logits) probs = probs / np.sum(probs) top_k_indices = np.argsort(probs)[-top_k:] top_k_probs = probs[top_k_indices] sorted_indices = top_k_indices[np.argsort(top_k_probs)[::-1]] sorted_probs = probs[sorted_indices] cumulative_probs = np.cumsum(sorted_probs) cutoff_index = np.searchsorted(cumulative_probs, top_p) final_indices = sorted_indices[:cutoff_index + 1] final_probs = probs[final_indices] final_probs = final_probs / np.sum(final_probs) return np.random.choice(final_indices, p=final_probs) # 디코딩 def decode_sequence_custom(input_text, max_attempts=2): input_seq = tokenizer_q.texts_to_sequences([input_text]) input_seq = pad_sequences(input_seq, maxlen=max_len_q, padding='post') for _ in range(max_attempts + 1): target_seq = tokenizer_a.texts_to_sequences([start_token])[0] target_seq = pad_sequences([target_seq], maxlen=max_len_a, padding='post') decoded_sentence = '' generated_ids = [] for i in range(max_len_a): predictions = model.predict([input_seq, target_seq], verbose=0) prob_dist = predictions[0, i, :] pred_id = sample_from_top_p_top_k(prob_dist, generated_ids=generated_ids) generated_ids.append(pred_id) pred_word = index_to_word.get(pred_id, '') if pred_word == end_token: break decoded_sentence += pred_word + ' ' if i + 1 < max_len_a: target_seq[0, i + 1] = pred_id cleaned = re.sub(r'\b\b', '', decoded_sentence) cleaned = re.sub(r'\s+', ' ', cleaned) if is_valid_response(cleaned): return cleaned.strip() return "죄송해요, 답변 생성에 실패했어요." def is_valid_response(response): if len(response.strip()) < 2: return False if re.search(r'[ㄱ-ㅎㅏ-ㅣ]{3,}', response): return False if len(response.split()) < 2: return False if response.count(' ') < 2: return False if any(tok in response.lower() for tok in ['hello', 'this', 'ㅋㅋ']): return False return True def extract_main_query(text): sentences = re.split(r'[.?!]\s*', text) sentences = [s.strip() for s in sentences if s.strip()] if not sentences: return text last = sentences[-1] last = re.sub(r'[^가-힣a-zA-Z0-9 ]', '', last) particles = ['이', '가', '은', '는', '을', '를', '의', '에서', '에게', '한테', '보다'] for p in particles: last = re.sub(rf'\b(\w+){p}\b', r'\1', last) return last.strip() def get_wikipedia_summary(query): cleaned_query = extract_main_query(query) url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}" res = requests.get(url) if res.status_code == 200: return res.json().get("extract", "요약 정보를 찾을 수 없습니다.") else: return "위키백과에서 정보를 가져올 수 없습니다." def simple_intent_classifier(text): text = text.lower() greet_keywords = ["안녕", "반가워", "이름", "누구", "소개", "어디서 왔", "정체", "몇 살", "너 뭐야"] info_keywords = ["설명", "정보", "무엇", "뭐야", "어디", "누구", "왜", "어떻게", "종류", "개념"] math_keywords = ["더하기", "빼기", "곱하기", "나누기", "루트", "제곱", "+", "-", "*", "/", "=", "^", "√", "계산", "몇이야", "얼마야"] if any(kw in text for kw in greet_keywords): return "인사" elif any(kw in text for kw in info_keywords): return "정보질문" elif any(kw in text for kw in math_keywords): return "수학질문" else: return "일상대화" def parse_math_question(text): text = text.replace("곱하기", "*").replace("더하기", "+").replace("빼기", "-").replace("나누기", "/").replace("제곱", "*2") text = re.sub(r'루트\s(\d+)', r'math.sqrt(\1)', text) try: result = eval(text) return f"정답은 {result}입니다." except: return "계산할 수 없는 수식이에요. 다시 한번 확인해 주세요!" # 전체 응답 함수 def respond(input_text): intent = simple_intent_classifier(input_text) if "/사용법" in input_text: return "자유롭게 사용해주세요. 딱히 제약은 없습니다." if "이름" in input_text: return "제 이름은 kossistant입니다." if "누구" in input_text: return "저는 kossistant이라고 해요." if intent == "수학질문": return parse_math_question(input_text) if intent == "정보질문": keyword = re.sub(r"(에 대해|에 대한|에 대해서)?\s*(설명해줘|알려줘|뭐야|개념|정의|정보)?", "", input_text).strip() if not keyword: return "어떤 주제에 대해 궁금한가요?" summary = get_wikipedia_summary(keyword) return f"{summary}\n다른 궁금한 점 있으신가요?" return decode_sequence_custom(input_text)