test_wprm3 / agent /mini_bench /reward_agent.py
iruno's picture
Upload 245 files
498ffec verified
raw
history blame
17.8 kB
from abc import ABC, abstractmethod
import time
import requests
import json
import math
from langsmith import Client
import numpy as np
from langchain_openai import ChatOpenAI
from .prompts import get_messages
from .prompts.judge_prompt import (
JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE,
JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE,
JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE,
JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
)
from .prompts.image_utils import image_to_base64_url
from .prompts.utils import convert_dict_messages
MAX_RETRY = 3
RETRY_SLEEP = 5
MODEL_COST_MAPPING = {
"gpt-4o-mini": {
"input_token_cost": 0.15,
"output_token_cost": 0.6
},
"gpt-4o": {
"input_token_cost": 2.5,
"output_token_cost": 10
},
}
class Agent(ABC):
@abstractmethod
def generate_response(self, inputs: dict) -> str:
pass
class BaseAgent(Agent):
def __init__(self, agent_config: dict):
self.agent_config = agent_config
self._setup()
def _init_llm_object(self, **extra_kwargs):
config = self.agent_config
config.update(extra_kwargs)
use_log_probs = config.get("use_log_probs", False)
if use_log_probs:
self.llm = ChatOpenAI(
model=config["model_name"],
base_url=config["base_url"],
api_key=config["api_key"],
temperature=config["temperature"],
timeout=300,
logprobs=True,
top_logprobs=10,
n=config.get('n', None)
)
else:
self.llm = ChatOpenAI(
model=config["model_name"],
base_url=config["base_url"],
api_key=config["api_key"],
temperature=config["temperature"],
timeout=300,
n=config.get('n', None)
)
def _setup(self):
self._init_llm_object()
self.temperature = self.agent_config["temperature"]
self.num_generate = self.agent_config["num_generate"]
self.use_checklist = self.agent_config.get("use_checklist", False)
self.use_multimodal = self.agent_config.get("use_multimodal", False)
# setup cost
model_cost = MODEL_COST_MAPPING.get(self.agent_config["model_name"], None)
if model_cost and "api" in self.agent_config["base_url"]:
self.input_token_cost = model_cost["input_token_cost"]
self.output_token_cost = model_cost["output_token_cost"]
else:
self.input_token_cost = 0.0
self.output_token_cost = 0.0
def generate_with_retry(self, model_input, constraint_str_list: list = None):
total_input_tokens = 0
total_output_tokens = 0
if self.temperature == 0:
response = self.llm.invoke(model_input)
total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
else:
for i in range(MAX_RETRY):
try:
response = self.llm.invoke(model_input)
total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"]
total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"]
if constraint_str_list:
pass_constraint_num = 0
for constraint_str in constraint_str_list:
if constraint_str in response.content:
pass_constraint_num += 1
if pass_constraint_num == len(constraint_str_list):
break
else:
print(f"Agent has fomat issue, retry... {i+1}/{MAX_RETRY}")
else:
break
except Exception as e:
print(f"Agent returned an Error: {e}")
response = None
time.sleep(RETRY_SLEEP)
cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
if response is None:
return "", cost
else:
return response.content, cost
def prepare_message(self, model_input: dict, prompt_type: str):
message = []
return message
def generate_response(self, model_input: dict, prompt_type: str, constraint_str_list: list = None,):
total_cost = 0
response_list = []
# prepare message
message = self.prepare_message(model_input, prompt_type)
# n sampling
for i in range(self.num_generate):
response, cost = self.generate_with_retry(message, constraint_str_list)
response_list.append(response)
total_cost += cost
return response_list, total_cost
class GroundingJudgeAgent(BaseAgent):
def __init__(self, agent_config: dict):
super().__init__(agent_config)
self._setup()
def prepare_message(self, model_input: dict, prompt_type):
message = get_messages(
input_info=model_input,
inference_mode="judge_grounding",
prompt_type=prompt_type,
use_multimodal=self.use_multimodal,
text_obs=self.agent_config["text_obs_type"],
image_obs=self.agent_config["image_obs_type"]
)
return message
class ProgressJudgeAgent(BaseAgent):
def __init__(self, agent_config: dict):
super().__init__(agent_config)
self._setup()
def prepare_message(self, model_input: dict, prompt_type):
if self.agent_config["input_type"]=="text_only":
use_multimodal = False
text_obs = self.agent_config["text_obs_type"]
image_obs = None
elif self.agent_config["input_type"]=="image_only":
use_multimodal = True
text_obs = None
image_obs = self.agent_config["image_obs_type"]
elif self.agent_config["input_type"]=="text_image":
use_multimodal = True
text_obs = self.agent_config["text_obs_type"]
image_obs = self.agent_config["image_obs_type"]
else:
raise ValueError(f"Invalid input type: {self.agent_config['input_type']}")
if self.agent_config["use_in_progress"]:
use_in_progress = True
else:
use_in_progress = False
message = get_messages(
input_info=model_input,
inference_mode="judge_progress",
prompt_type=prompt_type,
use_checklist=self.use_checklist,
use_multimodal=use_multimodal,
text_obs=text_obs,
image_obs=image_obs,
use_in_progress=use_in_progress
)
return message
def get_judge_probs(self, logprobs: list):
# target_judge = {
# "yes": [" Yes", "Yes", "ĠYes", "ĊYes"],
# "no": [" No", "No", "ĠNo", "ĊNo"],
# "in": [" In", "In", "ĠIn", "ĊIn"]
# }
target_judge = {
"yes": [
"ĠYes", "Yes", "ĊYes",
"Ġyes", "yes", "Ċyes",
"ĠYES", "YES", "ĊYES",
"ĠDone", "Done", "ĊDone",
"ĠCompleted", "Completed", "ĊCompleted",
"ĠCorrect", "Correct", "ĊCorrect"
],
"no": [
"ĠNo", "No", "ĊNo",
"ĠNO", "NO", "ĊNO",
"ĠNot", "Not", "ĊNot",
"ĠNone", "None", "ĊNone",
"ĠNope", "Nope", "ĊNope",
"ĠUn", "Un", "ĊUn",
"ĠWrong", "Wrong", "ĊWrong"
],
"in": [
"ĠIn", "In", "ĊIn",
"ĠPending", "Pending", "ĊPending",
"ĠPart", "Part", "ĊPart",
"ĠPartial", "Partial", "ĊPartial",
"ĠInProgress", "InProgress", "ĊInProgress"
]
}
response_str = ""
judge_probs_list = []
for i, log_prob in enumerate(logprobs):
# Start to find judge string
if "<answer>" in response_str:
find_judge_str = False
for judge_type in target_judge:
if log_prob["token"] in target_judge[judge_type]:
# print(log_prob)
find_judge_str = True
break
if find_judge_str:
token_judge_dict = {
"yes": None,
"no": None,
"in": None
}
for token_info in log_prob["top_logprobs"]:
for judge_type in target_judge:
for judge_str in target_judge[judge_type]:
if judge_str in token_info["token"] :
if token_judge_dict[judge_type] is None:
token_judge_dict[judge_type] = math.exp(token_info["logprob"])
else:
token_judge_dict[judge_type] += math.exp(token_info["logprob"])
token_judge_dict = {
"yes": math.log(token_judge_dict["yes"]) if token_judge_dict["yes"] is not None else -float('inf'),
"no": math.log(token_judge_dict["no"]) if token_judge_dict["no"] is not None else -float('inf'),
"in": math.log(token_judge_dict["in"]) if token_judge_dict["in"] is not None else -float('inf')
}
judge_probs_list.append(token_judge_dict)
if "</answer>" in response_str:
break
response_str += log_prob["token"]
if len(judge_probs_list) == 0:
return [{
"yes": 0.0,
"no": 0.0,
"in": 0.0
}]
else:
# convert with softmax
final_judge_probs_list = []
max_in_prob = -float('inf')
for idx, judge_probs in enumerate(judge_probs_list):
exp_logprobs = [math.exp(x) for x in [judge_probs["yes"], judge_probs["no"], judge_probs["in"]]]
sum_exp_logprobs = sum(exp_logprobs)
softmax_probs = [x / sum_exp_logprobs for x in exp_logprobs]
if softmax_probs[2] > max_in_prob:
max_in_prob = softmax_probs[2]
final_judge_probs_list.append({
"yes": softmax_probs[0],
"no": softmax_probs[1],
"in": softmax_probs[2]
})
return final_judge_probs_list
def generate_probs(self, model_input: dict, prompt_type: str, n=1, temperature=None):
total_cost = 0
# prepare message
message = self.prepare_message(model_input, prompt_type)
messages = convert_dict_messages(message)
kwargs = {'n': n}
if temperature is not None:
kwargs['temperature'] = temperature
self._init_llm_object(**kwargs)
try:
response = self.llm.generate([messages]) # assume single batch
finally:
print('request url: ', self.agent_config['base_url'])
# parse responses
response_list = []
for generation in response.generations[0]: # assume singel batch
# parse logprobs
logprobs = generation.message.response_metadata["logprobs"]["content"]
response_list.append(
{
"response": generation.message.content,
"judge_probs": self.get_judge_probs(logprobs)
}
)
# calculate cost
total_input_tokens = response.llm_output["token_usage"]["prompt_tokens"]
total_output_tokens = response.llm_output["token_usage"]["completion_tokens"]
total_cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000
return response_list, total_cost
class ChecklistGenerationAgent(BaseAgent):
def __init__(self, agent_config: dict):
super().__init__(agent_config)
self._setup()
def prepare_message(self, model_input: dict, prompt_type):
message = get_messages(
input_info=model_input,
inference_mode="checklist_generation",
prompt_type=prompt_type
)
return message
class ClassifierRewardAgent(Agent):
def __init__(self, url: str, use_checklist: bool = False, use_multimodal: bool = False):
self.url = url
self.use_checklist = use_checklist
self.use_multimodal = use_multimodal
def _process_multimodal_message(self, prompt: str, image_list: list[str]):
multimodal_message = []
text_prompt_prefix = prompt.split("<IMAGE_PLACEHOLDER>")[0]
text_prompt_suffix = prompt.split("<IMAGE_PLACEHOLDER>")[1]
multimodal_message = [
{"type": "text", "text": text_prompt_prefix},
# {"type": "image_url", "image_url": {"url": image_to_base64_url(image_list[0])}},
{"type": "image", "image": image_to_base64_url(image_list[0])},
{"type": "text", "text": text_prompt_suffix}
]
return multimodal_message
def _make_query(self, user_prompt_template: dict, model_input: dict | list[dict]):
if self.use_multimodal:
tmp_user_prompt = user_prompt_template["user"].format(
**model_input
)
user_prompt = self._process_multimodal_message(tmp_user_prompt, model_input["image_list"])
else:
user_prompt = user_prompt_template["user"].format(
**model_input
)
assistant_prompt = user_prompt_template["assistant"].format(
**model_input
)
query = [
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": assistant_prompt}
]
return query
def prepare_message(self, model_input: dict | list[dict], batch: bool = False):
if self.use_checklist:
if self.use_multimodal:
user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE
else:
user_prompt_template = JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE
else:
if self.use_multimodal:
user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE
else:
user_prompt_template = JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE
if self.use_multimodal:
if batch:
message = [self._make_query(user_prompt_template, input) for input in model_input]
else:
message = [self._make_query(user_prompt_template, model_input)]
else:
if batch:
message = {
"query": [self._make_query(user_prompt_template, input) for input in model_input],
"promptts": []
}
else:
message = {
"query": self._make_query(user_prompt_template, model_input),
"prompts": []
}
return message
def get_rm_scroe(self, message: dict | list):
headers = {"Content-Type": "application/json"}
try:
if self.use_multimodal:
response = requests.post(
self.url,
json={"messages": message},
timeout=600
)
else:
response = requests.post(
self.url,
headers=headers,
data=json.dumps(message),
timeout=300
)
response.raise_for_status()
response_json = response.json()
if "rewards" not in response_json:
print(f"Error: 'rewards' key not found in API response: {response_json}")
return []
if "get_reward" in self.url:
# use openrlhf
return response_json["rewards"]
elif "pooling" in self.url:
# use vllm server
return response_json["reward"]
else:
# error
raise ValueError(f"Invalid URL: {self.url}")
except requests.exceptions.Timeout:
print(f"Error: Request timed out to {self.url}")
return []
except requests.exceptions.RequestException as e:
print(f"Error during request to {self.url}: {e}")
return []
except json.JSONDecodeError:
print(f"Error: Failed to decode JSON response from {self.url}")
return []
except KeyError as e:
print(f"Error: Missing key {e} in response from {self.url}")
return []
def generate_response(self, model_input: dict | list[dict], batch: bool = False):
if batch:
message = self.prepare_message(model_input, batch=True)
else:
message = self.prepare_message(model_input)
rewards = self.get_rm_scroe(message)
return rewards, 0