from abc import ABC, abstractmethod import time import requests import json import math from langsmith import Client import numpy as np from langchain_openai import ChatOpenAI from .prompts import get_messages from .prompts.judge_prompt import ( JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE, JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE, JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE, JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE ) from .prompts.image_utils import image_to_base64_url from .prompts.utils import convert_dict_messages MAX_RETRY = 3 RETRY_SLEEP = 5 MODEL_COST_MAPPING = { "gpt-4o-mini": { "input_token_cost": 0.15, "output_token_cost": 0.6 }, "gpt-4o": { "input_token_cost": 2.5, "output_token_cost": 10 }, } class Agent(ABC): @abstractmethod def generate_response(self, inputs: dict) -> str: pass class BaseAgent(Agent): def __init__(self, agent_config: dict): self.agent_config = agent_config self._setup() def _init_llm_object(self, **extra_kwargs): config = self.agent_config config.update(extra_kwargs) use_log_probs = config.get("use_log_probs", False) if use_log_probs: self.llm = ChatOpenAI( model=config["model_name"], base_url=config["base_url"], api_key=config["api_key"], temperature=config["temperature"], timeout=300, logprobs=True, top_logprobs=10, n=config.get('n', None) ) else: self.llm = ChatOpenAI( model=config["model_name"], base_url=config["base_url"], api_key=config["api_key"], temperature=config["temperature"], timeout=300, n=config.get('n', None) ) def _setup(self): self._init_llm_object() self.temperature = self.agent_config["temperature"] self.num_generate = self.agent_config["num_generate"] self.use_checklist = self.agent_config.get("use_checklist", False) self.use_multimodal = self.agent_config.get("use_multimodal", False) # setup cost model_cost = MODEL_COST_MAPPING.get(self.agent_config["model_name"], None) if model_cost and "api" in self.agent_config["base_url"]: self.input_token_cost = model_cost["input_token_cost"] self.output_token_cost = model_cost["output_token_cost"] else: self.input_token_cost = 0.0 self.output_token_cost = 0.0 def generate_with_retry(self, model_input, constraint_str_list: list = None): total_input_tokens = 0 total_output_tokens = 0 if self.temperature == 0: response = self.llm.invoke(model_input) total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"] total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"] else: for i in range(MAX_RETRY): try: response = self.llm.invoke(model_input) total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"] total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"] if constraint_str_list: pass_constraint_num = 0 for constraint_str in constraint_str_list: if constraint_str in response.content: pass_constraint_num += 1 if pass_constraint_num == len(constraint_str_list): break else: print(f"Agent has fomat issue, retry... {i+1}/{MAX_RETRY}") else: break except Exception as e: print(f"Agent returned an Error: {e}") response = None time.sleep(RETRY_SLEEP) cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000 if response is None: return "", cost else: return response.content, cost def prepare_message(self, model_input: dict, prompt_type: str): message = [] return message def generate_response(self, model_input: dict, prompt_type: str, constraint_str_list: list = None,): total_cost = 0 response_list = [] # prepare message message = self.prepare_message(model_input, prompt_type) # n sampling for i in range(self.num_generate): response, cost = self.generate_with_retry(message, constraint_str_list) response_list.append(response) total_cost += cost return response_list, total_cost class GroundingJudgeAgent(BaseAgent): def __init__(self, agent_config: dict): super().__init__(agent_config) self._setup() def prepare_message(self, model_input: dict, prompt_type): message = get_messages( input_info=model_input, inference_mode="judge_grounding", prompt_type=prompt_type, use_multimodal=self.use_multimodal, text_obs=self.agent_config["text_obs_type"], image_obs=self.agent_config["image_obs_type"] ) return message class ProgressJudgeAgent(BaseAgent): def __init__(self, agent_config: dict): super().__init__(agent_config) self._setup() def prepare_message(self, model_input: dict, prompt_type): if self.agent_config["input_type"]=="text_only": use_multimodal = False text_obs = self.agent_config["text_obs_type"] image_obs = None elif self.agent_config["input_type"]=="image_only": use_multimodal = True text_obs = None image_obs = self.agent_config["image_obs_type"] elif self.agent_config["input_type"]=="text_image": use_multimodal = True text_obs = self.agent_config["text_obs_type"] image_obs = self.agent_config["image_obs_type"] else: raise ValueError(f"Invalid input type: {self.agent_config['input_type']}") if self.agent_config["use_in_progress"]: use_in_progress = True else: use_in_progress = False message = get_messages( input_info=model_input, inference_mode="judge_progress", prompt_type=prompt_type, use_checklist=self.use_checklist, use_multimodal=use_multimodal, text_obs=text_obs, image_obs=image_obs, use_in_progress=use_in_progress ) return message def get_judge_probs(self, logprobs: list): # target_judge = { # "yes": [" Yes", "Yes", "ĠYes", "ĊYes"], # "no": [" No", "No", "ĠNo", "ĊNo"], # "in": [" In", "In", "ĠIn", "ĊIn"] # } target_judge = { "yes": [ "ĠYes", "Yes", "ĊYes", "Ġyes", "yes", "Ċyes", "ĠYES", "YES", "ĊYES", "ĠDone", "Done", "ĊDone", "ĠCompleted", "Completed", "ĊCompleted", "ĠCorrect", "Correct", "ĊCorrect" ], "no": [ "ĠNo", "No", "ĊNo", "ĠNO", "NO", "ĊNO", "ĠNot", "Not", "ĊNot", "ĠNone", "None", "ĊNone", "ĠNope", "Nope", "ĊNope", "ĠUn", "Un", "ĊUn", "ĠWrong", "Wrong", "ĊWrong" ], "in": [ "ĠIn", "In", "ĊIn", "ĠPending", "Pending", "ĊPending", "ĠPart", "Part", "ĊPart", "ĠPartial", "Partial", "ĊPartial", "ĠInProgress", "InProgress", "ĊInProgress" ] } response_str = "" judge_probs_list = [] for i, log_prob in enumerate(logprobs): # Start to find judge string if "" in response_str: find_judge_str = False for judge_type in target_judge: if log_prob["token"] in target_judge[judge_type]: # print(log_prob) find_judge_str = True break if find_judge_str: token_judge_dict = { "yes": None, "no": None, "in": None } for token_info in log_prob["top_logprobs"]: for judge_type in target_judge: for judge_str in target_judge[judge_type]: if judge_str in token_info["token"] : if token_judge_dict[judge_type] is None: token_judge_dict[judge_type] = math.exp(token_info["logprob"]) else: token_judge_dict[judge_type] += math.exp(token_info["logprob"]) token_judge_dict = { "yes": math.log(token_judge_dict["yes"]) if token_judge_dict["yes"] is not None else -float('inf'), "no": math.log(token_judge_dict["no"]) if token_judge_dict["no"] is not None else -float('inf'), "in": math.log(token_judge_dict["in"]) if token_judge_dict["in"] is not None else -float('inf') } judge_probs_list.append(token_judge_dict) if "" in response_str: break response_str += log_prob["token"] if len(judge_probs_list) == 0: return [{ "yes": 0.0, "no": 0.0, "in": 0.0 }] else: # convert with softmax final_judge_probs_list = [] max_in_prob = -float('inf') for idx, judge_probs in enumerate(judge_probs_list): exp_logprobs = [math.exp(x) for x in [judge_probs["yes"], judge_probs["no"], judge_probs["in"]]] sum_exp_logprobs = sum(exp_logprobs) softmax_probs = [x / sum_exp_logprobs for x in exp_logprobs] if softmax_probs[2] > max_in_prob: max_in_prob = softmax_probs[2] final_judge_probs_list.append({ "yes": softmax_probs[0], "no": softmax_probs[1], "in": softmax_probs[2] }) return final_judge_probs_list def generate_probs(self, model_input: dict, prompt_type: str, n=1, temperature=None): total_cost = 0 # prepare message message = self.prepare_message(model_input, prompt_type) messages = convert_dict_messages(message) kwargs = {'n': n} if temperature is not None: kwargs['temperature'] = temperature self._init_llm_object(**kwargs) try: response = self.llm.generate([messages]) # assume single batch finally: print('request url: ', self.agent_config['base_url']) # parse responses response_list = [] for generation in response.generations[0]: # assume singel batch # parse logprobs logprobs = generation.message.response_metadata["logprobs"]["content"] response_list.append( { "response": generation.message.content, "judge_probs": self.get_judge_probs(logprobs) } ) # calculate cost total_input_tokens = response.llm_output["token_usage"]["prompt_tokens"] total_output_tokens = response.llm_output["token_usage"]["completion_tokens"] total_cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000 return response_list, total_cost class ChecklistGenerationAgent(BaseAgent): def __init__(self, agent_config: dict): super().__init__(agent_config) self._setup() def prepare_message(self, model_input: dict, prompt_type): message = get_messages( input_info=model_input, inference_mode="checklist_generation", prompt_type=prompt_type ) return message class ClassifierRewardAgent(Agent): def __init__(self, url: str, use_checklist: bool = False, use_multimodal: bool = False): self.url = url self.use_checklist = use_checklist self.use_multimodal = use_multimodal def _process_multimodal_message(self, prompt: str, image_list: list[str]): multimodal_message = [] text_prompt_prefix = prompt.split("")[0] text_prompt_suffix = prompt.split("")[1] multimodal_message = [ {"type": "text", "text": text_prompt_prefix}, # {"type": "image_url", "image_url": {"url": image_to_base64_url(image_list[0])}}, {"type": "image", "image": image_to_base64_url(image_list[0])}, {"type": "text", "text": text_prompt_suffix} ] return multimodal_message def _make_query(self, user_prompt_template: dict, model_input: dict | list[dict]): if self.use_multimodal: tmp_user_prompt = user_prompt_template["user"].format( **model_input ) user_prompt = self._process_multimodal_message(tmp_user_prompt, model_input["image_list"]) else: user_prompt = user_prompt_template["user"].format( **model_input ) assistant_prompt = user_prompt_template["assistant"].format( **model_input ) query = [ {"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_prompt} ] return query def prepare_message(self, model_input: dict | list[dict], batch: bool = False): if self.use_checklist: if self.use_multimodal: user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE else: user_prompt_template = JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE else: if self.use_multimodal: user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE else: user_prompt_template = JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE if self.use_multimodal: if batch: message = [self._make_query(user_prompt_template, input) for input in model_input] else: message = [self._make_query(user_prompt_template, model_input)] else: if batch: message = { "query": [self._make_query(user_prompt_template, input) for input in model_input], "promptts": [] } else: message = { "query": self._make_query(user_prompt_template, model_input), "prompts": [] } return message def get_rm_scroe(self, message: dict | list): headers = {"Content-Type": "application/json"} try: if self.use_multimodal: response = requests.post( self.url, json={"messages": message}, timeout=600 ) else: response = requests.post( self.url, headers=headers, data=json.dumps(message), timeout=300 ) response.raise_for_status() response_json = response.json() if "rewards" not in response_json: print(f"Error: 'rewards' key not found in API response: {response_json}") return [] if "get_reward" in self.url: # use openrlhf return response_json["rewards"] elif "pooling" in self.url: # use vllm server return response_json["reward"] else: # error raise ValueError(f"Invalid URL: {self.url}") except requests.exceptions.Timeout: print(f"Error: Request timed out to {self.url}") return [] except requests.exceptions.RequestException as e: print(f"Error during request to {self.url}: {e}") return [] except json.JSONDecodeError: print(f"Error: Failed to decode JSON response from {self.url}") return [] except KeyError as e: print(f"Error: Missing key {e} in response from {self.url}") return [] def generate_response(self, model_input: dict | list[dict], batch: bool = False): if batch: message = self.prepare_message(model_input, batch=True) else: message = self.prepare_message(model_input) rewards = self.get_rm_scroe(message) return rewards, 0