dav74 commited on
Commit
7b5ebc3
·
verified ·
1 Parent(s): 45c6e94

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +145 -0
  2. requirements.txt +11 -0
main.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.prompts import ChatPromptTemplate
2
+ from langchain_google_genai import ChatGoogleGenerativeAI
3
+ from langchain_community.document_loaders import DirectoryLoader
4
+ from langchain_text_splitters import CharacterTextSplitter
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ from typing import List
7
+ from langchain_chroma import Chroma
8
+ from typing_extensions import TypedDict
9
+ from typing import Annotated
10
+ from langgraph.graph.message import AnyMessage, add_messages
11
+ from langchain_core.messages import HumanMessage, AIMessage
12
+ from langgraph.graph import END, StateGraph, START
13
+ from langgraph.checkpoint.memory import MemorySaver
14
+ from fastapi import FastAPI, UploadFile, Form
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from typing import Optional
17
+ from PIL import Image
18
+ import base64
19
+ from io import BytesIO
20
+ import os
21
+ import logging
22
+ import sys
23
+
24
+ logger = logging.getLogger('uvicorn.error')
25
+ logger.setLevel(logging.DEBUG)
26
+
27
+ app = FastAPI()
28
+
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"],
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+
38
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp", temperature=0.5)
39
+
40
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
41
+ persist_directory = 'db'
42
+ embedding = HuggingFaceEmbeddings(model_name="OrdalieTech/Solon-embeddings-large-0.1")
43
+ memory = MemorySaver()
44
+
45
+ if os.path.exists(persist_directory) :
46
+ vectordb = Chroma(persist_directory=persist_directory, embedding_function=embedding)
47
+ else :
48
+ glob_pattern="./*.md"
49
+ directory_path = "./documents"
50
+ loader = DirectoryLoader(directory_path, glob=glob_pattern)
51
+ documents = loader.load()
52
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
53
+ texts = text_splitter.split_documents(documents)
54
+ vectordb = Chroma.from_documents(documents=texts, embedding=embedding, persist_directory=persist_directory)
55
+
56
+ retriever = vectordb.as_retriever()
57
+
58
+ system = """
59
+ Tu es un assistant spécialisé dans l'enseignement de la spécialité Numérique et sciences informatiques en classe de première et de terminal
60
+ Tu as un bon niveau en langage Python
61
+ Ton interlocuteur est un élève qui suit la spécialité nsi en première et en terminale
62
+ Ton unique thème de conservation doit être l'enseignement de l'informatique. Tu ne dois pas aborder d'autres thèmes que l'enseignement de l'informatique
63
+ Tu ne dois pas faire d'erreur, répond à la question uniquement si tu es sûr de ta réponse
64
+ si tu ne trouves pas la réponse à une question, tu réponds que tu ne connais pas la réponse et que l'élève doit s'adresser à son professeur pour obtenir cette réponse
65
+ si l'élève n'arrive pas à trouver la réponse à un exercice, tu ne dois pas lui donner tout de suite la réponse, mais seulement lui donner des indications pour lui permettre de trouver la réponse par lui même
66
+ Tu dois uniquement répondre en langue française
67
+ Tu trouveras ci-dessous les programmes de la spécialité NSI en première et terminale, tu devras veiller à ce que tes réponses ne sortent pas du cadre de ces programmes
68
+ Si la question posée ne rentre pas dans le cadre du programme de NSI tu peux tout de même répondre en précisant bien que cette notion est hors programme
69
+ si tu proposes un exercice, tu dois bien vérifier que toutes les notions nécessaires à la résolution de l'exercice sont explicitement au programme de NSI
70
+ """
71
+
72
+ prompt = ChatPromptTemplate.from_messages(
73
+ [
74
+ ("system", system),
75
+ ("human", "Extraits des programmes de NSI : \n {document} \n\n Historique conversation entre l'assistant et l'élève : \n {historical} \n\n Intervention de l'élève : {question}"),
76
+ ]
77
+ )
78
+
79
+
80
+ def format_docs(docs):
81
+ return "\n".join(doc.page_content for doc in docs)
82
+
83
+ def format_historical(hist):
84
+ historical = []
85
+ for i in range(0,len(hist)-2,2):
86
+ historical.append("Utilisateur : "+hist[i].content[0]['text'])
87
+ historical.append("Assistant : "+hist[i+1].content[0]['text'])
88
+ return "\n".join(historical[-10:])
89
+
90
+
91
+ class GraphState(TypedDict):
92
+ messages: Annotated[list[AnyMessage], add_messages]
93
+ documents : str
94
+
95
+ def retrieve(state : GraphState):
96
+ documents = format_docs(retriever.invoke(state['messages'][-1].content[0]['text']))
97
+ return {'documents' : documents}
98
+
99
+ def chatbot(state : GraphState):
100
+ question = prompt.invoke({'historical': format_historical(state['messages']), 'document':state['documents'] , 'question' : state['messages'][-1].content[0]['text']})
101
+ q = question.messages[0].content + question.messages[1].content
102
+ if len(state['messages'][-1].content) > 1 :
103
+ response = llm.invoke([HumanMessage(
104
+ content=[
105
+ {"type": "text", "text": q},
106
+ state['messages'][-1].content[1]
107
+ ])])
108
+ else :
109
+ response = llm.invoke([HumanMessage(
110
+ content=[
111
+ {"type": "text", "text": q}
112
+ ])])
113
+ return {"messages": [AIMessage(content=[{'type': 'text', 'text': response.content}])]}
114
+
115
+ workflow = StateGraph(GraphState)
116
+ workflow.add_node('retrieve', retrieve)
117
+ workflow.add_node('chatbot', chatbot)
118
+
119
+ workflow.add_edge(START, 'retrieve')
120
+ workflow.add_edge('retrieve','chatbot')
121
+ workflow.add_edge('chatbot', END)
122
+
123
+ app_chatbot = workflow.compile(checkpointer=memory)
124
+
125
+ @app.post('/request')
126
+ def request(id:Annotated[str, Form()], query:Annotated[str, Form()], image:Optional[UploadFile] = None):
127
+ config = {"configurable": {"thread_id": id}}
128
+ if image:
129
+ try:
130
+ img = Image.open(image.file)
131
+ img_buffer = BytesIO()
132
+ img.save(img_buffer, format='PNG')
133
+ byte_data = img_buffer.getvalue()
134
+ base64_img = base64.b64encode(byte_data).decode("utf-8")
135
+ message = HumanMessage(
136
+ content=[
137
+ {'type': 'text', 'text': query},
138
+ {'type': 'image_url', 'image_url': {"url": f"data:image/jpeg;base64,{base64_img}"}}
139
+ ])
140
+ except:
141
+ return {"response":"Attention, vous m'avez fourni autre chose qu'une image. Renouvelez votre demande avec une image."}
142
+ rep = app_chatbot.invoke({"messages": message},config, stream_mode="values")
143
+ else :
144
+ rep = app_chatbot.invoke({"messages": [HumanMessage(content=[{'type': 'text', 'text': query}])]},config, stream_mode="values")
145
+ return {"response":rep['messages'][-1].content[0]['text']}
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-core
3
+ langchain-google-genai
4
+ langchain-community
5
+ langchain-huggingface
6
+ langgraph
7
+ unstructured
8
+ pillow
9
+ langchain-chroma
10
+ unstructured[md]
11
+ fastapi[all]