File size: 4,530 Bytes
1b65cc5 d648fe6 1b65cc5 2f85c93 d648fe6 1b65cc5 434b328 7c06e97 d648fe6 1b65cc5 d648fe6 1b65cc5 d648fe6 1b65cc5 b27e104 1b65cc5 0e58feb 1b65cc5 d648fe6 1b65cc5 d648fe6 1b65cc5 d648fe6 1b65cc5 d648fe6 1b65cc5 d648fe6 1b65cc5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
# ------------------------------ main.py ------------------------------
import os, base64
from dotenv import load_dotenv
from fastapi import FastAPI, Request
from fastapi.responses import RedirectResponse
from starlette.middleware.sessions import SessionMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from authlib.integrations.starlette_client import OAuth
import gradio as gr
from src.manager.manager import GeminiManager
# 1. Load environment --------------------------------------------------
load_dotenv()
AUTH0_DOMAIN = os.getenv("AUTH0_DOMAIN")
AUTH0_CLIENT_ID = os.getenv("AUTH0_CLIENT_ID")
AUTH0_CLIENT_SECRET = os.getenv("AUTH0_CLIENT_SECRET")
AUTH0_AUDIENCE = os.getenv("AUTH0_AUDIENCE")
SESSION_SECRET_KEY = os.getenv("SESSION_SECRET_KEY", "replace‑me")
# 2. Auth0 client ------------------------------------------------------
oauth = OAuth()
oauth.register(
"auth0",
client_id=AUTH0_CLIENT_ID,
client_secret=AUTH0_CLIENT_SECRET,
client_kwargs={"scope": "openid profile email"},
server_metadata_url=f"https://{AUTH0_DOMAIN}/.well-known/openid-configuration",
)
# 3. FastAPI app -------------------------------------------------------
app = FastAPI()
# 3a. *Inner* auth‑gate middleware (needs session already populated)
class RequireAuthMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
public = ("/login", "/auth", "/logout", "/static", "/assets", "/favicon")
if any(request.url.path.startswith(p) for p in public):
return await call_next(request)
if "user" not in request.session:
return RedirectResponse("/login")
return await call_next(request)
app.add_middleware(RequireAuthMiddleware) # Add **first** (inner)
app.add_middleware(SessionMiddleware, secret_key=SESSION_SECRET_KEY) # Add **second** (outer)
# 4. Auth routes -------------------------------------------------------
@app.get("/login")
async def login(request: Request):
return await oauth.auth0.authorize_redirect(request, request.url_for("auth"), audience=AUTH0_AUDIENCE)
@app.get("/auth")
async def auth(request: Request):
token = await oauth.auth0.authorize_access_token(request)
request.session["user"] = token["userinfo"]
return RedirectResponse("/")
@app.get("/logout")
async def logout(request: Request):
request.session.clear()
return RedirectResponse(
f"https://{AUTH0_DOMAIN}/v2/logout?client_id={AUTH0_CLIENT_ID}&returnTo=http://localhost:7860/"
)
# 5. Gradio UI ---------------------------------------------------------
_logo_b64 = base64.b64encode(open("HASHIRU_LOGO.png", "rb").read()).decode()
HEADER_HTML = f"""
<div style='display:flex;align-items:center;width:30%;'>
<img src='data:image/png;base64,{_logo_b64}' width='40' class='logo'/>
<h1>HASHIRU AI</h1>
</div>"""
CSS = ".logo{margin-right:20px;}"
def run_model(message, history):
history.append({"role": "user", "content": message})
yield "", history
for messages in model_manager.run(history):
for m in messages:
if m.get("role") == "summary":
print("Summary:", m["content"])
yield "", messages
def update_model(name):
print("Model changed to:", name)
with gr.Blocks(css=CSS, fill_width=True, fill_height=True) as demo:
model_manager = GeminiManager(gemini_model="gemini-2.0-flash")
with gr.Column():
with gr.Row():
gr.Markdown(HEADER_HTML)
model_dropdown = gr.Dropdown(
[
"HASHIRU",
"Static-HASHIRU",
"Cloud-Only HASHIRU",
"Local-Only HASHIRU",
"No-Economy HASHIRU",
],
value="HASHIRU",
interactive=True,
)
model_dropdown.change(update_model, model_dropdown)
with gr.Row():
chatbot = gr.Chatbot(
avatar_images=("HASHIRU_2.png", "HASHIRU.png"),
type="messages", show_copy_button=True, editable="user",
placeholder="Type your message here…",
)
gr.ChatInterface(run_model, type="messages", chatbot=chatbot, additional_outputs=[chatbot], save_history=True)
# Mount at root
gr.mount_gradio_app(app, demo, path="/")
# 6. Entrypoint --------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|