Spaces:
Sleeping
Sleeping
from abc import ABC, abstractmethod | |
from .action import ACTION_SPACE_PROMPT | |
from .eval_type import ( | |
GROUNDING, | |
PROGRESS_LIKERT_SCALE, | |
PROGRESS_THREE_CLASS, | |
PROGRESS_WITH_CHECKLIST, | |
PROGRESS_WITH_CHECKLIST_IN_PROGRESS, | |
PROGRESS_OURS | |
) | |
from .input_information import ( | |
USER_INSTRUCTION, | |
TRAJECTORY, | |
AGENT_RESPONSE, | |
CHECKLIST, | |
CURRENT_URL, | |
TEXT_OBSERVATION, | |
SOM_IMAGE_OBSERVATION, | |
COORD_IMAGE_OBSERVATION | |
) | |
from .judge_prompt import ( | |
JUDGE_GROUNDING_PROMPT_TEMPLATE, | |
JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE, | |
JUDGE_THREE_CLASS_PROMPT_TEMPLATE, | |
JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE, | |
JUDGE_OURS_PROMPT_TEMPLATE, | |
JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE | |
) | |
from .checklist_prompt import ( | |
CHECKLIST_SYSTEM_PROMPT, | |
CHECKLIST_USER_PROMPT, | |
CHECKLIST_OURS_SYSTEM_PROMPT, | |
CHECKLIST_OURS_USER_PROMPT | |
) | |
from .image_utils import image_to_base64_url | |
class Message(ABC): | |
def get_messages(self): | |
pass | |
class BaseMessage(Message): | |
def __init__(self, input_info:dict, use_multimodal:bool=False): | |
self.input_info = input_info | |
self.use_multimodal = use_multimodal | |
def _get_system_message(self): | |
system_message = {"role": "system", "content": "You are a helpful assistant."} | |
return system_message | |
def _process_multimodal_message(self, prompt: str, image_list: list[str]): | |
multimodal_message = [] | |
text_prompt_prefix = prompt.split("<IMAGE_PLACEHOLDER>")[0] | |
text_prompt_suffix = prompt.split("<IMAGE_PLACEHOLDER>")[1] | |
multimodal_message.append({"type": "text", "text": text_prompt_prefix}) | |
for i, image in enumerate(image_list): | |
# TODO: text prompt for multiple images | |
# multimodal_message.append({"type": "text", "text": f"IMAGE {i+1}\n"}) | |
multimodal_message.append({"type": "image_url", "image_url": {"url": image_to_base64_url(image), "detail": "low"}}) | |
multimodal_message.append({"type": "text", "text": text_prompt_suffix}) | |
return {"role": "user", "content": multimodal_message} | |
def _get_user_message(self): | |
user_prompt = "What is the capital of France?" | |
if self.use_multimodal: | |
image_list = self.input_info.get("image_list", []) | |
user_message = self._process_multimodal_message(user_prompt, image_list) | |
else: | |
user_message = {"role": "user", "content": user_prompt} | |
return user_message | |
def get_messages(self): | |
message = [] | |
system_message = self._get_system_message() | |
user_message = self._get_user_message() | |
message.append(system_message) | |
# message.append({"role": "system", "content": ""}) | |
message.append(user_message) | |
return message | |
class ProgressMessage(BaseMessage): | |
''' | |
Progress Judge Message | |
''' | |
def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str, text_obs: str, image_obs: str, use_checklist:bool, use_in_progress:bool): | |
super().__init__(input_info, use_multimodal) | |
self.prompt_type = prompt_type | |
self.text_obs = text_obs | |
self.image_obs = image_obs | |
self.use_checklist = use_checklist | |
self.use_in_progress = use_in_progress | |
def _get_system_message(self): | |
if self.prompt_type == "likert_scale": | |
system_message = {"role": "system", "content": JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE["system"]} | |
elif self.prompt_type == "three_class": | |
system_message = {"role": "system", "content": JUDGE_THREE_CLASS_PROMPT_TEMPLATE["system"]} | |
elif self.prompt_type == "with_checklist": | |
system_message = {"role": "system", "content": JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE["system"]} | |
elif self.prompt_type == "ours": | |
system_message = {"role": "system", "content": JUDGE_OURS_PROMPT_TEMPLATE["system"]} | |
else: | |
raise ValueError(f"Invalid prompt type: {self.prompt_type}") | |
return system_message | |
def _setup_input_information(self): | |
observation = "## Current State\n" | |
observation += CURRENT_URL | |
# text observation | |
if self.text_obs: | |
observation += TEXT_OBSERVATION | |
# image observation (som, coord, none) | |
if self.image_obs == "som": | |
observation += SOM_IMAGE_OBSERVATION | |
elif self.image_obs == "coord": | |
observation += COORD_IMAGE_OBSERVATION | |
if self.use_checklist: | |
input_information = USER_INSTRUCTION + TRAJECTORY + observation + CHECKLIST + AGENT_RESPONSE | |
else: | |
input_information = USER_INSTRUCTION + TRAJECTORY + observation + AGENT_RESPONSE | |
return input_information | |
def _setup_task_info(self): | |
if self.prompt_type == "likert_scale": | |
task_description = PROGRESS_LIKERT_SCALE["task_description"] | |
output_format = PROGRESS_LIKERT_SCALE["output_format"] | |
elif self.prompt_type == "three_class": | |
task_description = PROGRESS_THREE_CLASS["task_description"] | |
output_format = PROGRESS_THREE_CLASS["output_format"] | |
elif self.prompt_type == "with_checklist": | |
if self.use_in_progress: | |
task_description = PROGRESS_WITH_CHECKLIST_IN_PROGRESS["task_description"] | |
output_format = PROGRESS_WITH_CHECKLIST_IN_PROGRESS["output_format"] | |
else: | |
task_description = PROGRESS_WITH_CHECKLIST["task_description"] | |
output_format = PROGRESS_WITH_CHECKLIST["output_format"] | |
else: | |
raise ValueError(f"Invalid prompt type: {self.prompt_type}") | |
return task_description, output_format | |
def _get_user_prompt_template(self): | |
if self.prompt_type == "likert_scale": | |
user_prompt = JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE["user"] | |
elif self.prompt_type == "three_class": | |
user_prompt = JUDGE_THREE_CLASS_PROMPT_TEMPLATE["user"] | |
elif self.prompt_type == "with_checklist": | |
user_prompt = JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE["user"] | |
else: | |
raise ValueError(f"Invalid prompt type: {self.prompt_type}") | |
return user_prompt | |
def _get_user_message(self): | |
# setup input information (user_instruction, trajectory, current_state, agent_response, checklist) | |
input_information_template = self._setup_input_information() | |
input_information = input_information_template.format(**self.input_info) | |
if self.prompt_type == "ours": | |
if self.use_checklist: | |
user_prompt = JUDGE_OURS_PROMPT_TEMPLATE["user"].format( | |
input_information=input_information, | |
) | |
else: | |
user_prompt = JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE["user"].format( | |
input_information=input_information, | |
) | |
else: | |
task_description, output_format = self._setup_task_info() | |
# get user prompt template by prompt type | |
user_prompt_template = self._get_user_prompt_template() | |
user_prompt = user_prompt_template.format( | |
action_space=ACTION_SPACE_PROMPT, | |
task_description=task_description, | |
input_information=input_information, | |
output_format=output_format | |
) | |
# process multimodal message | |
if self.use_multimodal: | |
image_list = self.input_info.get("image_list", []) | |
user_message = self._process_multimodal_message(user_prompt, image_list) | |
else: | |
user_message = {"role": "user", "content": user_prompt} | |
return user_message | |
class GroundingMessage(BaseMessage): | |
''' | |
Grounding Judge Message | |
''' | |
def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str, text_obs: str, image_obs: str): | |
super().__init__(input_info, use_multimodal) | |
self.prompt_type = prompt_type | |
self.text_obs = text_obs | |
self.image_obs = image_obs | |
def _get_system_message(self): | |
if self.prompt_type == "ours": | |
# TODO: implement ours | |
system_message = {"role": "system", "content": "You are a helpful assistant."} | |
elif self.prompt_type == "default": | |
system_message = {"role": "system", "content": JUDGE_GROUNDING_PROMPT_TEMPLATE["system"]} | |
else: | |
raise ValueError(f"Invalid prompt type: {self.prompt_type}") | |
return system_message | |
def _setup_input_information(self): | |
observation = "## Current State\n" | |
observation += CURRENT_URL | |
# text observation | |
if self.text_obs: | |
observation += TEXT_OBSERVATION | |
# image observation (som, coord, none) | |
if self.image_obs == "som": | |
observation += SOM_IMAGE_OBSERVATION | |
elif self.image_obs == "coord": | |
observation += COORD_IMAGE_OBSERVATION | |
# input_information = USER_INSTRUCTION + TRAJECTORY + observation + AGENT_RESPONSE # with trajectory | |
input_information = USER_INSTRUCTION + observation + AGENT_RESPONSE # without trajectory | |
return input_information | |
def _get_user_message(self): | |
if self.prompt_type == "ours": | |
# TODO: implement ours | |
user_message = {"role": "user", "content": "TODO"} | |
elif self.prompt_type == "default": | |
action_space = ACTION_SPACE_PROMPT | |
task_description = GROUNDING["task_description"] | |
output_format = GROUNDING["output_format"] | |
input_information_template = self._setup_input_information() | |
input_information = input_information_template.format(**self.input_info) | |
user_prompt = JUDGE_GROUNDING_PROMPT_TEMPLATE["user"].format( | |
action_space=action_space, | |
task_description=task_description, | |
input_information=input_information, | |
output_format=output_format | |
) | |
# process multimodal message | |
if self.use_multimodal: | |
image_list = self.input_info.get("image_list", []) | |
user_message = self._process_multimodal_message(user_prompt, image_list) | |
else: | |
user_message = {"role": "user", "content": user_prompt} | |
else: | |
raise ValueError(f"Invalid prompt type: {self.prompt_type}") | |
return user_message | |
class ChecklistMessage(BaseMessage): | |
''' | |
Checklist Message | |
''' | |
def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str): | |
super().__init__(input_info, use_multimodal) | |
self.prompt_type = prompt_type | |
def _get_system_message(self): | |
if self.prompt_type == "ours": | |
# TODO: implement ours | |
system_message = {"role": "system", "content": CHECKLIST_OURS_SYSTEM_PROMPT} | |
elif self.prompt_type == "default": | |
system_message = {"role": "system", "content": CHECKLIST_SYSTEM_PROMPT} | |
else: | |
raise ValueError(f"Invalid prompt type: {self.prompt_type}") | |
return system_message | |
def _get_user_message(self): | |
if self.prompt_type == "ours": | |
user_message = {"role": "user", "content": CHECKLIST_OURS_USER_PROMPT.format(**self.input_info)} | |
elif self.prompt_type == "default": | |
user_message = {"role": "user", "content": CHECKLIST_USER_PROMPT.format(**self.input_info)} | |
else: | |
raise ValueError(f"Invalid prompt type: {self.prompt_type}") | |
return user_message | |
def get_messages(input_info:dict, inference_mode:str, prompt_type:str, text_obs:str=None, image_obs:str=None, use_multimodal:bool=False, use_checklist:bool=False, use_in_progress:bool=False): | |
message_list = [] | |
if inference_mode == "judge_grounding": | |
message = GroundingMessage(input_info, use_multimodal=use_multimodal, prompt_type=prompt_type, text_obs=text_obs, image_obs=image_obs) | |
elif inference_mode == "judge_progress": | |
message = ProgressMessage(input_info, use_multimodal=use_multimodal, prompt_type=prompt_type, text_obs=text_obs, image_obs=image_obs, use_checklist=use_checklist, use_in_progress=use_in_progress) | |
elif inference_mode == "checklist_generation": | |
message = ChecklistMessage(input_info, use_multimodal=False, prompt_type=prompt_type) | |
else: | |
raise ValueError(f"Invalid inference mode: {inference_mode}") | |
system_message, user_message = message.get_messages() | |
message_list.append(system_message) | |
message_list.append(user_message) | |
return message_list |