|
from fastapi import FastAPI, HTTPException, Depends, status |
|
from fastapi.responses import FileResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
|
from pydantic import BaseModel |
|
from jose import JWTError, jwt |
|
from datetime import datetime, timedelta |
|
from openai import OpenAI |
|
from pathlib import Path |
|
from typing import List, Optional, Dict |
|
from datasets import Dataset, load_dataset |
|
from sentence_transformers import SentenceTransformer |
|
from huggingface_hub import login |
|
from contextlib import asynccontextmanager |
|
import pandas as pd |
|
import numpy as np |
|
import torch as t |
|
import os |
|
import logging |
|
from functools import lru_cache |
|
from diskcache import Cache |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
get_sentence_transformer() |
|
yield |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
cache = Cache('./cache') |
|
|
|
|
|
SECRET_KEY = os.environ.get("prime_auth", "c0369f977b69e717dc16f6fc574039eb2b1ebde38014d2be") |
|
REFRESH_SECRET_KEY = os.environ.get("prolonged_auth", "916018771b29084378c9362c0cd9e631fd4927b8aea07f91") |
|
ALGORITHM = "HS256" |
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
|
REFRESH_TOKEN_EXPIRE_DAYS = 7 |
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") |
|
|
|
|
|
class QueryInput(BaseModel): |
|
query: str |
|
|
|
class SearchResult(BaseModel): |
|
text: str |
|
similarity: float |
|
model_type: str |
|
|
|
class TokenResponse(BaseModel): |
|
access_token: str |
|
refresh_token: str |
|
token_type: str |
|
|
|
class SaveInput(BaseModel): |
|
user_type: str |
|
username: str |
|
query: str |
|
retrieved_text: str |
|
model_type: str |
|
reaction: str |
|
|
|
class SaveBatchInput(BaseModel): |
|
items: List[SaveInput] |
|
|
|
class RefreshRequest(BaseModel): |
|
refresh_token: str |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def get_sentence_transformer(): |
|
"""Load and cache the SentenceTransformer model with lru_cache""" |
|
return SentenceTransformer(model_name_or_path="all-mpnet-base-v2", device="cpu") |
|
|
|
def get_cached_embeddings(text: str, model_type: str) -> Optional[List[float]]: |
|
"""Try to get embeddings from cache""" |
|
cache_key = f"{model_type}_{hash(text)}" |
|
return cache.get(cache_key) |
|
|
|
def set_cached_embeddings(text: str, model_type: str, embeddings: List[float]): |
|
"""Store embeddings in cache""" |
|
cache_key = f"{model_type}_{hash(text)}" |
|
cache.set(cache_key, embeddings, expire=86400) |
|
|
|
@lru_cache(maxsize=1) |
|
def load_dataframe(): |
|
"""Load and cache the parquet dataframe""" |
|
database_file = Path(__file__).parent / "[all_embedded] The Alchemy of Happiness (Ghazzālī, Claud Field) (Z-Library).parquet" |
|
return pd.read_parquet(database_file) |
|
|
|
|
|
def cosine_similarity(embedding_0, embedding_1): |
|
dot_product = sum(a * b for a, b in zip(embedding_0, embedding_1)) |
|
norm_0 = sum(a * a for a in embedding_0) ** 0.5 |
|
norm_1 = sum(b * b for b in embedding_1) ** 0.5 |
|
return dot_product / (norm_0 * norm_1) |
|
|
|
def generate_embedding(model, text: str, model_type: str) -> List[float]: |
|
|
|
cached_embedding = get_cached_embeddings(text, model_type) |
|
if cached_embedding is not None: |
|
return cached_embedding |
|
|
|
|
|
if model_type == "all-mpnet-base-v2": |
|
chunk_embedding = model.encode( |
|
text, |
|
convert_to_tensor=True |
|
) |
|
embedding = np.array(t.Tensor.cpu(chunk_embedding)).tolist() |
|
elif model_type == "text-embedding-3-small": |
|
response = model.embeddings.create( |
|
input=text, |
|
model="text-embedding-3-small" |
|
) |
|
embedding = response.data[0].embedding |
|
|
|
|
|
set_cached_embeddings(text, model_type, embedding) |
|
return embedding |
|
|
|
def search_query(client, st_model, query: str, df: pd.DataFrame, n: int = 1) -> List[Dict]: |
|
|
|
mpnet_embedding = generate_embedding(st_model, query, "all-mpnet-base-v2") |
|
openai_embedding = generate_embedding(client, query, "text-embedding-3-small") |
|
|
|
|
|
df['mpnet_similarities'] = df.all_mpnet_embedding.apply( |
|
lambda x: cosine_similarity(x, mpnet_embedding) |
|
) |
|
df['openai_similarities'] = df.openai_embedding.apply( |
|
lambda x: cosine_similarity(x, openai_embedding) |
|
) |
|
|
|
|
|
mpnet_results = df.nlargest(n, 'mpnet_similarities') |
|
openai_results = df.nlargest(n, 'openai_similarities') |
|
|
|
|
|
results = [] |
|
|
|
for _, row in mpnet_results.iterrows(): |
|
results.append({ |
|
"text": row["ext"], |
|
"similarity": float(row["mpnet_similarities"]), |
|
"model_type": "all-mpnet-base-v2" |
|
}) |
|
|
|
for _, row in openai_results.iterrows(): |
|
results.append({ |
|
"text": row["ext"], |
|
"similarity": float(row["openai_similarities"]), |
|
"model_type": "text-embedding-3-small" |
|
}) |
|
|
|
return results |
|
|
|
|
|
def load_credentials(): |
|
credentials = {} |
|
for i in range(1, 51): |
|
username = os.environ.get(f"login_{i}") |
|
password = os.environ.get(f"password_{i}") |
|
if username and password: |
|
credentials[username] = password |
|
return credentials |
|
|
|
def authenticate_user(username: str, password: str): |
|
credentials_dict = load_credentials() |
|
if username in credentials_dict and credentials_dict[username] == password: |
|
return username |
|
return None |
|
|
|
def create_token(data: dict, expires_delta: timedelta, secret_key: str): |
|
to_encode = data.copy() |
|
expire = datetime.utcnow() + expires_delta |
|
to_encode.update({"exp": expire}) |
|
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM) |
|
return encoded_jwt |
|
|
|
def verify_token(token: str, secret_key: str): |
|
credentials_exception = HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Could not validate credentials", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
try: |
|
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) |
|
username: str = payload.get("sub") |
|
if username is None: |
|
raise credentials_exception |
|
except JWTError: |
|
raise credentials_exception |
|
return username |
|
|
|
def verify_access_token(token: str = Depends(oauth2_scheme)): |
|
return verify_token(token, SECRET_KEY) |
|
|
|
|
|
@app.get("/") |
|
def index() -> FileResponse: |
|
"""Serve the custom HTML page from the static directory""" |
|
file_path = Path(__file__).parent / "static" / "index.html" |
|
return FileResponse(path=str(file_path), media_type="text/html") |
|
|
|
@app.post("/login", response_model=TokenResponse) |
|
def login_app(form_data: OAuth2PasswordRequestForm = Depends()): |
|
username = authenticate_user(form_data.username, form_data.password) |
|
if not username: |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Invalid username or password", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
|
refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) |
|
access_token = create_token( |
|
data={"sub": username}, |
|
expires_delta=access_token_expires, |
|
secret_key=SECRET_KEY |
|
) |
|
refresh_token = create_token( |
|
data={"sub": username}, |
|
expires_delta=refresh_token_expires, |
|
secret_key=REFRESH_SECRET_KEY |
|
) |
|
return { |
|
"access_token": access_token, |
|
"refresh_token": refresh_token, |
|
"token_type": "bearer" |
|
} |
|
|
|
@app.post("/refresh", response_model=TokenResponse) |
|
async def refresh(refresh_request: RefreshRequest): |
|
""" |
|
Endpoint to refresh an access token using a valid refresh token. |
|
Returns a new access token and the existing refresh token. |
|
""" |
|
try: |
|
|
|
username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY) |
|
|
|
|
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
|
access_token = create_token( |
|
data={"sub": username}, |
|
expires_delta=access_token_expires, |
|
secret_key=SECRET_KEY |
|
) |
|
|
|
return { |
|
"access_token": access_token, |
|
"refresh_token": refresh_request.refresh_token, |
|
"token_type": "bearer" |
|
} |
|
|
|
except JWTError: |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Could not validate credentials", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
|
|
@app.post("/search", response_model=List[SearchResult]) |
|
async def search( |
|
query_input: QueryInput, |
|
username: str = Depends(verify_access_token), |
|
): |
|
try: |
|
|
|
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) |
|
st_model = get_sentence_transformer() |
|
df = load_dataframe() |
|
|
|
|
|
results = search_query(client, st_model, query_input.query, df, n=1) |
|
return [SearchResult(**result) for result in results] |
|
|
|
except Exception as e: |
|
logging.error(f"Search error: {str(e)}") |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
detail=f"Search failed: {str(e)}" |
|
) |
|
|
|
@app.post("/save") |
|
async def save_data( |
|
save_input: SaveBatchInput, |
|
username: str = Depends(verify_access_token) |
|
): |
|
try: |
|
|
|
hf_token = os.environ.get("al_ghazali_rag_retrieval_evaluation") |
|
if not hf_token: |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
detail="Hugging Face API token not found" |
|
) |
|
login(token=hf_token) |
|
|
|
|
|
data = { |
|
"user_type": [], |
|
"username": [], |
|
"query": [], |
|
"retrieved_text": [], |
|
"model_type": [], |
|
"reaction": [], |
|
"timestamp": [] |
|
} |
|
|
|
|
|
for item in save_input.items: |
|
data["user_type"].append(item.user_type) |
|
data["username"].append(item.username) |
|
data["query"].append(item.query) |
|
data["retrieved_text"].append(item.retrieved_text) |
|
data["model_type"].append(item.model_type) |
|
data["reaction"].append(item.reaction) |
|
data["timestamp"].append(timestamp or datetime.now(timezone.utc).isoformat().replace('+00:00', 'Z')) |
|
|
|
try: |
|
|
|
dataset = load_dataset( |
|
"HumbleBeeAI/al-ghazali-rag-retrieval-evaluation", |
|
split="train" |
|
) |
|
existing_data = dataset.to_dict() |
|
|
|
|
|
for key in data: |
|
if key not in existing_data: |
|
existing_data[key] = ["" if key in ["timestamp"] else None] * len(next(iter(existing_data.values()))) |
|
existing_data[key].extend(data[key]) |
|
|
|
except Exception as e: |
|
logging.warning(f"Could not load existing dataset, creating new one: {str(e)}") |
|
existing_data = data |
|
|
|
|
|
updated_dataset = Dataset.from_dict(existing_data) |
|
updated_dataset.push_to_hub( |
|
"HumbleBeeAI/al-ghazali-rag-retrieval-evaluation" |
|
) |
|
|
|
return {"message": "Data saved successfully"} |
|
|
|
except Exception as e: |
|
logging.error(f"Save error: {str(e)}") |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
detail=f"Failed to save data: {str(e)}" |
|
) |
|
|
|
|
|
app.mount("/home", StaticFiles(directory="static", html=True), name="home") |
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
os.makedirs("./cache", exist_ok=True) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |