Spaces:
Running
Running
| import re | |
| import json | |
| import numpy as np | |
| import requests | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.preprocessing.text import tokenizer_from_json | |
| from tensorflow.keras.preprocessing.sequence import pad_sequences | |
| def load_tokenizer(filename): | |
| with open(filename, 'r', encoding='utf-8') as f: | |
| return tokenizer_from_json(json.load(f)) | |
| tokenizer_q = load_tokenizer('kossistant_q.json') | |
| tokenizer_a = load_tokenizer('kossistant_a.json') | |
| # ๋ชจ๋ธ ๋ฐ ํ๋ผ๋ฏธํฐ ๋ก๋ | |
| model = load_model('kossistant.h5', compile=False) | |
| max_len_q = model.input_shape[0][1] | |
| max_len_a = model.input_shape[1][1] | |
| index_to_word = {v: k for k, v in tokenizer_a.word_index.items()} | |
| index_to_word[0] = '' | |
| start_token = 'start' | |
| end_token = 'end' | |
| # ํ ํฐ ์ํ๋ง ํจ์ | |
| def sample_from_top_p_top_k(prob_dist, top_p=0.85, top_k=40, temperature=0.8, repetition_penalty=1.4, generated_ids=[]): | |
| logits = np.log(prob_dist + 1e-9) / temperature | |
| for idx in generated_ids: | |
| logits[idx] /= repetition_penalty | |
| probs = np.exp(logits) | |
| probs = probs / np.sum(probs) | |
| top_k_indices = np.argsort(probs)[-top_k:] | |
| top_k_probs = probs[top_k_indices] | |
| sorted_indices = top_k_indices[np.argsort(top_k_probs)[::-1]] | |
| sorted_probs = probs[sorted_indices] | |
| cumulative_probs = np.cumsum(sorted_probs) | |
| cutoff_index = np.searchsorted(cumulative_probs, top_p) | |
| final_indices = sorted_indices[:cutoff_index + 1] | |
| final_probs = probs[final_indices] | |
| final_probs = final_probs / np.sum(final_probs) | |
| return np.random.choice(final_indices, p=final_probs) | |
| # ๋์ฝ๋ฉ | |
| def decode_sequence_custom(input_text, max_attempts=2): | |
| input_seq = tokenizer_q.texts_to_sequences([input_text]) | |
| input_seq = pad_sequences(input_seq, maxlen=max_len_q, padding='post') | |
| for _ in range(max_attempts + 1): | |
| target_seq = tokenizer_a.texts_to_sequences([start_token])[0] | |
| target_seq = pad_sequences([target_seq], maxlen=max_len_a, padding='post') | |
| decoded_sentence = '' | |
| generated_ids = [] | |
| for i in range(max_len_a): | |
| predictions = model.predict([input_seq, target_seq], verbose=0) | |
| prob_dist = predictions[0, i, :] | |
| pred_id = sample_from_top_p_top_k(prob_dist, generated_ids=generated_ids) | |
| generated_ids.append(pred_id) | |
| pred_word = index_to_word.get(pred_id, '') | |
| if pred_word == end_token: | |
| break | |
| decoded_sentence += pred_word + ' ' | |
| if i + 1 < max_len_a: | |
| target_seq[0, i + 1] = pred_id | |
| cleaned = re.sub(r'\b<end>\b', '', decoded_sentence) | |
| cleaned = re.sub(r'\s+', ' ', cleaned) | |
| if is_valid_response(cleaned): | |
| return cleaned.strip() | |
| return "์ฃ์กํด์, ๋ต๋ณ ์์ฑ์ ์คํจํ์ด์." | |
| def is_valid_response(response): | |
| if len(response.strip()) < 2: | |
| return False | |
| if re.search(r'[ใฑ-ใ ใ -ใ ฃ]{3,}', response): | |
| return False | |
| if len(response.split()) < 2: | |
| return False | |
| if response.count(' ') < 2: | |
| return False | |
| if any(tok in response.lower() for tok in ['hello', 'this', 'ใ ใ ']): | |
| return False | |
| return True | |
| def extract_main_query(text): | |
| sentences = re.split(r'[.?!]\s*', text) | |
| sentences = [s.strip() for s in sentences if s.strip()] | |
| if not sentences: | |
| return text | |
| last = sentences[-1] | |
| last = re.sub(r'[^๊ฐ-ํฃa-zA-Z0-9 ]', '', last) | |
| particles = ['์ด', '๊ฐ', '์', '๋', '์', '๋ฅผ', '์', '์์', '์๊ฒ', 'ํํ ', '๋ณด๋ค'] | |
| for p in particles: | |
| last = re.sub(rf'\b(\w+){p}\b', r'\1', last) | |
| return last.strip() | |
| def get_wikipedia_summary(query): | |
| cleaned_query = extract_main_query(query) | |
| url = f"https://ko.wikipedia.org/api/rest_v1/page/summary/{cleaned_query}" | |
| res = requests.get(url) | |
| if res.status_code == 200: | |
| return res.json().get("extract", "์์ฝ ์ ๋ณด๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค.") | |
| else: | |
| return "์ํค๋ฐฑ๊ณผ์์ ์ ๋ณด๋ฅผ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค." | |
| def simple_intent_classifier(text): | |
| text = text.lower() | |
| greet_keywords = ["์๋ ", "๋ฐ๊ฐ์", "์ด๋ฆ", "๋๊ตฌ", "์๊ฐ", "์ด๋์ ์", "์ ์ฒด", "๋ช ์ด", "๋ ๋ญ์ผ"] | |
| info_keywords = ["์ค๋ช ", "์ ๋ณด", "๋ฌด์", "๋ญ์ผ", "์ด๋", "๋๊ตฌ", "์", "์ด๋ป๊ฒ", "์ข ๋ฅ", "๊ฐ๋ "] | |
| math_keywords = ["๋ํ๊ธฐ", "๋นผ๊ธฐ", "๊ณฑํ๊ธฐ", "๋๋๊ธฐ", "๋ฃจํธ", "์ ๊ณฑ", "+", "-", "*", "/", "=", "^", "โ", "๊ณ์ฐ", "๋ช์ด์ผ", "์ผ๋ง์ผ"] | |
| if any(kw in text for kw in greet_keywords): | |
| return "์ธ์ฌ" | |
| elif any(kw in text for kw in info_keywords): | |
| return "์ ๋ณด์ง๋ฌธ" | |
| elif any(kw in text for kw in math_keywords): | |
| return "์ํ์ง๋ฌธ" | |
| else: | |
| return "์ผ์๋ํ" | |
| def parse_math_question(text): | |
| text = text.replace("๊ณฑํ๊ธฐ", "*").replace("๋ํ๊ธฐ", "+").replace("๋นผ๊ธฐ", "-").replace("๋๋๊ธฐ", "/").replace("์ ๊ณฑ", "*2") | |
| text = re.sub(r'๋ฃจํธ\s(\d+)', r'math.sqrt(\1)', text) | |
| try: | |
| result = eval(text) | |
| return f"์ ๋ต์ {result}์ ๋๋ค." | |
| except: | |
| return "๊ณ์ฐํ ์ ์๋ ์์์ด์์. ๋ค์ ํ๋ฒ ํ์ธํด ์ฃผ์ธ์!" | |
| # ์ ์ฒด ์๋ต ํจ์ | |
| def respond(input_text): | |
| intent = simple_intent_classifier(input_text) | |
| if "/์ฌ์ฉ๋ฒ" in input_text: | |
| return "์์ ๋กญ๊ฒ ์ฌ์ฉํด์ฃผ์ธ์. ๋ฑํ ์ ์ฝ์ ์์ต๋๋ค." | |
| if "์ด๋ฆ" in input_text: | |
| return "์ ์ด๋ฆ์ kossistant์ ๋๋ค." | |
| if "๋๊ตฌ" in input_text: | |
| return "์ ๋ kossistant์ด๋ผ๊ณ ํด์." | |
| if intent == "์ํ์ง๋ฌธ": | |
| return parse_math_question(input_text) | |
| if intent == "์ ๋ณด์ง๋ฌธ": | |
| keyword = re.sub(r"(์ ๋ํด|์ ๋ํ|์ ๋ํด์)?\s*(์ค๋ช ํด์ค|์๋ ค์ค|๋ญ์ผ|๊ฐ๋ |์ ์|์ ๋ณด)?", "", input_text).strip() | |
| if not keyword: | |
| return "์ด๋ค ์ฃผ์ ์ ๋ํด ๊ถ๊ธํ๊ฐ์?" | |
| summary = get_wikipedia_summary(keyword) | |
| return f"{summary}\n๋ค๋ฅธ ๊ถ๊ธํ ์ ์์ผ์ ๊ฐ์?" | |
| return decode_sequence_custom(input_text) |