File size: 1,697 Bytes
5ef0f8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from fastapi import APIRouter, Depends, HTTPException
from litellm.router import Router
from dependencies import get_llm_router
from schemas import ReqSearchLLMResponse, ReqSearchRequest, ReqSearchResponse

# Router for all requirements
router = APIRouter()


@router.post("/get_reqs_from_query", response_model=ReqSearchResponse)
def find_requirements_from_problem_description(req: ReqSearchRequest, llm_router: Router = Depends(get_llm_router)):
    """Finds the requirements that adress a given problem description from an extracted list"""

    requirements = req.requirements
    query = req.query

    requirements_text = "\n".join(
        [f"[Selection ID: {r.req_id} | Document: {r.document} | Context: {r.context} | Requirement: {r.requirement}]" for r in requirements])
    print("Called the LLM")
    resp_ai = llm_router.completion(
        model="gemini-v2",
        messages=[{"role": "user", "content": f"Given all the requirements : \n {requirements_text} \n and the problem description \"{query}\", return a list of 'Selection ID' for the most relevant corresponding requirements that reference or best cover the problem. If none of the requirements covers the problem, simply return an empty list"}],
        response_format=ReqSearchLLMResponse
    )
    print("Answered")
    print(resp_ai.choices[0].message.content)

    out_llm = ReqSearchLLMResponse.model_validate_json(
        resp_ai.choices[0].message.content).selected

    if max(out_llm) > len(requirements) - 1:
        raise HTTPException(
            status_code=500, detail="LLM error : Generated a wrong index, please try again.")

    return ReqSearchResponse(requirements=[requirements[i] for i in out_llm])