Spaces:
Running
Running
from fastapi import Depends, HTTPException, status, Header, Query, Request | |
from typing import Optional | |
from database import get_users | |
from models import User, UserInDB | |
from token_store import token_store | |
async def get_token( | |
request: Request, | |
authorization: Optional[str] = Header(None, convert_underscores=False), | |
token: Optional[str] = Query( | |
None, description="Access token (alternative to Authorization header)" | |
), | |
) -> str: | |
""" | |
Extract token from Authorization header or query parameter | |
Supports both methods for better compatibility with various clients | |
""" | |
# Debug headers | |
headers = dict(request.headers) | |
print(f"All headers: {headers}") | |
print(f"Authorization header from param: {authorization}") | |
auth_header = headers.get("authorization") or headers.get("Authorization") | |
print(f"Authorization header from request: {auth_header}") | |
# First try to get token from Authorization header | |
if authorization or auth_header: | |
# Use the authorization from parameter or from request headers | |
auth = authorization or auth_header | |
# Handle "Bearer" prefix if present | |
if auth.startswith("Bearer "): | |
return auth.replace("Bearer ", "").strip() | |
else: | |
# If it doesn't have Bearer prefix, use as is | |
return auth.strip() | |
# Then try to get token from query parameter | |
if token: | |
return token.strip() | |
# If no token is provided, raise an error | |
raise HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Authorization header missing", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
async def get_current_user_from_token(token: str = Depends(get_token)): | |
""" | |
Validate token and return user if valid | |
""" | |
credentials_exception = HTTPException( | |
status_code=status.HTTP_401_UNAUTHORIZED, | |
detail="Could not validate credentials", | |
headers={"WWW-Authenticate": "Bearer"}, | |
) | |
# Validate token | |
username = token_store.validate_token(token) | |
if not username: | |
print(f"Invalid or expired token") | |
raise credentials_exception | |
# Get user from database | |
users = get_users() | |
if username not in users: | |
print(f"User not found: {username}") | |
raise credentials_exception | |
user_dict = users[username] | |
user = UserInDB(**user_dict) | |
print(f"User authenticated: {user.username}") | |
return user | |
def create_token_for_user(username: str) -> str: | |
""" | |
Create a new token for a user | |
""" | |
return token_store.create_token(username) | |
def remove_token(token: str) -> bool: | |
""" | |
Remove a token from the store | |
""" | |
return token_store.remove_token(token) | |