Bahodir Nematjonov commited on
Commit
e4f5d4a
·
1 Parent(s): a6e37af

file update

Browse files
Files changed (11) hide show
  1. .gitignore +5 -1
  2. auth.py +69 -37
  3. cache.py +23 -0
  4. database.py +16 -0
  5. gen.py +0 -13
  6. main.py +107 -33
  7. middleware.py +0 -18
  8. models.py +10 -0
  9. requirements.txt +11 -0
  10. schemas.py +10 -43
  11. utils.py +21 -0
.gitignore CHANGED
@@ -1 +1,5 @@
1
- .venv
 
 
 
 
 
1
+ .venv
2
+ cache
3
+ users.db
4
+ .env
5
+ __pycache__
auth.py CHANGED
@@ -3,61 +3,93 @@ from fastapi.security import OAuth2PasswordBearer
3
  from jose import JWTError, jwt
4
  from passlib.context import CryptContext
5
  from datetime import datetime, timedelta
 
 
 
6
  import os
7
 
8
- # Load secrets from environment variables
9
- SECRET_KEY = os.getenv("SECRET_KEY")
10
- ALGORITHM = "HS512"
 
11
  ACCESS_TOKEN_EXPIRE_MINUTES = 30
 
12
 
13
- # Password hashing
 
 
 
14
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
15
 
16
- # OAuth2 scheme
17
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
 
 
 
 
 
18
 
19
- # Mock user database (replace with a real database in production)
20
- fake_users_db = {
21
- "admin": {
22
- "username": "admin",
23
- "hashed_password": pwd_context.hash(os.getenv("ADMIN_PASSWORD")),
24
- }
25
- }
 
26
 
27
- def verify_password(plain_password, hashed_password):
28
  return pwd_context.verify(plain_password, hashed_password)
29
 
30
- def authenticate_user(username: str, password: str):
31
- user = fake_users_db.get(username)
32
- if not user or not verify_password(password, user["hashed_password"]):
33
- return False
34
  return user
35
 
36
- def create_access_token(data: dict, expires_delta: timedelta = None):
 
37
  to_encode = data.copy()
38
- if expires_delta:
39
- expire = datetime.utcnow() + expires_delta
40
- else:
41
- expire = datetime.utcnow() + timedelta(minutes=15)
42
  to_encode.update({"exp": expire})
43
- encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
 
44
  return encoded_jwt
45
 
46
- async def get_current_user(token: str = Depends(oauth2_scheme)):
47
- credentials_exception = HTTPException(
48
- status_code=status.HTTP_401_UNAUTHORIZED,
49
- detail="Could not validate credentials",
50
- headers={"WWW-Authenticate": "Bearer"},
51
- )
52
  try:
53
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
54
  username: str = payload.get("sub")
55
  if username is None:
56
- raise credentials_exception
 
 
 
 
 
57
  except JWTError:
58
- raise credentials_exception
 
 
 
 
59
 
60
- user = fake_users_db.get(username)
61
- if user is None:
62
- raise credentials_exception
63
- return user
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from jose import JWTError, jwt
4
  from passlib.context import CryptContext
5
  from datetime import datetime, timedelta
6
+ from sqlalchemy.orm import Session
7
+ from database import engine, get_db
8
+ from models import Base, User
9
  import os
10
 
11
+ Base.metadata.create_all(bind=engine)
12
+ # Load secrets from environment variables or set defaults
13
+ SECRET_KEY = os.getenv("SECRET_KEY", "def6nQHONW99pOPyba9DShny6FB1CJJBigZault")
14
+ ALGORITHM = "HS256"
15
  ACCESS_TOKEN_EXPIRE_MINUTES = 30
16
+ REFRESH_TOKEN_EXPIRE_DAYS = 7
17
 
18
+ # OAuth2 scheme (Ensure the token URL matches the actual login endpoint)
19
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
20
+
21
+ # Password hashing context
22
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
23
 
24
+ def hash_password(password: str) -> str:
25
+ return pwd_context.hash(password)
26
+
27
+ def register_user(username: str, password: str, db: Session):
28
+ existing_user = db.query(User).filter(User.username == username).first()
29
+ if existing_user:
30
+ raise HTTPException(status_code=400, detail="Username already taken")
31
 
32
+ hashed_password = hash_password(password)
33
+ new_user = User(username=username, password=hashed_password)
34
+
35
+ db.add(new_user)
36
+ db.commit()
37
+ db.refresh(new_user)
38
+
39
+ return new_user
40
 
41
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
42
  return pwd_context.verify(plain_password, hashed_password)
43
 
44
+ def authenticate_user(username: str, password: str, db: Session):
45
+ user = db.query(User).filter(User.username == username).first()
46
+ if not user or not verify_password(password, user.password):
47
+ return None
48
  return user
49
 
50
+ def create_token(data: dict, expires_delta: timedelta, secret_key: str) -> str:
51
+ """Generate JWT token."""
52
  to_encode = data.copy()
53
+ expire = datetime.utcnow() + expires_delta
 
 
 
54
  to_encode.update({"exp": expire})
55
+
56
+ encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM)
57
  return encoded_jwt
58
 
59
+ def verify_token(token: str, secret_key: str) -> str:
60
+ """Verifies JWT token and extracts the username."""
 
 
 
 
61
  try:
62
+ payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM])
63
  username: str = payload.get("sub")
64
  if username is None:
65
+ raise HTTPException(
66
+ status_code=status.HTTP_401_UNAUTHORIZED,
67
+ detail="Invalid token",
68
+ headers={"WWW-Authenticate": "Bearer"},
69
+ )
70
+ return username
71
  except JWTError:
72
+ raise HTTPException(
73
+ status_code=status.HTTP_401_UNAUTHORIZED,
74
+ detail="Could not validate credentials",
75
+ headers={"WWW-Authenticate": "Bearer"},
76
+ )
77
 
78
+ def verify_access_token(token: str = Depends(oauth2_scheme)) -> str:
79
+ """Verifies access token and returns the username."""
80
+ try:
81
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
82
+ username: str = payload.get("sub")
83
+ if username is None:
84
+ raise HTTPException(
85
+ status_code=status.HTTP_401_UNAUTHORIZED,
86
+ detail="Invalid access token",
87
+ headers={"WWW-Authenticate": "Bearer"},
88
+ )
89
+ return username
90
+ except JWTError:
91
+ raise HTTPException(
92
+ status_code=status.HTTP_401_UNAUTHORIZED,
93
+ detail="Could not validate credentials",
94
+ headers={"WWW-Authenticate": "Bearer"},
95
+ )
cache.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional
3
+ from sentence_transformers import SentenceTransformer
4
+ from functools import lru_cache
5
+ from dotenv import load_dotenv
6
+
7
+ @lru_cache(maxsize=1)
8
+ def get_sentence_transformer() -> SentenceTransformer:
9
+ """Loads and caches the Sentence Transformer model."""
10
+ try:
11
+ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
12
+ print("✅ Sentence Transformer Model Loaded")
13
+ return model
14
+ except Exception as e:
15
+ print(f"❌ Error loading Sentence Transformer: {str(e)}")
16
+ raise RuntimeError("Failed to load Sentence Transformer model.")
17
+
18
+
19
+ def get_cached_embeddings(text: str, model_type: str) -> Optional[List[float]]:
20
+ pass
21
+
22
+ def set_cached_embeddings(text: str, model_type: str, embeddings: List[float]):
23
+ pass
database.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import create_engine
2
+ from sqlalchemy.orm import sessionmaker, declarative_base
3
+
4
+ DATABASE_URL = "sqlite:///./users.db"
5
+
6
+ engine = create_engine(DATABASE_URL)
7
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
8
+ Base = declarative_base()
9
+
10
+ def get_db():
11
+ """Dependency function to get the database session."""
12
+ db = SessionLocal()
13
+ try:
14
+ yield db
15
+ finally:
16
+ db.close()
gen.py DELETED
@@ -1,13 +0,0 @@
1
- from fastapi import WebSocket
2
- from transformers import pipeline
3
- import asyncio
4
-
5
- # Load the model
6
- model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
7
- generator = pipeline("text-generation", model=model_name)
8
-
9
- async def generate_text_stream(prompt: str, websocket: WebSocket):
10
- for i in range(10): # Simulate streaming (replace with actual model inference)
11
- chunk = generator(prompt, max_length=i + 10, do_sample=True)[0]["generated_text"]
12
- await websocket.send_text(chunk)
13
- await asyncio.sleep(0.1) # Simulate delay
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -1,44 +1,118 @@
1
- from fastapi import FastAPI, WebSocket, Depends, HTTPException
2
- from auth import get_current_user, authenticate_user, create_access_token
3
- from gen import generate_text_stream
 
 
 
 
4
  from fastapi.security import OAuth2PasswordRequestForm
5
- from fastapi.middleware.cors import CORSMiddleware
6
- from middleware import setup_rate_limiter
7
  import os
 
 
 
 
 
 
 
 
 
8
 
9
  app = FastAPI()
10
 
 
 
 
 
 
11
 
12
- # Apply rate limiting middleware
13
- setup_rate_limiter(app)
 
 
 
14
 
15
- # CORS middleware
16
- app.add_middleware(
17
- CORSMiddleware,
18
- allow_origins=["*"],
19
- allow_credentials=True,
20
- allow_methods=["*"],
21
- allow_headers=["*"],
22
- )
23
 
24
- # Login endpoint
25
- @app.post("/token")
26
- async def login(form_data: OAuth2PasswordRequestForm = Depends()):
27
- user = authenticate_user(form_data.username, form_data.password)
28
- if not user:
29
- raise HTTPException(status_code=400, detail="Incorrect username or password")
30
- access_token = create_access_token(data={"sub": user["username"]})
31
- return {"access_token": access_token, "token_type": "bearer"}
32
 
33
- # WebSocket endpoint for streaming
34
- @app.websocket("/generate")
35
- async def websocket_generate(websocket: WebSocket, token: str):
36
- await websocket.accept()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  try:
38
- user = get_current_user(token)
39
- prompt = await websocket.receive_text()
40
- await generate_text_stream(prompt, websocket)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  except Exception as e:
42
- await websocket.send_text(f"Error: {str(e)}")
43
- finally:
44
- await websocket.close()
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Depends, HTTPException, status
2
+ from fastapi.responses import FileResponse
3
+ from fastapi.staticfiles import StaticFiles
4
+ from jose import JWTError
5
+ from schemas import UserRegister, TokenResponse, RefreshTokenRequest, QueryInput
6
+ from auth import register_user, get_db, authenticate_user, create_token, verify_token, verify_access_token, Session
7
+ from utils import search_with_llm
8
  from fastapi.security import OAuth2PasswordRequestForm
9
+ from pathlib import Path
10
+ from datetime import timedelta
11
  import os
12
+ import logging
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+
16
+ SECRET_KEY = os.getenv("SECRET_KEY", 'def6nQHONW99pOPyba9DShny6FB1CJJBigZault')
17
+ REFRESH_SECRET_KEY = os.getenv("REFRESH_SECRET_KEY", SECRET_KEY)
18
+ ALGORITHM = "HS256"
19
+ ACCESS_TOKEN_EXPIRE_MINUTES = 30
20
+ REFRESH_TOKEN_EXPIRE_DAYS = 7
21
 
22
  app = FastAPI()
23
 
24
+ # Entry Endpoint
25
+ @app.get('/')
26
+ def index() -> FileResponse:
27
+ file_path = Path(__file__).parent / 'static' / 'index.html'
28
+ return FileResponse(path=str(file_path), media_type='text/html')
29
 
30
+ @app.post("/register")
31
+ async def register(user: UserRegister, db: Session = Depends(get_db)):
32
+ """Registers a new user."""
33
+ new_user = register_user(user.username, user.password, db)
34
+ return {"message": "User registered successfully", "user": new_user.username}
35
 
36
+ @app.post("/login", response_model=TokenResponse)
37
+ async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
38
+ try:
39
+ user = authenticate_user(form_data.username, form_data.password, db)
 
 
 
 
40
 
41
+ if not user:
42
+ raise HTTPException(
43
+ status_code=status.HTTP_401_UNAUTHORIZED,
44
+ detail="Invalid username or password",
45
+ headers={"WWW-Authenticate": "Bearer"},
46
+ )
 
 
47
 
48
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
49
+ refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
50
+
51
+ access_token = create_token(
52
+ data={"sub": user.username},
53
+ expires_delta=access_token_expires,
54
+ secret_key=SECRET_KEY
55
+ )
56
+ refresh_token = create_token(
57
+ data={"sub": user.username},
58
+ expires_delta=refresh_token_expires,
59
+ secret_key=REFRESH_SECRET_KEY
60
+ )
61
+
62
+ return {
63
+ "access_token": access_token,
64
+ "refresh_token": refresh_token,
65
+ "token_type": "bearer"
66
+ }
67
+ except Exception as e:
68
+ logging.error(f"Login error: {str(e)}")
69
+ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error")
70
+
71
+ @app.post("/refresh", response_model=TokenResponse)
72
+ async def refresh(refresh_request: RefreshTokenRequest):
73
  try:
74
+ # Verify the refresh token
75
+ username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY)
76
+
77
+ # Create new access token
78
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
79
+ access_token = create_token(
80
+ data={"sub": username},
81
+ expires_delta=access_token_expires,
82
+ secret_key=SECRET_KEY
83
+ )
84
+
85
+ return {
86
+ "access_token": access_token,
87
+ "refresh_token": refresh_request.refresh_token, # Return the same refresh token
88
+ "token_type": "bearer"
89
+ }
90
+
91
+ except JWTError:
92
+ raise HTTPException(
93
+ status_code=status.HTTP_401_UNAUTHORIZED,
94
+ detail="Could not validate credentials",
95
+ headers={"WWW-Authenticate": "Bearer"},
96
+ )
97
+
98
+
99
+
100
+ @app.post("/search")
101
+ async def search(
102
+ query_input: QueryInput,
103
+ username: str = Depends(verify_access_token),
104
+ ):
105
+ try:
106
+ response = search_with_llm(query_input.query)
107
+ return {"response": response}
108
  except Exception as e:
109
+ raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}")
110
+
111
+ # WebSocket endpoint for streaming
112
+ @app.on_event("startup")
113
+ async def startup_event():
114
+ os.makedirs("./cache", exist_ok=True)
115
+
116
+ if __name__ == "__main__":
117
+ import uvicorn
118
+ uvicorn.run(app, host="0.0.0.0", port=7860)
middleware.py DELETED
@@ -1,18 +0,0 @@
1
- from fastapi import Request
2
- from slowapi import Limiter
3
- from slowapi.util import get_remote_address
4
- from slowapi.middleware import SlowAPIMiddleware
5
-
6
-
7
- limiter = Limiter(key_func=get_remote_address)
8
-
9
- async def log_requests(request: Request, call_next):
10
- print(f"Request: {request.method} {request.url}")
11
- response = await call_next(request)
12
- print(f"Response: {response.status_code}")
13
- return response
14
-
15
- # Apply the rate limiter middleware
16
- def setup_rate_limiter(app):
17
- app.state.limiter = limiter
18
- app.add_middleware(SlowAPIMiddleware)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import Column, Integer, String
2
+ from database import Base
3
+
4
+ class User(Base):
5
+ """SQLAlchemy model for storing users."""
6
+ __tablename__ = "users"
7
+
8
+ id = Column(Integer, primary_key=True, index=True)
9
+ username = Column(String, unique=True, index=True)
10
+ password = Column(String) # Hashed password
requirements.txt CHANGED
@@ -22,6 +22,17 @@ loguru # Optional: for better logging
22
  # Testing (optional for development)
23
  pytest
24
  httpx # For async HTTP requests in tests
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # Docker and deployment
27
  gunicorn # Optional: for production deployment
 
22
  # Testing (optional for development)
23
  pytest
24
  httpx # For async HTTP requests in tests
25
+ bcrypt
26
+ openai
27
+ datasets
28
+ diskcache
29
+ sentence_transformers
30
+ pinecone-client
31
+ python-multipart
32
+ sqlalchemy
33
+ psycopg2
34
+ python-dotenv
35
+ ollama
36
 
37
  # Docker and deployment
38
  gunicorn # Optional: for production deployment
schemas.py CHANGED
@@ -1,49 +1,16 @@
1
- from pydantic import BaseModel, EmailStr
2
 
3
- # Authentication
4
- class Token(BaseModel):
5
- access_token: str
6
- token_type: str
7
-
8
- class TokenData(BaseModel):
9
- username: str | None = None
10
-
11
- class User(BaseModel):
12
- username: str
13
- email: EmailStr | None = None
14
-
15
- class UserInDB(User):
16
- hashed_password: str
17
-
18
- class LoginRequest(BaseModel):
19
  username: str
20
  password: str
21
 
22
- # Generation
23
- class GenerationRequest(BaseModel):
24
- prompt: str
25
- max_length: int = 100
26
- temperature: float = 0.7
27
- top_k: int = 50
28
- top_p: float = 0.95
29
-
30
- class GenerationResponse(BaseModel):
31
- generated_text: str
32
-
33
- # WebSocket
34
- class WebSocketMessage(BaseModel):
35
- prompt: str
36
 
37
- # Error Handling
38
- class HTTPError(BaseModel):
39
- detail: str
40
-
41
- class Config:
42
- schema_extra = {
43
- "example": {"detail": "Error message"},
44
- }
45
 
46
- # Rate Limiting
47
- class RateLimitResponse(BaseModel):
48
- message: str
49
- retry_after: int
 
1
+ from pydantic import BaseModel, Field
2
 
3
+ class UserRegister(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  username: str
5
  password: str
6
 
7
+ class QueryInput(BaseModel):
8
+ query: str
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ class TokenResponse(BaseModel):
11
+ access_token: str
12
+ token_type: str
13
+ refresh_token: str
 
 
 
 
14
 
15
+ class RefreshTokenRequest(BaseModel):
16
+ refresh_token: str
 
 
utils.py CHANGED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ollama
3
+ from typing import List, Dict
4
+
5
+ def cosine_similarity(embedding_0, embedding_1):
6
+ pass
7
+
8
+
9
+ def generate_embedding(model, text: str, model_type: str) -> List[float]:
10
+ pass
11
+
12
+ def search_with_llm(query: str, model: str = "llama3.2"):
13
+ try:
14
+ response = ollama.chat(
15
+ model=model,
16
+ messages=[{"role": "user", "content": query}]
17
+ )
18
+ return response["message"]["content"]
19
+ except Exception as e:
20
+ return f"❌ Error processing request: {str(e)}"
21
+