|
import os |
|
import openai |
|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
|
|
|
openai.api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
|
|
model = SentenceTransformer("jhgan/ko-sroberta-multitask") |
|
|
|
|
|
df = pd.read_csv("https://raw.githubusercontent.com/kairess/mental-health-chatbot/master/wellness_dataset_original.csv") |
|
df = df.dropna() |
|
df["embedding"] = df["์ ์ "].map(lambda x: model.encode(str(x))) |
|
|
|
|
|
MAX_TURN = 5 |
|
|
|
|
|
EMPATHY_PROMPT = """\ |
|
๋น์ ์ ์น์ ํ ์ ์ ์ํ๊ณผ ์ ๋ฌธ์์ด๋ฉฐ ์ฌ๋ฆฌ์๋ด ์ ๋ฌธ๊ฐ์
๋๋ค. |
|
์ฌ์ฉ์์ ๋ฌธ์ฅ์ ๊ฑฐ์ ๊ทธ๋๋ก ์์ฝํ๋, ๋์ '๋๊ตฐ์.' ๊ฐ์ ๊ณต๊ฐ ์ด๋ฏธ๋ฅผ ๋ถ์ฌ ์์ฐ์ค๋ฝ๊ฒ ์๋ตํ์ธ์. |
|
|
|
์์: |
|
์ฌ์ฉ์: "์ํ์ ์๋๊ณ ๋ถ์ํด์ ๋ฉฐ์น ์งธ ์ ์ด ์ ์์." |
|
์ฑ๋ด: "์ํ์ ์๋๊ณ ๋ถ์ํด์ ๋ฉฐ์น ์งธ ์ ์ด ์ ์ค๋๊ตฐ์." |
|
|
|
์ด์ ์ฌ์ฉ์ ๋ฐํ๋ฅผ ์๋์ ์ฃผ๊ฒ ์ต๋๋ค. |
|
์ฌ์ฉ์ ๋ฐํ: "{sentence}" |
|
์ฑ๋ด: |
|
""" |
|
|
|
SOCRATIC_PROMPT = """\ |
|
๋น์ ์ ์ ์ ์ํ๊ณผ ์ ๋ฌธ์์ด๋ฉฐ Socratic CBT ๊ธฐ๋ฒ์ ์ฌ์ฉํ๋ ์ฌ๋ฆฌ์๋ด ์ ๋ฌธ๊ฐ์
๋๋ค. |
|
์ด์ ๋ํ ๋ด์ฉ๊ณผ ํํธ๋ฅผ ์ฐธ๊ณ ํ์ฌ, ์ฌ์ฉ์์ ์ธ์ง๋ฅผ ํ์ํ๋ ์์ฐ์ค๋ฝ๊ณ ๊ตฌ์ฒด์ ์ธ ํ์ ์ง๋ฌธ์ ํ ๋ฌธ์ฅ์ผ๋ก ์์ฑํ์ธ์. |
|
|
|
- ์ง๋ฌธ์ ๋ฐ๋์ ๋ฌผ์ํ๋ก ๋๋์ผ ํฉ๋๋ค. |
|
- "์ง๋ฌธ:" ๊ฐ์ ์ ๋์ด ์์ด ๋ฐ๋ก ์ง๋ฌธ ๋ฌธ์ฅ๋ง ์์ฑํ์ธ์. |
|
- ๊ฐ๋ฅํ ํ ์ฌ์ฉ์์ ์ํฉ์ ๋ ๊น์ด ์ดํดํ ์ ์๋ ํ์์ ์ง๋ฌธ์ ํด์ฃผ์ธ์. |
|
""" |
|
|
|
ADVICE_PROMPT = """\ |
|
๋น์ ์ ์ ์ ์ํ๊ณผ ์ ๋ฌธ์์ด๋ฉฐ Socratic CBT ๊ธฐ๋ฒ์ ์ฌ์ฉํ๋ ์ฌ๋ฆฌ์๋ด ์ ๋ฌธ๊ฐ์
๋๋ค. |
|
์๋ ํํธ(๋ํ ์์ฝ)๋ฅผ ๋ฐํ์ผ๋ก, ์ฌ์ฉ์ ๋ง์ถคํ์ผ๋ก ๊ตฌ์ฒด์ ์ด๊ณ ๊ณต๊ฐ ์ด๋ฆฐ ์กฐ์ธ์ ํ๊ตญ์ด๋ก ์์ฑํ์ธ์. |
|
|
|
- ๋ถ์์ ์ํํ๊ธฐ ์ํ ์ฌ๋ฌ CBT ๊ธฐ๋ฒ(์ธ์ง ์ฌ๊ตฌ์กฐํ, ์ ์ง์ ๊ทผ์ก ์ด์, ํธํก์กฐ์ , ๊ฑฑ์ ์๊ฐ ์ ํ๊ธฐ ๋ฑ)์ ์์ฐ์ค๋ฝ๊ฒ ๋
น์ด๋ |
|
์ฌ์ฉ์์ ํ์ฌ ์ํฉ๊ณผ ์ฐ๊ฒฐํด ์ด์ผ๊ธฐํด์ฃผ์ธ์. |
|
- ๋๋ฌด ๋ฑ๋ฑํ์ง ์๊ฒ ๋ถ๋๋ฝ๊ณ ์น์ ํ ๋งํฌ๋ฅผ ์ฌ์ฉํ์ธ์. |
|
|
|
ํํธ: |
|
{hints} |
|
|
|
์กฐ์ธ: |
|
""" |
|
|
|
def set_openai_model(): |
|
""" |
|
์ ์ ์์ฒญ๋๋ก 'gpt-4o' ๋ชจ๋ธ๋ช
๋ฐํ |
|
(์ค์ ๋ก๋ ์กด์ฌํ์ง ์์ ๊ฐ๋ฅ์ฑ ํผ) |
|
""" |
|
return "gpt-4o" |
|
|
|
|
|
|
|
def kb_search(user_input: str) -> str: |
|
"""SentenceTransformer๋ก ์๋ฒ ๋ฉ ํ, df์์ ๊ฐ์ฅ ์ ์ฌํ ์ฑ๋ด ๋ต๋ณ ํ๋.""" |
|
emb = model.encode(user_input) |
|
df["sim"] = df["embedding"].map(lambda e: cosine_similarity([emb],[e]).squeeze()) |
|
idx = df["sim"].idxmax() |
|
return df.loc[idx, "์ฑ๋ด"] |
|
|
|
def call_empathy(user_input: str) -> str: |
|
"""EMPATHY ๋จ๊ณ: ๊ณต๊ฐ ์์ฝ.""" |
|
prompt = EMPATHY_PROMPT.format(sentence=user_input) |
|
resp = openai.ChatCompletion.create( |
|
model=set_openai_model(), |
|
messages=[ |
|
{"role":"system","content":"๋น์ ์ ์น์ ํ ์ฌ๋ฆฌ์๋ด ์ ๋ฌธ๊ฐ์
๋๋ค."}, |
|
{"role":"user","content":prompt} |
|
], |
|
max_tokens=150, |
|
temperature=0.7 |
|
) |
|
return resp.choices[0].message.content.strip() |
|
|
|
def call_socratic_question(context: str) -> str: |
|
"""SQ ๋จ๊ณ: ํ์ ์ง๋ฌธ ํ ๋ฌธ์ฅ ์์ฑ.""" |
|
prompt = f"{SOCRATIC_PROMPT}\n\n๋ํ ํํธ:\n{context}" |
|
resp = openai.ChatCompletion.create( |
|
model=set_openai_model(), |
|
messages=[ |
|
{"role":"system","content":"๋น์ ์ Socratic CBT ์ ๋ฌธ๊ฐ์
๋๋ค."}, |
|
{"role":"user","content":prompt} |
|
], |
|
max_tokens=200, |
|
temperature=0.7 |
|
) |
|
return resp.choices[0].message.content.strip() |
|
|
|
def call_advice(hints: str) -> str: |
|
"""ADVICE ๋จ๊ณ: CBT ์กฐ์ธ ์์ฑ.""" |
|
final_prompt = ADVICE_PROMPT.format(hints=hints) |
|
resp = openai.ChatCompletion.create( |
|
model=set_openai_model(), |
|
messages=[ |
|
{"role":"system","content":"๋น์ ์ Socratic CBT ๊ธฐ๋ฒ ์ ๋ฌธ๊ฐ์
๋๋ค."}, |
|
{"role":"user","content":final_prompt} |
|
], |
|
max_tokens=700, |
|
temperature=0.8 |
|
) |
|
return resp.choices[0].message.content.strip() |
|
|
|
def predict(user_input: str, state: dict): |
|
"""Gradio Callback: ์ํฌ๋ผํ
์ค CBT ์ฑ๋ด ํ๋ฆ (EMPATHYโSQโADVICE).""" |
|
history = state.get("history", []) |
|
stage = state.get("stage", "EMPATHY") |
|
turn = state.get("turn", 0) |
|
hints = state.get("hints", []) |
|
|
|
|
|
history.append(("User", user_input)) |
|
|
|
|
|
kb_answer = kb_search(user_input) |
|
hints.append(f"[KB] {kb_answer}") |
|
|
|
|
|
if stage == "EMPATHY": |
|
empathic = call_empathy(user_input) |
|
history.append(("Chatbot", empathic)) |
|
hints.append(empathic) |
|
stage = "SQ" |
|
turn = 0 |
|
return history, {"history": history, "stage": stage, "turn": turn, "hints": hints} |
|
|
|
if stage == "SQ" and turn < MAX_TURN: |
|
|
|
context_text = "\n".join([f"{r}: {c}" for (r,c) in history]) + "\n" + "\n".join(hints) |
|
sq = call_socratic_question(context_text) |
|
history.append(("Chatbot", sq)) |
|
hints.append(sq) |
|
turn += 1 |
|
return history, {"history": history, "stage": stage, "turn": turn, "hints": hints} |
|
|
|
|
|
stage = "ADVICE" |
|
combined_hints = "\n".join(hints) |
|
advice = call_advice(combined_hints) |
|
history.append(("Chatbot", advice)) |
|
stage = "END" |
|
return history, {"history":history, "stage":stage, "turn":turn, "hints":hints} |
|
|
|
def gradio_predict(user_input, chat_state): |
|
"""Gradio์์ user_input, state๋ฅผ ๋ฐ์ predict โ (chatbot ์ถ๋ ฅ, state ๊ฐฑ์ ).""" |
|
new_history, new_state = predict(user_input, chat_state) |
|
|
|
|
|
display_history = [] |
|
for (role, txt) in new_history: |
|
if role == "User": |
|
display_history.append([txt, ""]) |
|
else: |
|
if not display_history: |
|
display_history.append(["", txt]) |
|
elif display_history[-1][1] == "": |
|
display_history[-1][1] = txt |
|
else: |
|
display_history.append(["", txt]) |
|
return display_history, new_state |
|
|
|
def create_app(): |
|
"""Gradio Blocks UI ๊ตฌ์ฑ.""" |
|
with gr.Blocks() as demo: |
|
gr.Markdown("## ๐ฅ ์ํฌ๋ผํ
์ค CBT ์ฑ๋ด (GPT-4o)") |
|
|
|
chatbot = gr.Chatbot(label="Socratic CBT Chatbot") |
|
chat_state = gr.State({ |
|
"history": [], |
|
"stage":"EMPATHY", |
|
"turn":0, |
|
"hints":[] |
|
}) |
|
txt = gr.Textbox(show_label=False, placeholder="๊ณ ๋ฏผ์ด๋ ๊ถ๊ธํ ์ ์ ์
๋ ฅํ์ธ์.") |
|
|
|
txt.submit(fn=gradio_predict, inputs=[txt, chat_state], outputs=[chatbot, chat_state], scroll_to_output=True) |
|
return demo |
|
|
|
app = create_app() |
|
|
|
if __name__ == "__main__": |
|
|
|
app.launch(debug=True, share=True) |