Spaces:
Sleeping
Sleeping
import re | |
# μμ±λ λͺ¨λ λ΄ μλ΅ κΈ°λ‘ | |
def generate_reply(ctx, makePipeLine, user_msg): | |
# μ΅μ΄ μλ΅ | |
response = generate_valid_response(ctx, makePipeLine, user_msg) | |
ctx.addHistory("bot", response) | |
# λΆμμ ν μλ΅μ΄ μ λλλ―λ‘ μ¬μ©νμ§ μμ | |
''' | |
# μλ΅μ΄ λκ²Όλ€λ©΄ μΆκ° μμ± | |
if is_truncated_response(response): | |
continuation = generate_valid_response(ctx, makePipeLine, response) | |
ctx.addHistory("bot", continuation) | |
''' | |
# λ΄ μλ΅ 1ν μμ± | |
def generate_valid_response(ctx, makePipeline, user_msg) -> str: | |
user_name = ctx.getUserName() | |
bot_name = ctx.getBotName() | |
while True: | |
prompt = build_prompt(ctx.getHistory(), user_msg, user_name, bot_name) | |
full_text = makePipeline.character_chat(prompt) | |
response = extract_response(full_text) | |
print(f"debug: {response}") | |
if is_valid_response(response, user_name, bot_name): | |
break | |
return clean_response(response, bot_name) | |
# μ λ ₯ ν둬ννΈ μ 리 | |
def build_prompt(history, user_msg, user_name, bot_name): | |
with open("assets/prompt/init.txt", "r", encoding="utf-8") as f: | |
system_prompt = f.read().strip() | |
# μ΅κ·Ό λν νμ€ν 리λ₯Ό μΌλ° ν μ€νΈλ‘ μ¬κ΅¬μ± | |
dialogue = "" | |
for turn in history[-16:]: | |
role = user_name if turn["role"] == "user" else bot_name | |
dialogue += f"{role}: {turn['text']}\n" | |
dialogue += f"{user_name}: {user_msg}\n" | |
# λͺ¨λΈμ λ§λ ν¬λ§· κ΅¬μ± | |
prompt = f"""### Instruction: | |
{system_prompt} | |
{dialogue} | |
### Response: | |
{bot_name}:""" | |
return prompt | |
# μΆλ ₯μμ μλ΅ μΆμΆ (HyperCLOVAX ν¬λ§·μ λ§κ²) | |
def extract_response(full_text): | |
# '### Response:' μ΄ν ν μ€νΈ μΆμΆ | |
if "### Response:" in full_text: | |
reply = full_text.split("### Response:")[-1].strip() | |
else: | |
reply = full_text.strip() | |
return reply | |
# μλ΅ μ ν¨μ± κ²μ¬ | |
def is_valid_response(text: str, user_name, bot_name) -> bool: | |
if user_name + ":" in text: | |
return False | |
return True | |
# μλ΅ νμ μ 리 | |
def clean_response(text: str, bot_name): | |
return re.sub(rf"{bot_name}:\\s*", "", text).strip() | |
# μ€λ¨λ μλ΅ μ¬λΆ κ²μ¬ | |
def is_truncated_response(text: str) -> bool: | |
return re.search(r"[.?!β¦\u2026\u2639\u263A\u2764\uD83D\uDE0A\uD83D\uDE22]$", text.strip()) is None |