Spaces:
Restarting
Restarting
import asyncio | |
import json | |
import logging | |
from fastapi import APIRouter, Depends, HTTPException, Response | |
from httpx import AsyncClient | |
from jinja2 import Environment, TemplateNotFound | |
from litellm.router import Router | |
from dependencies import INSIGHT_FINDER_BASE_URL, get_http_client, get_llm_router, get_prompt_templates | |
from typing import Awaitable, Callable, TypeVar | |
from schemas import _RefinedSolutionModel, _BootstrappedSolutionModel, _SolutionCriticismOutput, CriticizeSolutionsRequest, CritiqueResponse, InsightFinderConstraintsList, PriorArtSearchRequest, PriorArtSearchResponse, ReqGroupingCategory, ReqGroupingRequest, ReqGroupingResponse, ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse, SolutionCriticism, SolutionModel, SolutionBootstrapResponse, SolutionBootstrapRequest, TechnologyData | |
# Router for solution generation and critique | |
router = APIRouter(tags=["solution generation and critique"]) | |
# ============== utilities =========================== | |
T = TypeVar("T") | |
A = TypeVar("A") | |
async def retry_until( | |
func: Callable[[A], Awaitable[T]], | |
arg: A, | |
predicate: Callable[[T], bool], | |
max_retries: int, | |
) -> T: | |
"""Retries the given async function until the passed in validation predicate returns true.""" | |
last_value = await func(arg) | |
for _ in range(max_retries): | |
if predicate(last_value): | |
return last_value | |
last_value = await func(arg) | |
return last_value | |
# =================================================== Search solutions ============================================================================ | |
async def bootstrap_solutions(req: SolutionBootstrapRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router), http_client: AsyncClient = Depends(get_http_client)) -> SolutionBootstrapResponse: | |
""" | |
Boostraps a solution for each of the passed in requirements categories using Insight Finder's API. | |
""" | |
async def _bootstrap_solution_inner(cat: ReqGroupingCategory): | |
# process requirements into insight finder format | |
fmt_completion = await llm_router.acompletion("gemini-v2", messages=[ | |
{ | |
"role": "user", | |
"content": await prompt_env.get_template("format_requirements.txt").render_async(**{ | |
"category": cat.model_dump(), | |
"response_schema": InsightFinderConstraintsList.model_json_schema() | |
}) | |
}], response_format=InsightFinderConstraintsList) | |
fmt_model = InsightFinderConstraintsList.model_validate_json( | |
fmt_completion.choices[0].message.content) | |
# translate from a structured output to a dict for insights finder | |
formatted_constraints = {'constraints': { | |
cons.title: cons.description for cons in fmt_model.constraints}} | |
# fetch technologies from insight finder | |
technologies_req = await http_client.post(INSIGHT_FINDER_BASE_URL + "process-constraints", content=json.dumps(formatted_constraints)) | |
technologies = TechnologyData.model_validate(technologies_req.json()) | |
# =============================================================== synthesize solution using LLM ========================================= | |
format_solution = await llm_router.acompletion("gemini-v2", messages=[{ | |
"role": "user", | |
"content": await prompt_env.get_template("bootstrap_solution.txt").render_async(**{ | |
"category": cat.model_dump(), | |
"technologies": technologies.model_dump()["technologies"], | |
"user_constraints": req.user_constraints, | |
"response_schema": _BootstrappedSolutionModel.model_json_schema() | |
})} | |
], response_format=_BootstrappedSolutionModel) | |
format_solution_model = _BootstrappedSolutionModel.model_validate_json( | |
format_solution.choices[0].message.content) | |
final_solution = SolutionModel( | |
context="", | |
requirements=[ | |
cat.requirements[i].requirement for i in format_solution_model.requirement_ids | |
], | |
problem_description=format_solution_model.problem_description, | |
solution_description=format_solution_model.solution_description, | |
references=[], | |
category_id=cat.id, | |
) | |
# ======================================================================================================================================== | |
return final_solution | |
tasks = await asyncio.gather(*[_bootstrap_solution_inner(cat) for cat in req.categories], return_exceptions=True) | |
final_solutions = [sol for sol in tasks if not isinstance(sol, Exception)] | |
return SolutionBootstrapResponse(solutions=final_solutions) | |
async def criticize_solution(params: CriticizeSolutionsRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> CritiqueResponse: | |
"""Criticize the challenges, weaknesses and limitations of the provided solutions.""" | |
async def __criticize_single(solution: SolutionModel): | |
req_prompt = await prompt_env.get_template("criticize.txt").render_async(**{ | |
"solutions": [solution.model_dump()], | |
"response_schema": _SolutionCriticismOutput.model_json_schema() | |
}) | |
req_completion = await llm_router.acompletion( | |
model="gemini-v2", | |
messages=[{"role": "user", "content": req_prompt}], | |
response_format=_SolutionCriticismOutput | |
) | |
criticism_out = _SolutionCriticismOutput.model_validate_json( | |
req_completion.choices[0].message.content | |
) | |
return SolutionCriticism(solution=solution, criticism=criticism_out.criticisms[0]) | |
critiques = await asyncio.gather(*[__criticize_single(sol) for sol in params.solutions], return_exceptions=False) | |
return CritiqueResponse(critiques=critiques) | |
# =================================================================== Refine solution ==================================== | |
async def refine_solutions(params: CritiqueResponse, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> SolutionBootstrapResponse: | |
"""Refines the previously critiqued solutions.""" | |
async def __refine_solution(crit: SolutionCriticism): | |
req_prompt = await prompt_env.get_template("refine_solution.txt").render_async(**{ | |
"solution": crit.solution.model_dump(), | |
"criticism": crit.criticism, | |
"response_schema": _RefinedSolutionModel.model_json_schema(), | |
}) | |
req_completion = await llm_router.acompletion(model="gemini-v2", messages=[ | |
{"role": "user", "content": req_prompt} | |
], response_format=_RefinedSolutionModel) | |
req_model = _RefinedSolutionModel.model_validate_json( | |
req_completion.choices[0].message.content) | |
# copy previous solution model | |
refined_solution = crit.solution.model_copy(deep=True) | |
refined_solution.problem_description = req_model.problem_description | |
refined_solution.solution_description = req_model.solution_description | |
return refined_solution | |
refined_solutions = await asyncio.gather(*[__refine_solution(crit) for crit in params.critiques], return_exceptions=False) | |
return SolutionBootstrapResponse(solutions=refined_solutions) | |
async def search_prior_art(req: PriorArtSearchRequest, prompt_env: Environment = Depends(get_prompt_templates), llm_router: Router = Depends(get_llm_router)) -> PriorArtSearchResponse: | |
"""Performs a comprehensive prior art search / FTO search against the provided topics for a drafted solution""" | |
sema = asyncio.Semaphore(4) | |
async def __search_topic(topic: str) -> str: | |
search_prompt = await prompt_env.get_template("search/search_topic.txt").render_async(**{ | |
"topic": topic | |
}) | |
try: | |
await sema.acquire() | |
search_completion = await llm_router.acompletion(model="gemini-v2", messages=[ | |
{"role": "user", "content": search_prompt} | |
], temperature=0.3, tools=[{"googleSearch": {}}]) | |
return {"topic": topic, "content": search_completion.choices[0].message.content} | |
finally: | |
sema.release() | |
# Dispatch the individual tasks for topic search | |
topics = await asyncio.gather(*[__search_topic(top) for top in req.topics], return_exceptions=False) | |
consolidation_prompt = await prompt_env.get_template("search/build_final_report.txt").render_async(**{ | |
"searches": topics | |
}) | |
# Then consolidate everything into a single detailed topic | |
consolidation_completion = await llm_router.acompletion(model="gemini-v2", messages=[ | |
{"role": "user", "content": consolidation_prompt} | |
], temperature=0.5) | |
return PriorArtSearchResponse(content=consolidation_completion.choices[0].message.content, references=[]) | |