from fastapi import FastAPI, Depends, HTTPException, status from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from jose import JWTError from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput from auth import register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session from utils import search_with_llm from fastapi.security import OAuth2PasswordRequestForm from pathlib import Path from datetime import timedelta import os import logging logging.basicConfig(level=logging.INFO) SECRET_KEY = os.getenv("SECRET_KEY", 'def6nQHONW99pOPyba9DShny6FB1CJJBigZault') REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY", SECRET_KEY) ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 REFRESH_TOKEN_EXPIRE_DAYS = 7 app = FastAPI() # Entry Endpoint @app.get('/') def index() -> FileResponse: file_path = Path(__file__).parent / 'static' / 'index.html' return FileResponse(path=str(file_path), media_type='text/html') @app.post("/register") async def register(user: UserRegister, db: Session = Depends(get_db)): """Registers a new user.""" new_user = register_user(user.username, user.password, db) return {"message": "User registered successfully", "user": new_user.username} @app.post("/login", response_model=TokenResponse) async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): try: user = authenticate_user(form_data.username, form_data.password, db) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password", headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) access_token = create_token( data={"sub": user.username}, expires_delta=access_token_expires, secret_key=SECRET_KEY ) refresh_token = create_token( data={"sub": user.username}, expires_delta=refresh_token_expires, secret_key=REFRESH_SECRET_KEY ) return { "access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer" } except Exception as e: logging.error(f"Login error: {str(e)}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error") @app.post("/refresh", response_model=TokenResponse) async def refresh(refresh_request: RefreshTokenRequest): try: # Verify the refresh token username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY) # Create new access token access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_token( data={"sub": username}, expires_delta=access_token_expires, secret_key=SECRET_KEY ) return { "access_token": access_token, "refresh_token": refresh_request.refresh_token, # Return the same refresh token "token_type": "bearer" } except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) @app.post("/search") async def search( query_input: QueryInput, username: str = Depends(verify_access_token), ): try: response = search_with_llm(query_input.query) return {"response": response} except Exception as e: raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") # WebSocket endpoint for streaming @app.on_event("startup") async def startup_event(): os.makedirs("./cache", exist_ok=True) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)