File size: 6,465 Bytes
7b5ebc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488c021
7b5ebc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d39fca
7b5ebc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.document_loaders import DirectoryLoader
from langchain_text_splitters import CharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from typing import List
from langchain_chroma import Chroma
from typing_extensions import TypedDict
from typing import Annotated
from langgraph.graph.message import AnyMessage, add_messages
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.graph import END, StateGraph, START
from langgraph.checkpoint.memory import MemorySaver
from fastapi import FastAPI, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from typing import Optional
from PIL import Image
import base64
from io import BytesIO
import os 
import logging
import sys

logger = logging.getLogger('uvicorn.error')
logger.setLevel(logging.DEBUG)

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.5)

os.environ["TOKENIZERS_PARALLELISM"] = "false"
persist_directory = 'db'
embedding = HuggingFaceEmbeddings(model_name="OrdalieTech/Solon-embeddings-large-0.1")
memory = MemorySaver()

if  os.path.exists(persist_directory) :
    vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)
else :
    glob_pattern="./*.md"
    directory_path = "./documents"
    loader = DirectoryLoader(directory_path, glob=glob_pattern)
    documents = loader.load()
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    texts = text_splitter.split_documents(documents)
    vectordb = Chroma.from_documents(documents=texts, embedding=embedding, persist_directory=persist_directory)

retriever = vectordb.as_retriever()

system = """
Tu es un assistant spécialisé dans l'enseignement de la spécialité Numérique et sciences informatiques en classe de première et de terminal
Tu as un bon niveau en langage Python
Ton interlocuteur est un élève qui suit la spécialité nsi en première et en terminale
Ton unique thème de conservation doit être l'enseignement de l'informatique. Tu ne dois pas aborder d'autres thèmes que l'enseignement de l'informatique
Tu ne dois pas faire d'erreur, répond à la question uniquement si tu es sûr de ta réponse
si tu ne trouves pas la réponse à une question, tu réponds que tu ne connais pas la réponse et que l'élève doit s'adresser à son professeur pour obtenir cette réponse
si l'élève n'arrive pas à trouver la réponse à un exercice, tu ne dois pas lui donner tout de suite la réponse, mais seulement lui donner des indications pour lui permettre de trouver la réponse par lui même
Tu dois uniquement répondre en langue française
Tu ne dois pas commencer tes réponses par "Assistant :"
Tu trouveras ci-dessous les programmes de la spécialité NSI en première et terminale, tu devras veiller à ce que tes réponses ne sortent pas du cadre de ces programmes
Si la question posée ne rentre pas dans le cadre du programme de NSI tu peux tout de même répondre en précisant bien que cette notion est hors programme
si tu proposes un exercice, tu dois bien vérifier que toutes les notions nécessaires à la résolution de l'exercice sont explicitement au programme de NSI
"""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Extraits des programmes de NSI : \n {document} \n\n Historique conversation entre l'assistant et l'élève : \n {historical} \n\n Intervention de l'élève : {question}"),
    ]
)


def format_docs(docs):
    return "\n".join(doc.page_content for doc in docs)

def format_historical(hist):
    historical = []
    for i in range(0,len(hist)-2,2):
        historical.append("Utilisateur : "+hist[i].content[0]['text'])
        historical.append("Assistant : "+hist[i+1].content[0]['text'])
    return "\n".join(historical[-10:])


class GraphState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]
    documents : str

def retrieve(state : GraphState):
    documents = format_docs(retriever.invoke(state['messages'][-1].content[0]['text']))
    return {'documents' : documents}

def chatbot(state : GraphState):
    question = prompt.invoke({'historical': format_historical(state['messages']), 'document':state['documents'] ,  'question' : state['messages'][-1].content[0]['text']})
    q = question.messages[0].content + question.messages[1].content
    if len(state['messages'][-1].content) > 1 :
        response = llm.invoke([HumanMessage(
            content=[
                {"type": "text", "text": q},
                state['messages'][-1].content[1]
            ])])
    else :
        response = llm.invoke([HumanMessage(
            content=[
                {"type": "text", "text": q}
            ])])
    return {"messages": [AIMessage(content=[{'type': 'text', 'text': response.content}])]}

workflow = StateGraph(GraphState)
workflow.add_node('retrieve', retrieve)
workflow.add_node('chatbot', chatbot)

workflow.add_edge(START, 'retrieve')
workflow.add_edge('retrieve','chatbot')
workflow.add_edge('chatbot', END)

app_chatbot = workflow.compile(checkpointer=memory)

@app.post('/request')
def request(id:Annotated[str, Form()], query:Annotated[str, Form()], image:Optional[UploadFile] = None):
    config = {"configurable": {"thread_id": id}}
    if image:
        try:
            img = Image.open(image.file)
            img_buffer = BytesIO()
            img.save(img_buffer, format='PNG')
            byte_data = img_buffer.getvalue()
            base64_img = base64.b64encode(byte_data).decode("utf-8")
            message = HumanMessage(
            content=[
                {'type': 'text', 'text': query},
                {'type': 'image_url', 'image_url': {"url": f"data:image/jpeg;base64,{base64_img}"}}
            ])
        except:
            return {"response":"Attention, vous m'avez fourni autre chose qu'une image. Renouvelez votre demande avec une image."}
        rep = app_chatbot.invoke({"messages": message},config, stream_mode="values")
    else :
        rep = app_chatbot.invoke({"messages": [HumanMessage(content=[{'type': 'text', 'text': query}])]},config, stream_mode="values")
    return {"response":rep['messages'][-1].content[0]['text']}