Spaces:
Sleeping
Sleeping
from abc import ABC, abstractmethod | |
import time | |
import requests | |
import json | |
import math | |
from langsmith import Client | |
from langchain_openai import ChatOpenAI | |
from .prompts import get_messages | |
from .prompts.judge_prompt import ( | |
JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE, | |
JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE, | |
JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE, | |
JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE | |
) | |
from .prompts.image_utils import image_to_base64_url | |
MAX_RETRY = 3 | |
RETRY_SLEEP = 5 | |
MODEL_COST_MAPPING = { | |
"gpt-4o-mini": { | |
"input_token_cost": 0.15, | |
"output_token_cost": 0.6 | |
}, | |
"gpt-4o": { | |
"input_token_cost": 2.5, | |
"output_token_cost": 10 | |
}, | |
} | |
class Agent(ABC): | |
def generate_response(self, inputs: dict) -> str: | |
pass | |
class BaseAgent(Agent): | |
def __init__(self, agent_config: dict): | |
self.agent_config = agent_config | |
self._setup() | |
def _setup(self): | |
use_log_probs = self.agent_config.get("use_log_probs", False) | |
if use_log_probs: | |
self.llm = ChatOpenAI( | |
model=self.agent_config["model_name"], | |
base_url=self.agent_config["base_url"], | |
api_key=self.agent_config["api_key"], | |
temperature=self.agent_config["temperature"], | |
timeout=300, | |
logprobs=True, | |
top_logprobs=10 | |
) | |
else: | |
self.llm = ChatOpenAI( | |
model=self.agent_config["model_name"], | |
base_url=self.agent_config["base_url"], | |
api_key=self.agent_config["api_key"], | |
temperature=self.agent_config["temperature"], | |
timeout=300 | |
) | |
self.temperature = self.agent_config["temperature"] | |
self.num_generate = self.agent_config["num_generate"] | |
self.use_checklist = self.agent_config.get("use_checklist", False) | |
self.use_multimodal = self.agent_config.get("use_multimodal", False) | |
# setup cost | |
model_cost = MODEL_COST_MAPPING.get(self.agent_config["model_name"], None) | |
if model_cost and "api" in self.agent_config["base_url"]: | |
self.input_token_cost = model_cost["input_token_cost"] | |
self.output_token_cost = model_cost["output_token_cost"] | |
else: | |
self.input_token_cost = 0.0 | |
self.output_token_cost = 0.0 | |
def generate_with_retry(self, model_input, constraint_str_list: list = None): | |
total_input_tokens = 0 | |
total_output_tokens = 0 | |
if self.temperature == 0: | |
response = self.llm.invoke(model_input) | |
total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"] | |
total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"] | |
else: | |
for i in range(MAX_RETRY): | |
try: | |
response = self.llm.invoke(model_input) | |
total_input_tokens += response.response_metadata["token_usage"]["prompt_tokens"] | |
total_output_tokens += response.response_metadata["token_usage"]["completion_tokens"] | |
if constraint_str_list: | |
pass_constraint_num = 0 | |
for constraint_str in constraint_str_list: | |
if constraint_str in response.content: | |
pass_constraint_num += 1 | |
if pass_constraint_num == len(constraint_str_list): | |
break | |
else: | |
print(f"Agent has fomat issue, retry... {i+1}/{MAX_RETRY}") | |
print(response.content) | |
else: | |
break | |
except Exception as e: | |
print(f"Agent returned an Error: {e}") | |
response = None | |
time.sleep(RETRY_SLEEP) | |
cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000 | |
if response is None: | |
return "", cost | |
else: | |
return response.content, cost | |
def prepare_message(self, model_input: dict, prompt_type: str): | |
message = [] | |
return message | |
def generate_response(self, model_input: dict, prompt_type: str, constraint_str_list: list = None,): | |
total_cost = 0 | |
response_list = [] | |
# prepare message | |
message = self.prepare_message(model_input, prompt_type) | |
# print(message) | |
# n sampling | |
for i in range(self.num_generate): | |
response, cost = self.generate_with_retry(message, constraint_str_list) | |
response_list.append(response) | |
total_cost += cost | |
return response_list, total_cost | |
class GroundingJudgeAgent(BaseAgent): | |
def __init__(self, agent_config: dict): | |
super().__init__(agent_config) | |
self._setup() | |
def prepare_message(self, model_input: dict, prompt_type): | |
message = get_messages( | |
input_info=model_input, | |
inference_mode="judge_grounding", | |
prompt_type=prompt_type, | |
use_multimodal=self.use_multimodal, | |
text_obs=self.agent_config["text_obs_type"], | |
image_obs=self.agent_config["image_obs_type"] | |
) | |
return message | |
class ProgressJudgeAgent(BaseAgent): | |
def __init__(self, agent_config: dict): | |
super().__init__(agent_config) | |
self._setup() | |
def prepare_message(self, model_input: dict, prompt_type): | |
if self.agent_config["input_type"]=="text_only": | |
use_multimodal = False | |
text_obs = self.agent_config["text_obs_type"] | |
image_obs = None | |
elif self.agent_config["input_type"]=="image_only": | |
use_multimodal = True | |
text_obs = None | |
image_obs = self.agent_config["image_obs_type"] | |
elif self.agent_config["input_type"]=="text_image": | |
use_multimodal = True | |
text_obs = self.agent_config["text_obs_type"] | |
image_obs = self.agent_config["image_obs_type"] | |
else: | |
raise ValueError(f"Invalid input type: {self.agent_config['input_type']}") | |
if self.agent_config["use_in_progress"]: | |
use_in_progress = True | |
else: | |
use_in_progress = False | |
message = get_messages( | |
input_info=model_input, | |
inference_mode="judge_progress", | |
prompt_type=prompt_type, | |
use_checklist=self.use_checklist, | |
use_multimodal=use_multimodal, | |
text_obs=text_obs, | |
image_obs=image_obs, | |
use_in_progress=use_in_progress | |
) | |
return message | |
def add_logprob(self, ori_logprob: float, add_logprob: float): | |
if ori_logprob is None: | |
return add_logprob | |
else: | |
ori_prob = math.exp(ori_logprob) | |
add_prob = math.exp(add_logprob) | |
return math.log(ori_prob + add_prob) | |
def get_judge_probs(self, logprobs: list): | |
# target_judge = { | |
# "yes": [" Yes", "Yes"], | |
# "no": [" No", "No"], | |
# "in": [" In", "In"] | |
# } | |
target_judge = { | |
"yes": [ | |
" Yes", "ĠYes", "Yes", "ĊYes", | |
"Ġyes", "yes", "Ċyes", | |
"ĠYES", "YES", "ĊYES", | |
"ĠDone", "Done", "ĊDone", | |
"ĠCompleted", "Completed", "ĊCompleted", | |
"ĠCorrect", "Correct", "ĊCorrect" | |
], | |
"no": [ | |
" No", "ĠNo", "No", "ĊNo", | |
"ĠNO", "NO", "ĊNO", | |
"ĠNot", "Not", "ĊNot", | |
"ĠNone", "None", "ĊNone", | |
"ĠNope", "Nope", "ĊNope", | |
"ĠUn", "Un", "ĊUn", | |
"ĠWrong", "Wrong", "ĊWrong" | |
], | |
"in": [ | |
" In", "ĠIn", "In", "ĊIn", | |
"ĠPending", "Pending", "ĊPending", | |
"ĠPart", "Part", "ĊPart", | |
"ĠPartial", "Partial", "ĊPartial", | |
"ĠInProgress", "InProgress", "ĊInProgress" | |
] | |
} | |
response_str = "" | |
judge_probs_list = [] | |
# print(logprobs) | |
for i, log_prob in enumerate(logprobs): | |
# Start to find judge string | |
if "<answer>" in response_str: | |
find_judge_str = None | |
for judge_type in target_judge: | |
if log_prob["token"] in target_judge[judge_type]: | |
# print(log_prob) | |
find_judge_str = judge_type | |
break | |
if find_judge_str: | |
# print("find judge str") | |
token_judge_dict = { | |
"yes": None, | |
"no": None, | |
"in": None | |
} | |
if "top_logprobs" in log_prob: | |
for token_info in log_prob["top_logprobs"]: | |
for judge_type in target_judge: | |
for judge_str in target_judge[judge_type]: | |
# if judge_str in token_info["token"] and token_info["logprob"] > token_judge_dict[judge_type]: | |
# token_judge_dict[judge_type] = token_info["logprob"] | |
if judge_str in token_info["token"]: | |
# print(token_info["logprob"]) | |
token_judge_dict[judge_type] = self.add_logprob(token_judge_dict[judge_type], token_info["logprob"]) | |
# for None case | |
for judge_type in token_judge_dict: | |
if token_judge_dict[judge_type] is None: | |
token_judge_dict[judge_type] = float("-inf") | |
judge_probs_list.append(token_judge_dict) | |
else: | |
# for vllm bugs : no top_logprobs | |
for judge_type in token_judge_dict: | |
if judge_type == find_judge_str: | |
token_judge_dict[judge_type] = log_prob["logprob"] | |
else: | |
token_judge_dict[judge_type] = float("-inf") | |
judge_probs_list.append(token_judge_dict) | |
# print(token_judge_dict) | |
if "</answer>" in response_str: | |
break | |
response_str += log_prob["token"] | |
# print(response_str.replace("Ġ", " ").replace("Ċ", "\n")) | |
# print(judge_probs_list) | |
if len(judge_probs_list) == 0: | |
return [{ | |
"yes": 0.0, | |
"no": 0.0, | |
"in": 0.0 | |
}] | |
else: | |
# convert with softmax | |
final_judge_probs_list = [] | |
for judge_probs in judge_probs_list: | |
exp_logprobs = [math.exp(x) for x in [judge_probs["yes"], judge_probs["no"], judge_probs["in"]]] | |
sum_exp_logprobs = sum(exp_logprobs) | |
softmax_probs = [x / sum_exp_logprobs for x in exp_logprobs] | |
final_judge_probs_list.append({ | |
"yes": softmax_probs[0], | |
"no": softmax_probs[1], | |
"in": softmax_probs[2] | |
}) | |
return final_judge_probs_list | |
def generate_probs(self, model_input: dict, prompt_type: str): | |
total_cost = 0 | |
response_list = [] | |
# prepare message | |
message = self.prepare_message(model_input, prompt_type) | |
# print(message) | |
for i in range(self.num_generate): | |
try: | |
response = self.llm.invoke(message) | |
total_input_tokens = response.response_metadata["token_usage"]["prompt_tokens"] | |
total_output_tokens = response.response_metadata["token_usage"]["completion_tokens"] | |
total_cost = self.input_token_cost * total_input_tokens / 1000000 + self.output_token_cost * total_output_tokens / 1000000 | |
logprobs = response.response_metadata["logprobs"]["content"] | |
response_list.append( | |
{ | |
"response": response.content, | |
"judge_probs": self.get_judge_probs(logprobs) | |
} | |
) | |
except Exception as e: | |
print(f"Error: {e}") | |
# print(response.response_metadata["logprobs"]) | |
response_list.append( | |
{ | |
"response": response.content, | |
"judge_probs": [] | |
} | |
) | |
return response_list, total_cost | |
class ChecklistGenerationAgent(BaseAgent): | |
def __init__(self, agent_config: dict): | |
super().__init__(agent_config) | |
self._setup() | |
def prepare_message(self, model_input: dict, prompt_type): | |
message = get_messages( | |
input_info=model_input, | |
inference_mode="checklist_generation", | |
prompt_type=prompt_type | |
) | |
return message | |
class ClassifierRewardAgent(Agent): | |
def __init__(self, url: str, use_checklist: bool = False, use_multimodal: bool = False): | |
self.url = url | |
self.use_checklist = use_checklist | |
self.use_multimodal = use_multimodal | |
def _process_multimodal_message(self, prompt: str, image_list: list[str]): | |
multimodal_message = [] | |
text_prompt_prefix = prompt.split("<IMAGE_PLACEHOLDER>")[0] | |
text_prompt_suffix = prompt.split("<IMAGE_PLACEHOLDER>")[1] | |
multimodal_message = [ | |
{"type": "text", "text": text_prompt_prefix}, | |
# {"type": "image_url", "image_url": {"url": image_to_base64_url(image_list[0])}}, | |
{"type": "image", "image": image_to_base64_url(image_list[0])}, | |
{"type": "text", "text": text_prompt_suffix} | |
] | |
return multimodal_message | |
def _make_query(self, user_prompt_template: dict, model_input: dict | list[dict]): | |
if self.use_multimodal: | |
tmp_user_prompt = user_prompt_template["user"].format( | |
**model_input | |
) | |
user_prompt = self._process_multimodal_message(tmp_user_prompt, model_input["image_list"]) | |
else: | |
user_prompt = user_prompt_template["user"].format( | |
**model_input | |
) | |
assistant_prompt = user_prompt_template["assistant"].format( | |
**model_input | |
) | |
query = [ | |
{"role": "user", "content": user_prompt}, | |
{"role": "assistant", "content": assistant_prompt} | |
] | |
return query | |
def prepare_message(self, model_input: dict | list[dict], batch: bool = False): | |
if self.use_checklist: | |
if self.use_multimodal: | |
user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_PROMPT_TEMPLATE | |
else: | |
user_prompt_template = JUDGE_OURS_BT_MODELING_PROMPT_TEMPLATE | |
else: | |
if self.use_multimodal: | |
user_prompt_template = JUDGE_OURS_BT_MODELING_MULTIMODAL_WO_CHECKLIST_PROMPT_TEMPLATE | |
else: | |
user_prompt_template = JUDGE_OURS_BT_MODELING_WO_CHECKLIST_PROMPT_TEMPLATE | |
if self.use_multimodal: | |
if batch: | |
message = [self._make_query(user_prompt_template, input) for input in model_input] | |
else: | |
message = [self._make_query(user_prompt_template, model_input)] | |
else: | |
if batch: | |
message = { | |
"query": [self._make_query(user_prompt_template, input) for input in model_input], | |
"promptts": [] | |
} | |
else: | |
message = { | |
"query": self._make_query(user_prompt_template, model_input), | |
"prompts": [] | |
} | |
return message | |
def get_rm_scroe(self, message: dict | list): | |
headers = {"Content-Type": "application/json"} | |
try: | |
if self.use_multimodal: | |
response = requests.post( | |
self.url, | |
json={"messages": message}, | |
timeout=600 | |
) | |
else: | |
response = requests.post( | |
self.url, | |
headers=headers, | |
data=json.dumps(message), | |
timeout=300 | |
) | |
response.raise_for_status() | |
response_json = response.json() | |
if "rewards" not in response_json: | |
print(f"Error: 'rewards' key not found in API response: {response_json}") | |
return [] | |
if "get_reward" in self.url: | |
# use openrlhf | |
return response_json["rewards"] | |
elif "pooling" in self.url: | |
# use vllm server | |
return response_json["reward"] | |
else: | |
# error | |
raise ValueError(f"Invalid URL: {self.url}") | |
except requests.exceptions.Timeout: | |
print(f"Error: Request timed out to {self.url}") | |
return [] | |
except requests.exceptions.RequestException as e: | |
print(f"Error during request to {self.url}: {e}") | |
return [] | |
except json.JSONDecodeError: | |
print(f"Error: Failed to decode JSON response from {self.url}") | |
return [] | |
except KeyError as e: | |
print(f"Error: Missing key {e} in response from {self.url}") | |
return [] | |
def generate_response(self, model_input: dict | list[dict], batch: bool = False): | |
if batch: | |
message = self.prepare_message(model_input, batch=True) | |
else: | |
message = self.prepare_message(model_input) | |
rewards = self.get_rm_scroe(message) | |
return rewards, 0 |