import json from typing import Any, List from agents import Agent, OpenAIChatCompletionsModel, Runner from agents.agent_output import AgentOutputSchemaBase from openai import AsyncOpenAI from config.global_storage import get_model_config from utils.bio_logger import bio_logger as logger from typing import List, Dict from pydantic import BaseModel, Field,ConfigDict class DateRange(BaseModel): # model_config = ConfigDict(strict=True) model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["start", "end"]}) start: str = Field('', description="Start date in YYYY-MM-DD format") end: str = Field('', description="End date in YYYY-MM-DD format") class Journal(BaseModel): # model_config = ConfigDict(strict=True) model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["name", "EISSN"]}) name: str = Field(..., description="Journal name") EISSN: str = Field(..., description="Journal EISSN") class AuthorFilter(BaseModel): # model_config = ConfigDict(strict=True) model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["name", "first_author", "last_author"]}) name: str = Field("", description="Author name to filter") first_author: bool = Field(False, description="Is first author?") last_author: bool = Field(False, description="Is last author?") class Filters(BaseModel): # model_config = ConfigDict(strict=True) model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["date_range", "article_types", "languages", "subjects", "journals", "author"]}) date_range: DateRange = Field(...,default_factory=DateRange) article_types: List[str] = Field(...,default_factory=list) languages: List[str] = Field(["English"],) subjects: List[str] = Field(...,default_factory=list) journals: List[str] = Field([""]) author: AuthorFilter = Field(...,default_factory=AuthorFilter) class RewriteJsonOutput(BaseModel): model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["category", "key_words", "key_journals", "queries", "filters"]}) category: str = Field(..., description="Query category") key_words: List[str] = Field(...,default_factory=list) key_journals: List[Journal] = Field(...,default_factory=list) queries: List[str] = Field(...,default_factory=list) filters: Filters = Field(...,default_factory=Filters) class SimpleJsonOutput(BaseModel): key_words: List[str] = Field(...,default_factory=list) class RewriteJsonOutputSchema(AgentOutputSchemaBase): def is_plain_text(self): return False def name(self): return "RewriteJsonOutput" def json_schema(self): return RewriteJsonOutput.model_json_schema() def is_strict_json_schema(self): return True def validate_json(self, json_data: Dict[str, Any]) -> bool: try: if isinstance(json_data, str): json_data = json.loads(json_data) return RewriteJsonOutput.model_validate(json_data) except Exception as e: logger.error(f"Validation error: {e}") # return False def parse(self, json_data: Dict[str, Any]) -> Any: if isinstance(json_data, str): json_data = json.loads(json_data) return json_data class RewriteAgent: def __init__(self): self.model_config = get_model_config() self.agent_name = "rewrite agent" self.selected_model = OpenAIChatCompletionsModel( model=self.model_config["rewrite-llm"]["main"]["model"], openai_client=AsyncOpenAI( api_key=self.model_config["rewrite-llm"]["main"]["api_key"], base_url=self.model_config["rewrite-llm"]["main"]["base_url"], timeout=120.0, max_retries=2, ), ) # self.openai_client = AsyncOpenAI( # api_key=self.model_config["llm"]["api_key"], # base_url=self.model_config["llm"]["base_url"], # ) async def rewrite_query(self, query: str,INSTRUCTIONS: str,simple_version=False) -> List[str]: try: logger.info(f"Rewriting query with main configuration.") if not simple_version: rewrite_agent = Agent( name=self.agent_name, instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.', model=self.selected_model, output_type=RewriteJsonOutputSchema(), # Use the Pydantic model for structured output ) else: rewrite_agent = Agent( name=self.agent_name, instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.', model=self.selected_model, output_type=SimpleJsonOutput, # Use the Pydantic model for structured output ) result = await Runner.run(rewrite_agent, input=INSTRUCTIONS + 'Here is the question: '+query) # completion = await self.openai_client.chat.completions.create( # model=self.model_config["llm"]["model"], # messages=[ # # { # # "role": "system", # # "content": "You are a helpful assistant.", # # }, # { # "role": "user", # "content": INSTRUCTIONS +' Here is the question: ' + query, # }, # ], # temperature=self.model_config["llm"]["temperature"], # # max_tokens=self.model_config["llm"]["max_tokens"], # ) try: # query_result = self.parse_json_output(completion.choices[0].message.content) query_result = self.parse_json_output(result.final_output.model_dump_json()) # query_result = self.parse_json_output(completion.model_dump_json()) except Exception as e: # print(completion.choices[0].message.content) logger.error(f"Failed to parse JSON output: {e}") return query_result except Exception as main_error: self.selected_model_backup = OpenAIChatCompletionsModel( model=self.model_config["rewrite-llm"]["backup"]["model"], openai_client=AsyncOpenAI( api_key=self.model_config["rewrite-llm"]["backup"]["api_key"], base_url=self.model_config["rewrite-llm"]["backup"]["base_url"], timeout=120.0, max_retries=2, ), ) logger.error(f"Error with main model: {main_error}", exc_info=main_error) logger.info("Trying backup model for rewriting query.") if not simple_version: rewrite_agent = Agent( name=self.agent_name, instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.', model=self.selected_model_backup, output_type=RewriteJsonOutputSchema(), # Use the Pydantic model for structured output ) else: rewrite_agent = Agent( name=self.agent_name, instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.', model=self.selected_model_backup, output_type=SimpleJsonOutput, # Use the Pydantic model for structured output ) result = await Runner.run(rewrite_agent, input=INSTRUCTIONS + 'Here is the question: '+query) # completion = await self.openai_client.chat.completions.create( # model=self.model_config["llm"]["model"], # messages=[ # # { # # "role": "system", # # "content": "You are a helpful assistant.", # # }, # { # "role": "user", # "content": INSTRUCTIONS +' Here is the question: ' + query, # }, # ], # temperature=self.model_config["llm"]["temperature"], # # max_tokens=self.model_config["llm"]["max_tokens"], # ) try: # query_result = self.parse_json_output(completion.choices[0].message.content) query_result = self.parse_json_output(result.final_output.model_dump_json()) # query_result = self.parse_json_output(completion.model_dump_json()) except Exception as e: # print(completion.choices[0].message.content) logger.error(f"Failed to parse JSON output: {e}") return query_result def parse_json_output(self, output: str) -> Any: """Take a string output and parse it as JSON""" # First try to load the string as JSON try: return json.loads(output) except json.JSONDecodeError as e: logger.info(f"Output is not valid JSON: {output}") logger.error(f"Failed to parse output as direct JSON: {e}") # If that fails, assume that the output is in a code block - remove the code block markers and try again parsed_output = output if "```" in parsed_output: try: parts = parsed_output.split("```") if len(parts) >= 3: parsed_output = parts[1] if parsed_output.startswith("json") or parsed_output.startswith( "JSON" ): parsed_output = parsed_output[4:].strip() return json.loads(parsed_output) except (IndexError, json.JSONDecodeError) as e: logger.error(f"Failed to parse output from code block: {e}") # As a last attempt, try to manually find the JSON object in the output and parse it parsed_output = self.find_json_in_string(output) if parsed_output: try: return json.loads(parsed_output) except json.JSONDecodeError as e: logger.error(f"Failed to parse extracted JSON: {e}") logger.error(f"Extracted JSON: {parsed_output}") return {"queries": []} else: logger.error("No valid JSON found in the output:{output}") # If all fails, raise an error return {"queries": []} def find_json_in_string(self, string: str) -> str: """ Method to extract all text in the left-most brace that appears in a string. Used to extract JSON from a string (note that this function does not validate the JSON). Example: string = "bla bla bla {this is {some} text{{}and it's sneaky}} because {it's} confusing" output = "{this is {some} text{{}and it's sneaky}}" """ stack = 0 start_index = None for i, c in enumerate(string): if c == "{": if stack == 0: start_index = i # Start index of the first '{' stack += 1 # Push to stack elif c == "}": stack -= 1 # Pop stack if stack == 0: # Return the substring from the start of the first '{' to the current '}' return ( string[start_index : i + 1] if start_index is not None else "" ) # If no complete set of braces is found, return an empty string return ""