File size: 5,019 Bytes
91f974c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import logging
import sys
from .config import DATASET_CONFIGS
# Lazy imports to avoid blocking startup
# from .pipeline import RAGPipeline  # Will import when needed
# import umap  # Will import when needed for visualization
# import plotly.express as px  # Will import when needed for visualization
# import plotly.graph_objects as go  # Will import when needed for visualization
# from plotly.subplots import make_subplots  # Will import when needed for visualization
# import numpy as np  # Will import when needed for visualization
# from sklearn.preprocessing import normalize  # Will import when needed for visualization
# import pandas as pd  # Will import when needed for visualization
import json

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

app = FastAPI(title="RAG Pipeline API", description="Multi-dataset RAG API", version="1.0.0")

# Initialize pipelines for all datasets
pipelines = {}
google_api_key = os.getenv("GOOGLE_API_KEY")

logger.info(f"Starting RAG Pipeline API")
logger.info(f"Port from env: {os.getenv('PORT', 'Not set - will use 8000')}")
logger.info(f"Google API Key present: {'Yes' if google_api_key else 'No'}")
logger.info(f"Available datasets: {list(DATASET_CONFIGS.keys())}")

# Don't load datasets during startup - do it asynchronously after server starts
logger.info("RAG Pipeline API is ready to serve requests - datasets will load in background")

# Visualization function disabled to speed up startup
# def create_3d_visualization(pipeline):
#     ... (commented out for faster startup)

class Question(BaseModel):
    text: str
    dataset: str = "developer-portfolio"  # Default dataset

@app.post("/answer")
async def get_answer(question: Question):
    try:
        # Check if any pipelines are loaded
        if not pipelines:
            return {
                "answer": "RAG Pipeline is running but datasets are still loading in the background. Please try again in a moment, or check /health for loading status.",
                "dataset": question.dataset,
                "status": "datasets_loading"
            }
        
        # Select the appropriate pipeline based on dataset
        if question.dataset not in pipelines:
            raise HTTPException(status_code=400, detail=f"Dataset '{question.dataset}' not available. Available datasets: {list(pipelines.keys())}")
        
        selected_pipeline = pipelines[question.dataset]
        answer = selected_pipeline.answer_question(question.text)
        return {"answer": answer, "dataset": question.dataset}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/datasets")
async def list_datasets():
    """List all available datasets"""
    return {"datasets": list(pipelines.keys())}

async def load_datasets_background():
    """Load datasets in background after server starts"""
    global pipelines
    if google_api_key:
        # Import RAGPipeline only when needed
        from .pipeline import RAGPipeline
        # Only load developer-portfolio to save memory
        dataset_name = "developer-portfolio"
        try:
            logger.info(f"Loading dataset: {dataset_name}")
            pipeline = RAGPipeline.from_preset(
                google_api_key=google_api_key,
                preset_name=dataset_name
            )
            pipelines[dataset_name] = pipeline
            logger.info(f"Successfully loaded {dataset_name}")
        except Exception as e:
            logger.error(f"Failed to load {dataset_name}: {e}")
        logger.info(f"Background loading complete - {len(pipelines)} datasets loaded")
    else:
        logger.warning("No Google API key provided - running in demo mode without datasets")

@app.on_event("startup")
async def startup_event():
    logger.info("FastAPI application startup complete")
    logger.info(f"Server should be running on port: {os.getenv('PORT', '8000')}")
    
    # Start loading datasets in background (non-blocking)
    import asyncio
    asyncio.create_task(load_datasets_background())

@app.on_event("shutdown")
async def shutdown_event():
    logger.info("FastAPI application shutting down")

@app.get("/")
async def root():
    """Root endpoint"""
    return {"status": "ok", "message": "RAG Pipeline API", "version": "1.0.0", "datasets": list(pipelines.keys())}

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    logger.info("Health check called")
    loading_status = "complete" if "developer-portfolio" in pipelines else "loading"
    return {
        "status": "healthy", 
        "datasets_loaded": len(pipelines), 
        "total_datasets": 1,  # Only loading developer-portfolio
        "loading_status": loading_status,
        "port": os.getenv('PORT', '8000')
    }