iruno's picture
Upload 245 files
498ffec verified
from abc import ABC, abstractmethod
import time
import requests
import json
import math
from langsmith import Client
from langchain_openai import ChatOpenAI
from .prompts import get_messages
from .prompts.judge_prompt import (
JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE,
JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE,
JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE,
JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
)
from .prompts.image_utils import image_to_base64_url
MAX_RETRY = 3
RETRY_SLEEP = 5
MODEL_COST_MAPPING = {
"gpt-4o-mini": {
"input_token_cost": 0.15,
"output_token_cost": 0.6
},
"gpt-4o": {
"input_token_cost": 2.5,
"output_token_cost": 10
},
}
class Agent(ABC):
@abstractmethod
def generate_response(self, inputs: dict) -> str:
pass
class BaseAgent(Agent):
def __init__(self, agent_config: dict):
self.agent_config = agent_config
self._setup()
def _setup(self):
use_log_probs = self.agent_config.get("use_log_probs", False)
if use_log_probs:
self.llm = ChatOpenAI(
model=self.agent_config["model_name"],
base_url=self.agent_config["base_url"],
api_key=self.agent_config["api_key"],
temperature=self.agent_config["temperature"],
timeout=300,
logprobs=True,
top_logprobs=10
)
else:
self.llm = ChatOpenAI(
model=self.agent_config["model_name"],
base_url=self.agent_config["base_url"],
api_key=self.agent_config["api_key"],
temperature=self.agent_config["temperature"],
timeout=300
)
self.temperature = self.agent_config["temperature"]
self.num_generate = self.agent_config["num_generate"]
self.use_checklist = self.agent_config.get("use_checklist", False)
self.use_multimodal = self.agent_config.get("use_multimodal", False)
# setup cost
model_cost = MODEL_COST_MAPPING.get(self.agent_config["model_name"], None)
if model_cost and "api" in self.agent_config["base_url"]:
self.input_token_cost = model_cost["input_token_cost"]
self.output_token_cost = model_cost["output_token_cost"]
else:
self.input_token_cost = 0.0
self.output_token_cost = 0.0
def generate_with_retry(self, model_input, constraint_str_list: list = None):
total_input_tokens = 0
total_output_tokens = 0
if self.temperature == 0:
response = self.llm.invoke(model_input)
total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
else:
for i in range(MAX_RETRY):
try:
response = self.llm.invoke(model_input)
total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
if constraint_str_list:
pass_constraint_num = 0
for constraint_str in constraint_str_list:
if constraint_str in response.content:
pass_constraint_num += 1
if pass_constraint_num == len(constraint_str_list):
break
else:
print(f"Agent has fomat issue, retry... {i+1}/{MAX_RETRY}")
print(response.content)
else:
break
except Exception as e:
print(f"Agent returned an Error: {e}")
response = None
time.sleep(RETRY_SLEEP)
cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
if response is None:
return "", cost
else:
return response.content, cost
def prepare_message(self, model_input: dict, prompt_type: str):
message = []
return message
def generate_response(self, model_input: dict, prompt_type: str, constraint_str_list: list = None,):
total_cost = 0
response_list = []
# prepare message
message = self.prepare_message(model_input, prompt_type)
# print(message)
# n sampling
for i in range(self.num_generate):
response, cost = self.generate_with_retry(message, constraint_str_list)
response_list.append(response)
total_cost += cost
return response_list, total_cost
class GroundingJudgeAgent(BaseAgent):
def __init__(self, agent_config: dict):
super().__init__(agent_config)
self._setup()
def prepare_message(self, model_input: dict, prompt_type):
message = get_messages(
input_info=model_input,
inference_mode="judge_grounding",
prompt_type=prompt_type,
use_multimodal=self.use_multimodal,
text_obs=self.agent_config["text_obs_type"],
image_obs=self.agent_config["image_obs_type"]
)
return message
class ProgressJudgeAgent(BaseAgent):
def __init__(self, agent_config: dict):
super().__init__(agent_config)
self._setup()
def prepare_message(self, model_input: dict, prompt_type):
if self.agent_config["input_type"]=="text_only":
use_multimodal = False
text_obs = self.agent_config["text_obs_type"]
image_obs = None
elif self.agent_config["input_type"]=="image_only":
use_multimodal = True
text_obs = None
image_obs = self.agent_config["image_obs_type"]
elif self.agent_config["input_type"]=="text_image":
use_multimodal = True
text_obs = self.agent_config["text_obs_type"]
image_obs = self.agent_config["image_obs_type"]
else:
raise ValueError(f"Invalid input type: {self.agent_config['input_type']}")
if self.agent_config["use_in_progress"]:
use_in_progress = True
else:
use_in_progress = False
message = get_messages(
input_info=model_input,
inference_mode="judge_progress",
prompt_type=prompt_type,
use_checklist=self.use_checklist,
use_multimodal=use_multimodal,
text_obs=text_obs,
image_obs=image_obs,
use_in_progress=use_in_progress
)
return message
def add_logprob(self, ori_logprob: float, add_logprob: float):
if ori_logprob is None:
return add_logprob
else:
ori_prob = math.exp(ori_logprob)
add_prob = math.exp(add_logprob)
return math.log(ori_prob + add_prob)
def get_judge_probs(self, logprobs: list):
# target_judge = {
# "yes": [" Yes", "Yes"],
# "no": [" No", "No"],
# "in": [" In", "In"]
# }
target_judge = {
"yes": [
" Yes", "ĠYes", "Yes", "ĊYes",
"Ġyes", "yes", "Ċyes",
"ĠYES", "YES", "ĊYES",
"ĠDone", "Done", "ĊDone",
"ĠCompleted", "Completed", "ĊCompleted",
"ĠCorrect", "Correct", "ĊCorrect"
],
"no": [
" No", "ĠNo", "No", "ĊNo",
"ĠNO", "NO", "ĊNO",
"ĠNot", "Not", "ĊNot",
"ĠNone", "None", "ĊNone",
"ĠNope", "Nope", "ĊNope",
"ĠUn", "Un", "ĊUn",
"ĠWrong", "Wrong", "ĊWrong"
],
"in": [
" In", "ĠIn", "In", "ĊIn",
"ĠPending", "Pending", "ĊPending",
"ĠPart", "Part", "ĊPart",
"ĠPartial", "Partial", "ĊPartial",
"ĠInProgress", "InProgress", "ĊInProgress"
]
}
response_str = ""
judge_probs_list = []
# print(logprobs)
for i, log_prob in enumerate(logprobs):
# Start to find judge string
if "<answer>" in response_str:
find_judge_str = None
for judge_type in target_judge:
if log_prob["token"] in target_judge[judge_type]:
# print(log_prob)
find_judge_str = judge_type
break
if find_judge_str:
# print("find judge str")
token_judge_dict = {
"yes": None,
"no": None,
"in": None
}
if "top_logprobs" in log_prob:
for token_info in log_prob["top_logprobs"]:
for judge_type in target_judge:
for judge_str in target_judge[judge_type]:
# if judge_str in token_info["token"] and token_info["logprob"] > token_judge_dict[judge_type]:
# token_judge_dict[judge_type] = token_info["logprob"]
if judge_str in token_info["token"]:
# print(token_info["logprob"])
token_judge_dict[judge_type] = self.add_logprob(token_judge_dict[judge_type], token_info["logprob"])
# for None case
for judge_type in token_judge_dict:
if token_judge_dict[judge_type] is None:
token_judge_dict[judge_type] = float("-inf")
judge_probs_list.append(token_judge_dict)
else:
# for vllm bugs : no top_logprobs
for judge_type in token_judge_dict:
if judge_type == find_judge_str:
token_judge_dict[judge_type] = log_prob["logprob"]
else:
token_judge_dict[judge_type] = float("-inf")
judge_probs_list.append(token_judge_dict)
# print(token_judge_dict)
if "</answer>" in response_str:
break
response_str += log_prob["token"]
# print(response_str.replace("Ġ", " ").replace("Ċ", "\n"))
# print(judge_probs_list)
if len(judge_probs_list) == 0:
return [{
"yes": 0.0,
"no": 0.0,
"in": 0.0
}]
else:
# convert with softmax
final_judge_probs_list = []
for judge_probs in judge_probs_list:
exp_logprobs = [math.exp(x) for x in [judge_probs["yes"], judge_probs["no"], judge_probs["in"]]]
sum_exp_logprobs = sum(exp_logprobs)
softmax_probs = [x / sum_exp_logprobs for x in exp_logprobs]
final_judge_probs_list.append({
"yes": softmax_probs[0],
"no": softmax_probs[1],
"in": softmax_probs[2]
})
return final_judge_probs_list
def generate_probs(self, model_input: dict, prompt_type: str):
total_cost = 0
response_list = []
# prepare message
message = self.prepare_message(model_input, prompt_type)
# print(message)
for i in range(self.num_generate):
try:
response = self.llm.invoke(message)
total_input_tokens = response.response_metadata["token_usage"]["prompt_tokens"]
total_output_tokens = response.response_metadata["token_usage"]["completion_tokens"]
total_cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
logprobs = response.response_metadata["logprobs"]["content"]
response_list.append(
{
"response": response.content,
"judge_probs": self.get_judge_probs(logprobs)
}
)
except Exception as e:
print(f"Error: {e}")
# print(response.response_metadata["logprobs"])
response_list.append(
{
"response": response.content,
"judge_probs": []
}
)
return response_list, total_cost
class ChecklistGenerationAgent(BaseAgent):
def __init__(self, agent_config: dict):
super().__init__(agent_config)
self._setup()
def prepare_message(self, model_input: dict, prompt_type):
message = get_messages(
input_info=model_input,
inference_mode="checklist_generation",
prompt_type=prompt_type
)
return message
class ClassifierRewardAgent(Agent):
def __init__(self, url: str, use_checklist: bool = False, use_multimodal: bool = False):
self.url = url
self.use_checklist = use_checklist
self.use_multimodal = use_multimodal
def _process_multimodal_message(self, prompt: str, image_list: list[str]):
multimodal_message = []
text_prompt_prefix = prompt.split("<IMAGE_PLACEHOLDER>")[0]
text_prompt_suffix = prompt.split("<IMAGE_PLACEHOLDER>")[1]
multimodal_message = [
{"type": "text", "text": text_prompt_prefix},
# {"type": "image_url", "image_url": {"url": image_to_base64_url(image_list[0])}},
{"type": "image", "image": image_to_base64_url(image_list[0])},
{"type": "text", "text": text_prompt_suffix}
]
return multimodal_message
def _make_query(self, user_prompt_template: dict, model_input: dict | list[dict]):
if self.use_multimodal:
tmp_user_prompt = user_prompt_template["user"].format(
**model_input
)
user_prompt = self._process_multimodal_message(tmp_user_prompt, model_input["image_list"])
else:
user_prompt = user_prompt_template["user"].format(
**model_input
)
assistant_prompt = user_prompt_template["assistant"].format(
**model_input
)
query = [
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": assistant_prompt}
]
return query
def prepare_message(self, model_input: dict | list[dict], batch: bool = False):
if self.use_checklist:
if self.use_multimodal:
user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE
else:
user_prompt_template = JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE
else:
if self.use_multimodal:
user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
else:
user_prompt_template = JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE
if self.use_multimodal:
if batch:
message = [self._make_query(user_prompt_template, input) for input in model_input]
else:
message = [self._make_query(user_prompt_template, model_input)]
else:
if batch:
message = {
"query": [self._make_query(user_prompt_template, input) for input in model_input],
"promptts": []
}
else:
message = {
"query": self._make_query(user_prompt_template, model_input),
"prompts": []
}
return message
def get_rm_scroe(self, message: dict | list):
headers = {"Content-Type": "application/json"}
try:
if self.use_multimodal:
response = requests.post(
self.url,
json={"messages": message},
timeout=600
)
else:
response = requests.post(
self.url,
headers=headers,
data=json.dumps(message),
timeout=300
)
response.raise_for_status()
response_json = response.json()
if "rewards" not in response_json:
print(f"Error: 'rewards' key not found in API response: {response_json}")
return []
if "get_reward" in self.url:
# use openrlhf
return response_json["rewards"]
elif "pooling" in self.url:
# use vllm server
return response_json["reward"]
else:
# error
raise ValueError(f"Invalid URL: {self.url}")
except requests.exceptions.Timeout:
print(f"Error: Request timed out to {self.url}")
return []
except requests.exceptions.RequestException as e:
print(f"Error during request to {self.url}: {e}")
return []
except json.JSONDecodeError:
print(f"Error: Failed to decode JSON response from {self.url}")
return []
except KeyError as e:
print(f"Error: Missing key {e} in response from {self.url}")
return []
def generate_response(self, model_input: dict | list[dict], batch: bool = False):
if batch:
message = self.prepare_message(model_input, batch=True)
else:
message = self.prepare_message(model_input)
rewards = self.get_rm_scroe(message)
return rewards, 0