File size: 4,133 Bytes
1d87783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfa1426
1d87783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfa1426
 
 
 
 
 
 
 
1d87783
 
 
 
 
 
 
 
cfa1426
1d87783
 
 
 
 
 
 
cfa1426
1d87783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cfa1426
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import os
import logging
import sys
import json

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

app = FastAPI(title="RAG Pipeline API", description="Multi-dataset RAG API", version="1.0.0")

# Initialize pipelines for all datasets
pipelines = {}
google_api_key = os.getenv("GOOGLE_API_KEY")

logger.info(f"Starting RAG Pipeline API")
logger.info(f"Google API Key present: {'Yes' if google_api_key else 'No'}")

# Don't load datasets during startup - do it asynchronously after server starts
logger.info("RAG Pipeline API is ready to serve requests - datasets will load in background")

class Question(BaseModel):
    text: str
    dataset: str = "developer-portfolio"

@app.post("/answer")
async def get_answer(question: Question):
    try:
        # Check if any pipelines are loaded
        if not pipelines:
            return {
                "answer": "RAG Pipeline is running but datasets are still loading in the background. Please try again in a moment, or check /health for loading status.",
                "dataset": question.dataset,
                "status": "datasets_loading"
            }
        
        # Select the appropriate pipeline based on dataset
        if question.dataset not in pipelines:
            raise HTTPException(status_code=400, detail=f"Dataset '{question.dataset}' not available. Available datasets: {list(pipelines.keys())}")
        
        selected_pipeline = pipelines[question.dataset]
        answer = selected_pipeline.answer_question(question.text)
        return {"answer": answer, "dataset": question.dataset}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/datasets")
async def list_datasets():
    """List all available datasets"""
    return {"datasets": list(pipelines.keys())}

async def load_datasets_background():
    """Load datasets in background after server starts"""
    global pipelines
    if google_api_key:
        try:
            # Import modules only when needed
            import sys
            sys.path.append('/app')
            from app.pipeline import RAGPipeline
            from app.config import DATASET_CONFIGS
            
            # Only load developer-portfolio to save memory
            dataset_name = "developer-portfolio"
            logger.info(f"Loading dataset: {dataset_name}")
            pipeline = RAGPipeline.from_preset(
                google_api_key=google_api_key,
                preset_name=dataset_name
            )
            pipelines[dataset_name] = pipeline
            logger.info(f"Successfully loaded {dataset_name}")
        except Exception as e:
            logger.error(f"Failed to load dataset: {e}")
        logger.info(f"Background loading complete - {len(pipelines)} datasets loaded")
    else:
        logger.warning("No Google API key provided - running in demo mode without datasets")

@app.on_event("startup")
async def startup_event():
    logger.info("FastAPI application startup complete")
    logger.info(f"Server should be running on port: 7860")
    
    # Start loading datasets in background (non-blocking)
    import asyncio
    asyncio.create_task(load_datasets_background())

@app.on_event("shutdown")
async def shutdown_event():
    logger.info("FastAPI application shutting down")

@app.get("/")
async def root():
    """Root endpoint"""
    return {"status": "ok", "message": "RAG Pipeline API", "version": "1.0.0", "datasets": list(pipelines.keys())}

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    logger.info("Health check called")
    loading_status = "complete" if "developer-portfolio" in pipelines else "loading"
    return {
        "status": "healthy", 
        "datasets_loaded": len(pipelines), 
        "total_datasets": 1,  # Only loading developer-portfolio
        "loading_status": loading_status,
        "port": "7860"
    }