jackkuo's picture
add QA
79899c0
import json
from typing import Any, List
from agents import Agent, OpenAIChatCompletionsModel, Runner
from agents.agent_output import AgentOutputSchemaBase
from openai import AsyncOpenAI
from config.global_storage import get_model_config
from utils.bio_logger import bio_logger as logger
from typing import List, Dict
from pydantic import BaseModel, Field,ConfigDict
class DateRange(BaseModel):
# model_config = ConfigDict(strict=True)
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["start", "end"]})
start: str = Field('', description="Start date in YYYY-MM-DD format")
end: str = Field('', description="End date in YYYY-MM-DD format")
class Journal(BaseModel):
# model_config = ConfigDict(strict=True)
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["name", "EISSN"]})
name: str = Field(..., description="Journal name")
EISSN: str = Field(..., description="Journal EISSN")
class AuthorFilter(BaseModel):
# model_config = ConfigDict(strict=True)
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["name", "first_author", "last_author"]})
name: str = Field("", description="Author name to filter")
first_author: bool = Field(False, description="Is first author?")
last_author: bool = Field(False, description="Is last author?")
class Filters(BaseModel):
# model_config = ConfigDict(strict=True)
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["date_range", "article_types", "languages", "subjects", "journals", "author"]})
date_range: DateRange = Field(...,default_factory=DateRange)
article_types: List[str] = Field(...,default_factory=list)
languages: List[str] = Field(["English"],)
subjects: List[str] = Field(...,default_factory=list)
journals: List[str] = Field([""])
author: AuthorFilter = Field(...,default_factory=AuthorFilter)
class RewriteJsonOutput(BaseModel):
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["category", "key_words", "key_journals", "queries", "filters"]})
category: str = Field(..., description="Query category")
key_words: List[str] = Field(...,default_factory=list)
key_journals: List[Journal] = Field(...,default_factory=list)
queries: List[str] = Field(...,default_factory=list)
filters: Filters = Field(...,default_factory=Filters)
class SimpleJsonOutput(BaseModel):
key_words: List[str] = Field(...,default_factory=list)
class RewriteJsonOutputSchema(AgentOutputSchemaBase):
def is_plain_text(self):
return False
def name(self):
return "RewriteJsonOutput"
def json_schema(self):
return RewriteJsonOutput.model_json_schema()
def is_strict_json_schema(self):
return True
def validate_json(self, json_data: Dict[str, Any]) -> bool:
try:
if isinstance(json_data, str):
json_data = json.loads(json_data)
return RewriteJsonOutput.model_validate(json_data)
except Exception as e:
logger.error(f"Validation error: {e}")
# return False
def parse(self, json_data: Dict[str, Any]) -> Any:
if isinstance(json_data, str):
json_data = json.loads(json_data)
return json_data
class RewriteAgent:
def __init__(self):
self.model_config = get_model_config()
self.agent_name = "rewrite agent"
self.selected_model = OpenAIChatCompletionsModel(
model=self.model_config["rewrite-llm"]["main"]["model"],
openai_client=AsyncOpenAI(
api_key=self.model_config["rewrite-llm"]["main"]["api_key"],
base_url=self.model_config["rewrite-llm"]["main"]["base_url"],
timeout=120.0,
max_retries=2,
),
)
# self.openai_client = AsyncOpenAI(
# api_key=self.model_config["llm"]["api_key"],
# base_url=self.model_config["llm"]["base_url"],
# )
async def rewrite_query(self, query: str,INSTRUCTIONS: str,simple_version=False) -> List[str]:
try:
logger.info(f"Rewriting query with main configuration.")
if not simple_version:
rewrite_agent = Agent(
name=self.agent_name,
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.',
model=self.selected_model,
output_type=RewriteJsonOutputSchema(), # Use the Pydantic model for structured output
)
else:
rewrite_agent = Agent(
name=self.agent_name,
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.',
model=self.selected_model,
output_type=SimpleJsonOutput, # Use the Pydantic model for structured output
)
result = await Runner.run(rewrite_agent, input=INSTRUCTIONS + 'Here is the question: '+query)
# completion = await self.openai_client.chat.completions.create(
# model=self.model_config["llm"]["model"],
# messages=[
# # {
# # "role": "system",
# # "content": "You are a helpful assistant.",
# # },
# {
# "role": "user",
# "content": INSTRUCTIONS +' Here is the question: ' + query,
# },
# ],
# temperature=self.model_config["llm"]["temperature"],
# # max_tokens=self.model_config["llm"]["max_tokens"],
# )
try:
# query_result = self.parse_json_output(completion.choices[0].message.content)
query_result = self.parse_json_output(result.final_output.model_dump_json())
# query_result = self.parse_json_output(completion.model_dump_json())
except Exception as e:
# print(completion.choices[0].message.content)
logger.error(f"Failed to parse JSON output: {e}")
return query_result
except Exception as main_error:
self.selected_model_backup = OpenAIChatCompletionsModel(
model=self.model_config["rewrite-llm"]["backup"]["model"],
openai_client=AsyncOpenAI(
api_key=self.model_config["rewrite-llm"]["backup"]["api_key"],
base_url=self.model_config["rewrite-llm"]["backup"]["base_url"],
timeout=120.0,
max_retries=2,
),
)
logger.error(f"Error with main model: {main_error}", exc_info=main_error)
logger.info("Trying backup model for rewriting query.")
if not simple_version:
rewrite_agent = Agent(
name=self.agent_name,
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.',
model=self.selected_model_backup,
output_type=RewriteJsonOutputSchema(), # Use the Pydantic model for structured output
)
else:
rewrite_agent = Agent(
name=self.agent_name,
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.',
model=self.selected_model_backup,
output_type=SimpleJsonOutput, # Use the Pydantic model for structured output
)
result = await Runner.run(rewrite_agent, input=INSTRUCTIONS + 'Here is the question: '+query)
# completion = await self.openai_client.chat.completions.create(
# model=self.model_config["llm"]["model"],
# messages=[
# # {
# # "role": "system",
# # "content": "You are a helpful assistant.",
# # },
# {
# "role": "user",
# "content": INSTRUCTIONS +' Here is the question: ' + query,
# },
# ],
# temperature=self.model_config["llm"]["temperature"],
# # max_tokens=self.model_config["llm"]["max_tokens"],
# )
try:
# query_result = self.parse_json_output(completion.choices[0].message.content)
query_result = self.parse_json_output(result.final_output.model_dump_json())
# query_result = self.parse_json_output(completion.model_dump_json())
except Exception as e:
# print(completion.choices[0].message.content)
logger.error(f"Failed to parse JSON output: {e}")
return query_result
def parse_json_output(self, output: str) -> Any:
"""Take a string output and parse it as JSON"""
# First try to load the string as JSON
try:
return json.loads(output)
except json.JSONDecodeError as e:
logger.info(f"Output is not valid JSON: {output}")
logger.error(f"Failed to parse output as direct JSON: {e}")
# If that fails, assume that the output is in a code block - remove the code block markers and try again
parsed_output = output
if "```" in parsed_output:
try:
parts = parsed_output.split("```")
if len(parts) >= 3:
parsed_output = parts[1]
if parsed_output.startswith("json") or parsed_output.startswith(
"JSON"
):
parsed_output = parsed_output[4:].strip()
return json.loads(parsed_output)
except (IndexError, json.JSONDecodeError) as e:
logger.error(f"Failed to parse output from code block: {e}")
# As a last attempt, try to manually find the JSON object in the output and parse it
parsed_output = self.find_json_in_string(output)
if parsed_output:
try:
return json.loads(parsed_output)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse extracted JSON: {e}")
logger.error(f"Extracted JSON: {parsed_output}")
return {"queries": []}
else:
logger.error("No valid JSON found in the output:{output}")
# If all fails, raise an error
return {"queries": []}
def find_json_in_string(self, string: str) -> str:
"""
Method to extract all text in the left-most brace that appears in a string.
Used to extract JSON from a string (note that this function does not validate the JSON).
Example:
string = "bla bla bla {this is {some} text{{}and it's sneaky}} because {it's} confusing"
output = "{this is {some} text{{}and it's sneaky}}"
"""
stack = 0
start_index = None
for i, c in enumerate(string):
if c == "{":
if stack == 0:
start_index = i # Start index of the first '{'
stack += 1 # Push to stack
elif c == "}":
stack -= 1 # Pop stack
if stack == 0:
# Return the substring from the start of the first '{' to the current '}'
return (
string[start_index : i + 1] if start_index is not None else ""
)
# If no complete set of braces is found, return an empty string
return ""