plantuml-agent / src /agent.py
Sávio Santos
added input input (optional)
99e2b34
from langchain_community.retrievers import BM25Retriever
from mistralai import Mistral
import base64
class PlantUMLAgent:
def __init__(self, docs_processed, api_key):
self.docs_processed = docs_processed
self.client = Mistral(api_key=api_key)
def _retrieve_diagram_info(self, diagram_name):
diagram_name = diagram_name.replace(" ", "-").lower()
filtered_docs = [doc for doc in self.docs_processed if diagram_name.lower() in doc.metadata["source"].lower()]
retriever = BM25Retriever.from_documents(filtered_docs, k=10)
documents = retriever.invoke(diagram_name)
documents_texts = [doc.page_content for doc in documents]
return "\n".join(documents_texts)
def recognize_image(self, diagram_name, image_input):
with open(image_input, "rb") as image_file:
image_base64 = base64.b64encode(image_file.read()).decode("utf-8")
extension = image_input.split(".")[-1]
image_url = f"data:image/{extension};base64,{image_base64}"
chat_response = self.client.chat.complete(
model="pixtral-12b-2409",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": f"Analyze the provided image and extract relevant information for create a {diagram_name}."
},
{
"type": "image_url",
"image_url": image_url
}
]
}
]
)
response = chat_response.choices[0].message.content
return response.strip()
def predict(self, diagram_name, message):
documents_texts = self._retrieve_diagram_info(diagram_name)
chat_response = self.client.chat.complete(
model="mistral-small-latest",
messages=[
{
"role": "system",
"content": "You are a PlantUML code generator. Generate PlantUML diagram code based on the provided description."
},
{
"role": "user",
"content": f"{message}\n\nUse the following documents as reference:\n{documents_texts}\n\nReturn the output between ```plantuml and ```."
}
]
)
response = chat_response.choices[0].message.content
return response.strip()