from abc import ABC, abstractmethod from .action import ACTION_SPACE_PROMPT from .eval_type import ( GROUNDING, PROGRESS_LIKERT_SCALE, PROGRESS_THREE_CLASS, PROGRESS_WITH_CHECKLIST, PROGRESS_WITH_CHECKLIST_IN_PROGRESS, PROGRESS_OURS ) from .input_information import ( USER_INSTRUCTION, TRAJECTORY, AGENT_RESPONSE, CHECKLIST, CURRENT_URL, TEXT_OBSERVATION, SOM_IMAGE_OBSERVATION, COORD_IMAGE_OBSERVATION ) from .judge_prompt import ( JUDGE_GROUNDING_PROMPT_TEMPLATE, JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE, JUDGE_THREE_CLASS_PROMPT_TEMPLATE, JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE, JUDGE_OURS_PROMPT_TEMPLATE, JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE ) from .checklist_prompt import ( CHECKLIST_SYSTEM_PROMPT, CHECKLIST_USER_PROMPT, CHECKLIST_OURS_SYSTEM_PROMPT, CHECKLIST_OURS_USER_PROMPT ) from .image_utils import image_to_base64_url class Message(ABC): @abstractmethod def get_messages(self): pass class BaseMessage(Message): def __init__(self, input_info:dict, use_multimodal:bool=False): self.input_info = input_info self.use_multimodal = use_multimodal def _get_system_message(self): system_message = {"role": "system", "content": "You are a helpful assistant."} return system_message def _process_multimodal_message(self, prompt: str, image_list: list[str]): multimodal_message = [] text_prompt_prefix = prompt.split("")[0] text_prompt_suffix = prompt.split("")[1] multimodal_message.append({"type": "text", "text": text_prompt_prefix}) for i, image in enumerate(image_list): # TODO: text prompt for multiple images # multimodal_message.append({"type": "text", "text": f"IMAGE {i+1}\n"}) multimodal_message.append({"type": "image_url", "image_url": {"url": image_to_base64_url(image), "detail": "low"}}) multimodal_message.append({"type": "text", "text": text_prompt_suffix}) return {"role": "user", "content": multimodal_message} def _get_user_message(self): user_prompt = "What is the capital of France?" if self.use_multimodal: image_list = self.input_info.get("image_list", []) user_message = self._process_multimodal_message(user_prompt, image_list) else: user_message = {"role": "user", "content": user_prompt} return user_message def get_messages(self): message = [] system_message = self._get_system_message() user_message = self._get_user_message() message.append(system_message) # message.append({"role": "system", "content": ""}) message.append(user_message) return message class ProgressMessage(BaseMessage): ''' Progress Judge Message ''' def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str, text_obs: str, image_obs: str, use_checklist:bool, use_in_progress:bool): super().__init__(input_info, use_multimodal) self.prompt_type = prompt_type self.text_obs = text_obs self.image_obs = image_obs self.use_checklist = use_checklist self.use_in_progress = use_in_progress def _get_system_message(self): if self.prompt_type == "likert_scale": system_message = {"role": "system", "content": JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE["system"]} elif self.prompt_type == "three_class": system_message = {"role": "system", "content": JUDGE_THREE_CLASS_PROMPT_TEMPLATE["system"]} elif self.prompt_type == "with_checklist": system_message = {"role": "system", "content": JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE["system"]} elif self.prompt_type == "ours": system_message = {"role": "system", "content": JUDGE_OURS_PROMPT_TEMPLATE["system"]} else: raise ValueError(f"Invalid prompt type: {self.prompt_type}") return system_message def _setup_input_information(self): observation = "## Current State\n" observation += CURRENT_URL # text observation if self.text_obs: observation += TEXT_OBSERVATION # image observation (som, coord, none) if self.image_obs == "som": observation += SOM_IMAGE_OBSERVATION elif self.image_obs == "coord": observation += COORD_IMAGE_OBSERVATION if self.use_checklist: input_information = USER_INSTRUCTION + TRAJECTORY + observation + CHECKLIST + AGENT_RESPONSE else: input_information = USER_INSTRUCTION + TRAJECTORY + observation + AGENT_RESPONSE return input_information def _setup_task_info(self): if self.prompt_type == "likert_scale": task_description = PROGRESS_LIKERT_SCALE["task_description"] output_format = PROGRESS_LIKERT_SCALE["output_format"] elif self.prompt_type == "three_class": task_description = PROGRESS_THREE_CLASS["task_description"] output_format = PROGRESS_THREE_CLASS["output_format"] elif self.prompt_type == "with_checklist": if self.use_in_progress: task_description = PROGRESS_WITH_CHECKLIST_IN_PROGRESS["task_description"] output_format = PROGRESS_WITH_CHECKLIST_IN_PROGRESS["output_format"] else: task_description = PROGRESS_WITH_CHECKLIST["task_description"] output_format = PROGRESS_WITH_CHECKLIST["output_format"] else: raise ValueError(f"Invalid prompt type: {self.prompt_type}") return task_description, output_format def _get_user_prompt_template(self): if self.prompt_type == "likert_scale": user_prompt = JUDGE_LIKERT_SCALE_PROMPT_TEMPLATE["user"] elif self.prompt_type == "three_class": user_prompt = JUDGE_THREE_CLASS_PROMPT_TEMPLATE["user"] elif self.prompt_type == "with_checklist": user_prompt = JUDGE_WITH_CHECKLIST_PROMPT_TEMPLATE["user"] else: raise ValueError(f"Invalid prompt type: {self.prompt_type}") return user_prompt def _get_user_message(self): # setup input information (user_instruction, trajectory, current_state, agent_response, checklist) input_information_template = self._setup_input_information() input_information = input_information_template.format(**self.input_info) if self.prompt_type == "ours": if self.use_checklist: user_prompt = JUDGE_OURS_PROMPT_TEMPLATE["user"].format( input_information=input_information, ) else: user_prompt = JUDGE_OURS_WO_CHECKLIST_PROMPT_TEMPLATE["user"].format( input_information=input_information, ) else: task_description, output_format = self._setup_task_info() # get user prompt template by prompt type user_prompt_template = self._get_user_prompt_template() user_prompt = user_prompt_template.format( action_space=ACTION_SPACE_PROMPT, task_description=task_description, input_information=input_information, output_format=output_format ) # process multimodal message if self.use_multimodal: image_list = self.input_info.get("image_list", []) user_message = self._process_multimodal_message(user_prompt, image_list) else: user_message = {"role": "user", "content": user_prompt} return user_message class GroundingMessage(BaseMessage): ''' Grounding Judge Message ''' def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str, text_obs: str, image_obs: str): super().__init__(input_info, use_multimodal) self.prompt_type = prompt_type self.text_obs = text_obs self.image_obs = image_obs def _get_system_message(self): if self.prompt_type == "ours": # TODO: implement ours system_message = {"role": "system", "content": "You are a helpful assistant."} elif self.prompt_type == "default": system_message = {"role": "system", "content": JUDGE_GROUNDING_PROMPT_TEMPLATE["system"]} else: raise ValueError(f"Invalid prompt type: {self.prompt_type}") return system_message def _setup_input_information(self): observation = "## Current State\n" observation += CURRENT_URL # text observation if self.text_obs: observation += TEXT_OBSERVATION # image observation (som, coord, none) if self.image_obs == "som": observation += SOM_IMAGE_OBSERVATION elif self.image_obs == "coord": observation += COORD_IMAGE_OBSERVATION # input_information = USER_INSTRUCTION + TRAJECTORY + observation + AGENT_RESPONSE # with trajectory input_information = USER_INSTRUCTION + observation + AGENT_RESPONSE # without trajectory return input_information def _get_user_message(self): if self.prompt_type == "ours": # TODO: implement ours user_message = {"role": "user", "content": "TODO"} elif self.prompt_type == "default": action_space = ACTION_SPACE_PROMPT task_description = GROUNDING["task_description"] output_format = GROUNDING["output_format"] input_information_template = self._setup_input_information() input_information = input_information_template.format(**self.input_info) user_prompt = JUDGE_GROUNDING_PROMPT_TEMPLATE["user"].format( action_space=action_space, task_description=task_description, input_information=input_information, output_format=output_format ) # process multimodal message if self.use_multimodal: image_list = self.input_info.get("image_list", []) user_message = self._process_multimodal_message(user_prompt, image_list) else: user_message = {"role": "user", "content": user_prompt} else: raise ValueError(f"Invalid prompt type: {self.prompt_type}") return user_message class ChecklistMessage(BaseMessage): ''' Checklist Message ''' def __init__(self, input_info:dict, use_multimodal:bool, prompt_type:str): super().__init__(input_info, use_multimodal) self.prompt_type = prompt_type def _get_system_message(self): if self.prompt_type == "ours": # TODO: implement ours system_message = {"role": "system", "content": CHECKLIST_OURS_SYSTEM_PROMPT} elif self.prompt_type == "default": system_message = {"role": "system", "content": CHECKLIST_SYSTEM_PROMPT} else: raise ValueError(f"Invalid prompt type: {self.prompt_type}") return system_message def _get_user_message(self): if self.prompt_type == "ours": user_message = {"role": "user", "content": CHECKLIST_OURS_USER_PROMPT.format(**self.input_info)} elif self.prompt_type == "default": user_message = {"role": "user", "content": CHECKLIST_USER_PROMPT.format(**self.input_info)} else: raise ValueError(f"Invalid prompt type: {self.prompt_type}") return user_message def get_messages(input_info:dict, inference_mode:str, prompt_type:str, text_obs:str=None, image_obs:str=None, use_multimodal:bool=False, use_checklist:bool=False, use_in_progress:bool=False): message_list = [] if inference_mode == "judge_grounding": message = GroundingMessage(input_info, use_multimodal=use_multimodal, prompt_type=prompt_type, text_obs=text_obs, image_obs=image_obs) elif inference_mode == "judge_progress": message = ProgressMessage(input_info, use_multimodal=use_multimodal, prompt_type=prompt_type, text_obs=text_obs, image_obs=image_obs, use_checklist=use_checklist, use_in_progress=use_in_progress) elif inference_mode == "checklist_generation": message = ChecklistMessage(input_info, use_multimodal=False, prompt_type=prompt_type) else: raise ValueError(f"Invalid inference mode: {inference_mode}") system_message, user_message = message.get_messages() message_list.append(system_message) message_list.append(user_message) return message_list