from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch from functools import partial from langchain_core.messages import HumanMessage, AIMessage from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import START, MessagesState, StateGraph import os from dotenv import load_dotenv load_dotenv() # Initialize the model and tokenizer print("Loading model and tokenizer...") device = "cuda" if torch.cuda.is_available() else "cpu" model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct" try: # Load the model in BF16 format for better performance and lower memory usage tokenizer = AutoTokenizer.from_pretrained(model_name) if device == "cuda": print("Using GPU for the model...") model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, device_map="auto", low_cpu_mem_usage=True ) else: print("Using CPU for the model...") model = AutoModelForCausalLM.from_pretrained( model_name, device_map={"": device}, torch_dtype=torch.float32 ) print(f"Model loaded successfully on: {device}") except Exception as e: print(f"Error loading the model: {str(e)}") raise # Define the function that calls the model def call_model(state: MessagesState, system_prompt: str): """ Call the model with the given messages Args: state: MessagesState Returns: dict: A dictionary containing the generated text and the thread ID """ # Convert LangChain messages to chat format messages = [ {"role": "system", "content": system_prompt} ] for msg in state["messages"]: if isinstance(msg, HumanMessage): messages.append({"role": "user", "content": msg.content}) elif isinstance(msg, AIMessage): messages.append({"role": "assistant", "content": msg.content}) # Prepare the input using the chat template input_text = tokenizer.apply_chat_template(messages, tokenize=False) inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) # Generate response outputs = model.generate( inputs, max_new_tokens=512, # Increase the number of tokens for longer responses temperature=0.7, top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id ) # Decode and clean the response response = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the assistant's response (after the last user message) response = response.split("Assistant:")[-1].strip() # Convert the response to LangChain format ai_message = AIMessage(content=response) return {"messages": state["messages"] + [ai_message]} # Define the graph workflow = StateGraph(state_schema=MessagesState) # Define the node in the graph workflow.add_edge(START, "model") workflow.add_node("model", call_model) # Add memory memory = MemorySaver() # Define the default system prompt DEFAULT_SYSTEM_PROMPT = "You are a friendly Chatbot. Always reply in the language in which the user is writing to you." # Use partial to create a version of the function with the default system prompt workflow.add_node("model", partial(call_model, system_prompt=DEFAULT_SYSTEM_PROMPT)) graph_app = workflow.compile(checkpointer=memory) # Define the data model for the request class QueryRequest(BaseModel): query: str thread_id: str = "default" system_prompt: str = DEFAULT_SYSTEM_PROMPT # Define the model for summary requests class SummaryRequest(BaseModel): text: str thread_id: str = "default" max_length: int = 200 # Create the FastAPI application app = FastAPI(title="LangChain FastAPI", description="API to generate text using LangChain and LangGraph - Máximo Fernández Núñez IriusRisk test challenge") # Welcome endpoint @app.get("/") async def api_home(): """Welcome endpoint""" return {"detail": "Welcome to Máximo Fernández Núñez IriusRisk test challenge"} # Generate endpoint @app.post("/generate") async def generate(request: QueryRequest): """ Endpoint to generate text using the language model Args: request: QueryRequest query: str thread_id: str = "default" system_prompt: str = DEFAULT_SYSTEM_PROMPT Returns: dict: A dictionary containing the generated text and the thread ID """ try: # Configure the thread ID config = {"configurable": {"thread_id": request.thread_id}} # Create the input message input_messages = [HumanMessage(content=request.query)] # Invoke the graph with custom system prompt output = graph_app.invoke( {"messages": input_messages}, config, {"model": {"system_prompt": request.system_prompt}} ) # Get the model response response = output["messages"][-1].content return { "generated_text": response, "thread_id": request.thread_id } except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") @app.post("/summarize") async def summarize(request: SummaryRequest): """ Endpoint to generate a summary using the language model Args: request: SummaryRequest text: str - The text to summarize thread_id: str = "default" max_length: int = 200 - Maximum summary length Returns: dict: A dictionary containing the summary and the thread ID """ try: # Configure the thread ID config = {"configurable": {"thread_id": request.thread_id}} # Create a specific system prompt for summarization summary_system_prompt = f"Make a summary of the following text in no more than {request.max_length} words. Keep the most important information and eliminate unnecessary details." # Create the input message input_messages = [HumanMessage(content=request.text)] # Invoke the graph with summarization system prompt output = graph_app.invoke( {"messages": input_messages}, config, {"model": {"system_prompt": summary_system_prompt}} ) # Get the model response response = output["messages"][-1].content return { "summary": response, "thread_id": request.thread_id } except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating summary: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)