File size: 2,651 Bytes
b5c4d16
 
99e2b34
b5c4d16
 
99e2b34
b5c4d16
f897f8d
 
99e2b34
b5c4d16
 
99e2b34
b5c4d16
 
99e2b34
 
b5c4d16
 
99e2b34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5c4d16
99e2b34
b5c4d16
 
 
 
 
 
 
99e2b34
b5c4d16
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from langchain_community.retrievers import BM25Retriever
from mistralai import Mistral
import base64

class PlantUMLAgent:
    def __init__(self, docs_processed, api_key):
        self.docs_processed = docs_processed
        self.client = Mistral(api_key=api_key)

    def _retrieve_diagram_info(self, diagram_name):
        diagram_name = diagram_name.replace(" ", "-").lower()
        filtered_docs = [doc for doc in self.docs_processed if diagram_name.lower() in doc.metadata["source"].lower()]
        
        retriever = BM25Retriever.from_documents(filtered_docs, k=10)
        
        documents = retriever.invoke(diagram_name)
        
        documents_texts = [doc.page_content for doc in documents]
        
        return "\n".join(documents_texts)
    
    def recognize_image(self, diagram_name, image_input):
        with open(image_input, "rb") as image_file:
            image_base64 = base64.b64encode(image_file.read()).decode("utf-8")
            
        extension = image_input.split(".")[-1]
        
        image_url = f"data:image/{extension};base64,{image_base64}"
        
        chat_response = self.client.chat.complete(
            model="pixtral-12b-2409",
            messages=[
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": f"Analyze the provided image and extract relevant information for create a {diagram_name}."
                        },
                        {
                            "type": "image_url",
                            "image_url": image_url
                        }
                    ]
                }
            ]     
        )
        
        response = chat_response.choices[0].message.content
        return response.strip()
        
    def predict(self, diagram_name, message):
        documents_texts = self._retrieve_diagram_info(diagram_name)
        
        chat_response = self.client.chat.complete(
            model="mistral-small-latest",
            messages=[
                {
                    "role": "system",
                    "content": "You are a PlantUML code generator. Generate PlantUML diagram code based on the provided description."
                },
                {
                    "role": "user",
                    "content": f"{message}\n\nUse the following documents as reference:\n{documents_texts}\n\nReturn the output between ```plantuml and ```."
                }
            ]
        )
        response = chat_response.choices[0].message.content
        return response.strip()