Spaces:
Sleeping
Sleeping
import json | |
from typing import Any, List | |
from agents import Agent, OpenAIChatCompletionsModel, Runner | |
from agents.agent_output import AgentOutputSchemaBase | |
from openai import AsyncOpenAI | |
from config.global_storage import get_model_config | |
from utils.bio_logger import bio_logger as logger | |
from typing import List, Dict | |
from pydantic import BaseModel, Field,ConfigDict | |
class DateRange(BaseModel): | |
# model_config = ConfigDict(strict=True) | |
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["start", "end"]}) | |
start: str = Field('', description="Start date in YYYY-MM-DD format") | |
end: str = Field('', description="End date in YYYY-MM-DD format") | |
class Journal(BaseModel): | |
# model_config = ConfigDict(strict=True) | |
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["name", "EISSN"]}) | |
name: str = Field(..., description="Journal name") | |
EISSN: str = Field(..., description="Journal EISSN") | |
class AuthorFilter(BaseModel): | |
# model_config = ConfigDict(strict=True) | |
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["name", "first_author", "last_author"]}) | |
name: str = Field("", description="Author name to filter") | |
first_author: bool = Field(False, description="Is first author?") | |
last_author: bool = Field(False, description="Is last author?") | |
class Filters(BaseModel): | |
# model_config = ConfigDict(strict=True) | |
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["date_range", "article_types", "languages", "subjects", "journals", "author"]}) | |
date_range: DateRange = Field(...,default_factory=DateRange) | |
article_types: List[str] = Field(...,default_factory=list) | |
languages: List[str] = Field(["English"],) | |
subjects: List[str] = Field(...,default_factory=list) | |
journals: List[str] = Field([""]) | |
author: AuthorFilter = Field(...,default_factory=AuthorFilter) | |
class RewriteJsonOutput(BaseModel): | |
model_config = ConfigDict(strict=True, extra="forbid",json_schema_extra={"required": ["category", "key_words", "key_journals", "queries", "filters"]}) | |
category: str = Field(..., description="Query category") | |
key_words: List[str] = Field(...,default_factory=list) | |
key_journals: List[Journal] = Field(...,default_factory=list) | |
queries: List[str] = Field(...,default_factory=list) | |
filters: Filters = Field(...,default_factory=Filters) | |
class SimpleJsonOutput(BaseModel): | |
key_words: List[str] = Field(...,default_factory=list) | |
class RewriteJsonOutputSchema(AgentOutputSchemaBase): | |
def is_plain_text(self): | |
return False | |
def name(self): | |
return "RewriteJsonOutput" | |
def json_schema(self): | |
return RewriteJsonOutput.model_json_schema() | |
def is_strict_json_schema(self): | |
return True | |
def validate_json(self, json_data: Dict[str, Any]) -> bool: | |
try: | |
if isinstance(json_data, str): | |
json_data = json.loads(json_data) | |
return RewriteJsonOutput.model_validate(json_data) | |
except Exception as e: | |
logger.error(f"Validation error: {e}") | |
# return False | |
def parse(self, json_data: Dict[str, Any]) -> Any: | |
if isinstance(json_data, str): | |
json_data = json.loads(json_data) | |
return json_data | |
class RewriteAgent: | |
def __init__(self): | |
self.model_config = get_model_config() | |
self.agent_name = "rewrite agent" | |
self.selected_model = OpenAIChatCompletionsModel( | |
model=self.model_config["rewrite-llm"]["main"]["model"], | |
openai_client=AsyncOpenAI( | |
api_key=self.model_config["rewrite-llm"]["main"]["api_key"], | |
base_url=self.model_config["rewrite-llm"]["main"]["base_url"], | |
timeout=120.0, | |
max_retries=2, | |
), | |
) | |
# self.openai_client = AsyncOpenAI( | |
# api_key=self.model_config["llm"]["api_key"], | |
# base_url=self.model_config["llm"]["base_url"], | |
# ) | |
async def rewrite_query(self, query: str,INSTRUCTIONS: str,simple_version=False) -> List[str]: | |
try: | |
logger.info(f"Rewriting query with main configuration.") | |
if not simple_version: | |
rewrite_agent = Agent( | |
name=self.agent_name, | |
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.', | |
model=self.selected_model, | |
output_type=RewriteJsonOutputSchema(), # Use the Pydantic model for structured output | |
) | |
else: | |
rewrite_agent = Agent( | |
name=self.agent_name, | |
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.', | |
model=self.selected_model, | |
output_type=SimpleJsonOutput, # Use the Pydantic model for structured output | |
) | |
result = await Runner.run(rewrite_agent, input=INSTRUCTIONS + 'Here is the question: '+query) | |
# completion = await self.openai_client.chat.completions.create( | |
# model=self.model_config["llm"]["model"], | |
# messages=[ | |
# # { | |
# # "role": "system", | |
# # "content": "You are a helpful assistant.", | |
# # }, | |
# { | |
# "role": "user", | |
# "content": INSTRUCTIONS +' Here is the question: ' + query, | |
# }, | |
# ], | |
# temperature=self.model_config["llm"]["temperature"], | |
# # max_tokens=self.model_config["llm"]["max_tokens"], | |
# ) | |
try: | |
# query_result = self.parse_json_output(completion.choices[0].message.content) | |
query_result = self.parse_json_output(result.final_output.model_dump_json()) | |
# query_result = self.parse_json_output(completion.model_dump_json()) | |
except Exception as e: | |
# print(completion.choices[0].message.content) | |
logger.error(f"Failed to parse JSON output: {e}") | |
return query_result | |
except Exception as main_error: | |
self.selected_model_backup = OpenAIChatCompletionsModel( | |
model=self.model_config["rewrite-llm"]["backup"]["model"], | |
openai_client=AsyncOpenAI( | |
api_key=self.model_config["rewrite-llm"]["backup"]["api_key"], | |
base_url=self.model_config["rewrite-llm"]["backup"]["base_url"], | |
timeout=120.0, | |
max_retries=2, | |
), | |
) | |
logger.error(f"Error with main model: {main_error}", exc_info=main_error) | |
logger.info("Trying backup model for rewriting query.") | |
if not simple_version: | |
rewrite_agent = Agent( | |
name=self.agent_name, | |
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.', | |
model=self.selected_model_backup, | |
output_type=RewriteJsonOutputSchema(), # Use the Pydantic model for structured output | |
) | |
else: | |
rewrite_agent = Agent( | |
name=self.agent_name, | |
instructions=' Your task is to rewrite the query into a structured JSON format. Please do not answer the question.', | |
model=self.selected_model_backup, | |
output_type=SimpleJsonOutput, # Use the Pydantic model for structured output | |
) | |
result = await Runner.run(rewrite_agent, input=INSTRUCTIONS + 'Here is the question: '+query) | |
# completion = await self.openai_client.chat.completions.create( | |
# model=self.model_config["llm"]["model"], | |
# messages=[ | |
# # { | |
# # "role": "system", | |
# # "content": "You are a helpful assistant.", | |
# # }, | |
# { | |
# "role": "user", | |
# "content": INSTRUCTIONS +' Here is the question: ' + query, | |
# }, | |
# ], | |
# temperature=self.model_config["llm"]["temperature"], | |
# # max_tokens=self.model_config["llm"]["max_tokens"], | |
# ) | |
try: | |
# query_result = self.parse_json_output(completion.choices[0].message.content) | |
query_result = self.parse_json_output(result.final_output.model_dump_json()) | |
# query_result = self.parse_json_output(completion.model_dump_json()) | |
except Exception as e: | |
# print(completion.choices[0].message.content) | |
logger.error(f"Failed to parse JSON output: {e}") | |
return query_result | |
def parse_json_output(self, output: str) -> Any: | |
"""Take a string output and parse it as JSON""" | |
# First try to load the string as JSON | |
try: | |
return json.loads(output) | |
except json.JSONDecodeError as e: | |
logger.info(f"Output is not valid JSON: {output}") | |
logger.error(f"Failed to parse output as direct JSON: {e}") | |
# If that fails, assume that the output is in a code block - remove the code block markers and try again | |
parsed_output = output | |
if "```" in parsed_output: | |
try: | |
parts = parsed_output.split("```") | |
if len(parts) >= 3: | |
parsed_output = parts[1] | |
if parsed_output.startswith("json") or parsed_output.startswith( | |
"JSON" | |
): | |
parsed_output = parsed_output[4:].strip() | |
return json.loads(parsed_output) | |
except (IndexError, json.JSONDecodeError) as e: | |
logger.error(f"Failed to parse output from code block: {e}") | |
# As a last attempt, try to manually find the JSON object in the output and parse it | |
parsed_output = self.find_json_in_string(output) | |
if parsed_output: | |
try: | |
return json.loads(parsed_output) | |
except json.JSONDecodeError as e: | |
logger.error(f"Failed to parse extracted JSON: {e}") | |
logger.error(f"Extracted JSON: {parsed_output}") | |
return {"queries": []} | |
else: | |
logger.error("No valid JSON found in the output:{output}") | |
# If all fails, raise an error | |
return {"queries": []} | |
def find_json_in_string(self, string: str) -> str: | |
""" | |
Method to extract all text in the left-most brace that appears in a string. | |
Used to extract JSON from a string (note that this function does not validate the JSON). | |
Example: | |
string = "bla bla bla {this is {some} text{{}and it's sneaky}} because {it's} confusing" | |
output = "{this is {some} text{{}and it's sneaky}}" | |
""" | |
stack = 0 | |
start_index = None | |
for i, c in enumerate(string): | |
if c == "{": | |
if stack == 0: | |
start_index = i # Start index of the first '{' | |
stack += 1 # Push to stack | |
elif c == "}": | |
stack -= 1 # Pop stack | |
if stack == 0: | |
# Return the substring from the start of the first '{' to the current '}' | |
return ( | |
string[start_index : i + 1] if start_index is not None else "" | |
) | |
# If no complete set of braces is found, return an empty string | |
return "" | |