|
from langchain_community.retrievers import BM25Retriever |
|
from mistralai import Mistral |
|
import base64 |
|
|
|
class PlantUMLAgent: |
|
def __init__(self, docs_processed, api_key): |
|
self.docs_processed = docs_processed |
|
self.client = Mistral(api_key=api_key) |
|
|
|
def _retrieve_diagram_info(self, diagram_name): |
|
diagram_name = diagram_name.replace(" ", "-").lower() |
|
filtered_docs = [doc for doc in self.docs_processed if diagram_name.lower() in doc.metadata["source"].lower()] |
|
|
|
retriever = BM25Retriever.from_documents(filtered_docs, k=10) |
|
|
|
documents = retriever.invoke(diagram_name) |
|
|
|
documents_texts = [doc.page_content for doc in documents] |
|
|
|
return "\n".join(documents_texts) |
|
|
|
def recognize_image(self, diagram_name, image_input): |
|
with open(image_input, "rb") as image_file: |
|
image_base64 = base64.b64encode(image_file.read()).decode("utf-8") |
|
|
|
extension = image_input.split(".")[-1] |
|
|
|
image_url = f"data:image/{extension};base64,{image_base64}" |
|
|
|
chat_response = self.client.chat.complete( |
|
model="pixtral-12b-2409", |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": f"Analyze the provided image and extract relevant information for create a {diagram_name}." |
|
}, |
|
{ |
|
"type": "image_url", |
|
"image_url": image_url |
|
} |
|
] |
|
} |
|
] |
|
) |
|
|
|
response = chat_response.choices[0].message.content |
|
return response.strip() |
|
|
|
def predict(self, diagram_name, message): |
|
documents_texts = self._retrieve_diagram_info(diagram_name) |
|
|
|
chat_response = self.client.chat.complete( |
|
model="mistral-small-latest", |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": "You are a PlantUML code generator. Generate PlantUML diagram code based on the provided description." |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"{message}\n\nUse the following documents as reference:\n{documents_texts}\n\nReturn the output between ```plantuml and ```." |
|
} |
|
] |
|
) |
|
response = chat_response.choices[0].message.content |
|
return response.strip() |