Spaces:
Sleeping
Sleeping
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
from langchain_community.document_loaders import DirectoryLoader | |
from langchain_community.document_loaders import PyPDFLoader | |
from typing import List | |
from typing_extensions import TypedDict | |
from typing import Annotated | |
from langgraph.graph.message import AnyMessage, add_messages | |
from langchain_core.messages import HumanMessage, AIMessage | |
from langgraph.graph import END, StateGraph, START | |
from langgraph.checkpoint.memory import MemorySaver | |
from fastapi import FastAPI, UploadFile, Form | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import Optional | |
from PIL import Image | |
import base64 | |
from io import BytesIO | |
import os | |
import logging | |
import sys | |
import matplotlib | |
matplotlib.use('Agg') # Configuration du backend avant d'importer pyplot | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import re | |
import json | |
logger = logging.getLogger('uvicorn.error') | |
logger.setLevel(logging.DEBUG) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.5) | |
memory = MemorySaver() | |
file_path_sec = "./documents/seconde.pdf" | |
loader_sec = PyPDFLoader(file_path_sec) | |
sec = loader_sec.load() | |
file_path_prem = "./documents/premiere.pdf" | |
loader_prem = PyPDFLoader(file_path_prem) | |
prem = loader_prem.load() | |
file_path_term = "./documents/term.pdf" | |
loader_term = PyPDFLoader(file_path_term) | |
term = loader_term.load() | |
plot_graph = False | |
system = """ | |
Tu es un assistant expert en pédagogie et en mathématiques. Ta spécialité est l'enseignement de mathématiques au lycée. | |
Ton domaine de compétences couvre la classe de seconde générale, la spécialité de la classe de première et la spécialité de la classe de terminale. | |
Ton interlocuteur est, soit un élève de seconde, soit un élève de première, soit un élève de terminale. | |
Ton rôle est d'aider l'élève à progresser en mathématiques : | |
- en répondant à ces questions | |
- en l'aidant à résoudre un exercice | |
- en lui proposant des exercices pour voir s'il a bien assimilé les conceptes vus en classe avec son professeur | |
**ATTENTION** : Si l'élève te demande de résoudre un exercice à sa place, tu ne dois pas le faire, tu dois l'aider à trouver les réponses aux questions, mais jamais lui donner directement la réponse | |
Si tu ne connais pas la réponse à une question, propose à l'élève de demander à son professeur. | |
Tu dois obligatoirement utiliser le format Latex pour écrire les formules mathématiques | |
Tu dois bien vérifier que la demande de l'élève est en adéquation avec sa classe. Si ce n'est pas le cas, tu peux lui répondre en précisant que cette notion sera vue en classe plus tard. | |
Tu ne dois jamais aborder d'autre sujet que les mathématiques | |
""" | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("system", system), | |
("human", """ | |
Voici les différents programme officiel qui te permettront d'aider l'élève : | |
Le programme de la classe de seconde : | |
{sec} | |
Le programme de la spécialité mathématiques en classe de première : | |
{prem} | |
Le programme de la spécialité mathématiques en classe de terminale : | |
{term} | |
Tu trouveras aussi l'historique conversation que tu as eu avec l'élève : \n {historical} | |
Et enfin l'intervention de l'élève : {question}"), | |
{graph} | |
""") | |
] | |
) | |
def format_historical(hist): | |
historical = [] | |
for i in range(0,len(hist)-2,2): | |
historical.append("Utilisateur : "+hist[i].content[0]['text']) | |
historical.append("Assistant : "+hist[i+1].content[0]['text']) | |
return "\n".join(historical[-20:]) | |
def generate_function_plot(expression, x_range=(-10, 10), num_points=1000): | |
try: | |
# Créer les points x | |
x = np.linspace(x_range[0], x_range[1], num_points) | |
# Évaluer la fonction | |
# Remplacer les expressions mathématiques courantes | |
expression = expression.replace('^', '**') | |
expression = expression.replace('sin', 'np.sin') | |
expression = expression.replace('cos', 'np.cos') | |
expression = expression.replace('tan', 'np.tan') | |
expression = expression.replace('exp', 'np.exp') | |
expression = expression.replace('log', 'np.log') | |
expression = expression.replace('ln', 'np.log') | |
expression = expression.replace('sqrt', 'np.sqrt') | |
expression = expression.replace('e', 'np.exp') | |
expression = expression.replace('exp', 'np.exp') | |
# Évaluer l'expression | |
y = eval(expression) | |
# Créer le graphique | |
plt.figure(figsize=(10, 6)) | |
plt.plot(x, y) | |
plt.grid(True) | |
plt.axhline(y=0, color='k', linestyle='-', alpha=0.3) | |
plt.axvline(x=0, color='k', linestyle='-', alpha=0.3) | |
# Sauvegarder le graphique en mémoire | |
img_buffer = BytesIO() | |
plt.savefig(img_buffer, format='PNG') | |
plt.close() | |
img_buffer.seek(0) | |
# Convertir en base64 | |
base64_img = base64.b64encode(img_buffer.getvalue()).decode('utf-8') | |
return base64_img | |
except Exception as e: | |
return None | |
class GraphState(TypedDict): | |
messages: Annotated[list[AnyMessage], add_messages] | |
def should_plot(state: GraphState): | |
system_prompt = """Tu es un assistant expert en pédagogie et en mathématiques. Ta spécialité est l'enseignement de mathématiques au lycée. | |
Ta tâche est de déterminer si la demande de l'utilisateur nécessite la représentation graphique d'une fonction. | |
Si c'est le cas, tu dois extraire l'expression de la fonction. | |
Réponds sur une seule ligne avec le format suivant : | |
OUI:expression si un graphique est nécessaire | |
NON si aucun graphique n'est nécessaire | |
Exemples de réponses : | |
OUI:x**2 | |
OUI:sin(x) | |
NON | |
""" | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", system_prompt), | |
("human", "Analyse cette demande : {question}") | |
]) | |
question = state['messages'][-1].content[0]['text'] | |
response = llm.invoke(prompt.format_messages(question=question)) | |
try: | |
response_text = response.content.strip() | |
if response_text.startswith("OUI:"): | |
plot_graph = True | |
expression = response_text[4:].strip() | |
return {"should_plot": True, "expression": expression} | |
else : | |
plot_graph = False | |
return {"should_plot": False, "expression": None} | |
except Exception as e: | |
return {"should_plot": False, "expression": None} | |
def chatbot(state : GraphState): | |
plot_decision = should_plot(state) | |
if plot_graph : | |
msg_graph = "Une représentation graphique de la fonction a été fournis à l'élève, tu dois préciser dans ta réponse que cette représentation graphique a été fournis à l'élève" | |
else : | |
msg_graph = "" | |
question = prompt.invoke({'historical': format_historical(state['messages']),'sec':sec, 'prem':prem, 'term':term, 'graph':msg_graph, 'question' : state['messages'][-1].content[0]['text']}) | |
q = question.messages[0].content + question.messages[1].content | |
if len(state['messages'][-1].content) > 1 : | |
response = llm.invoke([HumanMessage(content=[{"type": "text", "text": q},state['messages'][-1].content[1]])]) | |
else : | |
response = llm.invoke([HumanMessage(content=[{"type": "text", "text": q}])]) | |
if plot_decision["should_plot"] and plot_decision["expression"]: | |
plot_base64 = generate_function_plot(plot_decision["expression"]) | |
if plot_base64: | |
return {"messages": [AIMessage(content=[{'type': 'text', 'text': response.content},{'type': 'image_url', 'image_url': {"url": f"data:image/png;base64,{plot_base64}"}}])]} | |
return {"messages": [AIMessage(content=[{'type': 'text', 'text': response.content}])]} | |
workflow = StateGraph(GraphState) | |
workflow.add_node('chatbot', chatbot) | |
workflow.add_edge(START, 'chatbot') | |
workflow.add_edge('chatbot', END) | |
app_chatbot = workflow.compile(checkpointer=memory) | |
def request(id:Annotated[str, Form()], query:Annotated[str, Form()], image:Optional[UploadFile] = None): | |
config = {"configurable": {"thread_id": id}} | |
if image: | |
try: | |
img = Image.open(image.file) | |
img_buffer = BytesIO() | |
img.save(img_buffer, format='PNG') | |
byte_data = img_buffer.getvalue() | |
base64_img = base64.b64encode(byte_data).decode("utf-8") | |
message = HumanMessage( | |
content=[ | |
{'type': 'text', 'text': query}, | |
{'type': 'image_url', 'image_url': {"url": f"data:image/jpeg;base64,{base64_img}"}} | |
]) | |
except: | |
return {"response":"Attention, vous m'avez fourni autre chose qu'une image. Renouvelez votre demande avec une image."} | |
rep = app_chatbot.invoke({"messages": message},config, stream_mode="values") | |
else : | |
rep = app_chatbot.invoke({"messages": [HumanMessage(content=[{'type': 'text', 'text': query}])]},config, stream_mode="values") | |
if len(rep['messages'][-1].content) > 1 and rep['messages'][-1].content[1].get('type') == 'image_url': | |
return {"response": rep['messages'][-1].content[0]['text'],"image": rep['messages'][-1].content[1]['image_url']['url']} | |
return {"response": rep['messages'][-1].content[0]['text']} |