Kunal Pai
commited on
Commit
·
1b65cc5
1
Parent(s):
167a572
initial work on Auth0
Browse files
main.py
CHANGED
@@ -1,55 +1,98 @@
|
|
1 |
-
|
2 |
|
3 |
-
import base64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from src.manager.manager import GeminiManager
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
""
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
def run_model(message, history):
|
28 |
-
history.append({
|
29 |
-
"role": "user",
|
30 |
-
"content": message,
|
31 |
-
})
|
32 |
yield "", history
|
33 |
for messages in model_manager.run(history):
|
34 |
-
for
|
35 |
-
if
|
36 |
-
print(
|
37 |
yield "", messages
|
38 |
|
39 |
|
40 |
-
def update_model(
|
41 |
-
print(
|
42 |
-
pass
|
43 |
|
44 |
|
45 |
-
with gr.Blocks(css=
|
46 |
model_manager = GeminiManager(gemini_model="gemini-2.0-flash")
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
gr.Markdown(_header_html)
|
51 |
model_dropdown = gr.Dropdown(
|
52 |
-
|
53 |
"HASHIRU",
|
54 |
"Static-HASHIRU",
|
55 |
"Cloud-Only HASHIRU",
|
@@ -59,20 +102,19 @@ with gr.Blocks(css=css, fill_width=True, fill_height=True) as demo:
|
|
59 |
value="HASHIRU",
|
60 |
interactive=True,
|
61 |
)
|
62 |
-
|
63 |
-
|
64 |
-
fn=update_model, inputs=model_dropdown, outputs=[])
|
65 |
-
with gr.Row(scale=1):
|
66 |
chatbot = gr.Chatbot(
|
67 |
avatar_images=("HASHIRU_2.png", "HASHIRU.png"),
|
68 |
-
type="messages",
|
69 |
-
|
70 |
-
editable="user",
|
71 |
-
scale=1,
|
72 |
-
render_markdown=True,
|
73 |
-
placeholder="Type your message here...",
|
74 |
)
|
75 |
-
gr.ChatInterface(
|
76 |
-
|
|
|
|
|
|
|
|
|
77 |
if __name__ == "__main__":
|
78 |
-
|
|
|
|
1 |
+
# ------------------------------ main.py ------------------------------
|
2 |
|
3 |
+
import os, base64
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from fastapi import FastAPI, Request
|
6 |
+
from fastapi.responses import RedirectResponse
|
7 |
+
from starlette.middleware.sessions import SessionMiddleware
|
8 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
9 |
+
from authlib.integrations.starlette_client import OAuth
|
10 |
+
import gradio as gr
|
11 |
from src.manager.manager import GeminiManager
|
12 |
|
13 |
+
# 1. Load environment --------------------------------------------------
|
14 |
+
load_dotenv()
|
15 |
+
AUTH0_DOMAIN = os.getenv("AUTH0_DOMAIN")
|
16 |
+
AUTH0_CLIENT_ID = os.getenv("AUTH0_CLIENT_ID")
|
17 |
+
AUTH0_CLIENT_SECRET = os.getenv("AUTH0_CLIENT_SECRET")
|
18 |
+
AUTH0_AUDIENCE = os.getenv("AUTH0_AUDIENCE")
|
19 |
+
SESSION_SECRET_KEY = os.getenv("SESSION_SECRET_KEY", "replace‑me")
|
20 |
+
|
21 |
+
# 2. Auth0 client ------------------------------------------------------
|
22 |
+
oauth = OAuth()
|
23 |
+
oauth.register(
|
24 |
+
"auth0",
|
25 |
+
client_id=AUTH0_CLIENT_ID,
|
26 |
+
client_secret=AUTH0_CLIENT_SECRET,
|
27 |
+
client_kwargs={"scope": "openid profile email"},
|
28 |
+
server_metadata_url=f"https://{AUTH0_DOMAIN}/.well-known/openid-configuration",
|
29 |
+
)
|
30 |
+
|
31 |
+
# 3. FastAPI app -------------------------------------------------------
|
32 |
+
app = FastAPI()
|
33 |
+
|
34 |
+
# 3a. *Inner* auth‑gate middleware (needs session already populated)
|
35 |
+
class RequireAuthMiddleware(BaseHTTPMiddleware):
|
36 |
+
async def dispatch(self, request: Request, call_next):
|
37 |
+
public = ("/login", "/auth", "/logout", "/static", "/assets", "/favicon")
|
38 |
+
if any(request.url.path.startswith(p) for p in public):
|
39 |
+
return await call_next(request)
|
40 |
+
if "user" not in request.session:
|
41 |
+
return RedirectResponse("/login")
|
42 |
+
return await call_next(request)
|
43 |
+
|
44 |
+
app.add_middleware(RequireAuthMiddleware) # Add **first** (inner)
|
45 |
+
app.add_middleware(SessionMiddleware, secret_key=SESSION_SECRET_KEY) # Add **second** (outer)
|
46 |
+
|
47 |
+
# 4. Auth routes -------------------------------------------------------
|
48 |
+
@app.get("/login")
|
49 |
+
async def login(request: Request):
|
50 |
+
return await oauth.auth0.authorize_redirect(request, request.url_for("auth"), audience=AUTH0_AUDIENCE)
|
51 |
+
|
52 |
+
@app.get("/auth")
|
53 |
+
async def auth(request: Request):
|
54 |
+
token = await oauth.auth0.authorize_access_token(request)
|
55 |
+
request.session["user"] = token["userinfo"]
|
56 |
+
return RedirectResponse("/")
|
57 |
+
|
58 |
+
@app.get("/logout")
|
59 |
+
async def logout(request: Request):
|
60 |
+
request.session.clear()
|
61 |
+
return RedirectResponse(
|
62 |
+
f"https://{AUTH0_DOMAIN}/v2/logout?client_id={AUTH0_CLIENT_ID}&returnTo=http://localhost:7860/"
|
63 |
+
)
|
64 |
+
|
65 |
+
# 5. Gradio UI ---------------------------------------------------------
|
66 |
+
_logo_b64 = base64.b64encode(open("HASHIRU_LOGO.png", "rb").read()).decode()
|
67 |
+
HEADER_HTML = f"""
|
68 |
+
<div style='display:flex;align-items:center;width:30%;'>
|
69 |
+
<img src='data:image/png;base64,{_logo_b64}' width='40' class='logo'/>
|
70 |
+
<h1>HASHIRU AI</h1>
|
71 |
+
</div>"""
|
72 |
+
CSS = ".logo{margin-right:20px;}"
|
73 |
|
74 |
|
75 |
def run_model(message, history):
|
76 |
+
history.append({"role": "user", "content": message})
|
|
|
|
|
|
|
77 |
yield "", history
|
78 |
for messages in model_manager.run(history):
|
79 |
+
for m in messages:
|
80 |
+
if m.get("role") == "summary":
|
81 |
+
print("Summary:", m["content"])
|
82 |
yield "", messages
|
83 |
|
84 |
|
85 |
+
def update_model(name):
|
86 |
+
print("Model changed to:", name)
|
|
|
87 |
|
88 |
|
89 |
+
with gr.Blocks(css=CSS, fill_width=True, fill_height=True) as demo:
|
90 |
model_manager = GeminiManager(gemini_model="gemini-2.0-flash")
|
91 |
+
with gr.Column():
|
92 |
+
with gr.Row():
|
93 |
+
gr.Markdown(HEADER_HTML)
|
|
|
94 |
model_dropdown = gr.Dropdown(
|
95 |
+
[
|
96 |
"HASHIRU",
|
97 |
"Static-HASHIRU",
|
98 |
"Cloud-Only HASHIRU",
|
|
|
102 |
value="HASHIRU",
|
103 |
interactive=True,
|
104 |
)
|
105 |
+
model_dropdown.change(update_model, model_dropdown)
|
106 |
+
with gr.Row():
|
|
|
|
|
107 |
chatbot = gr.Chatbot(
|
108 |
avatar_images=("HASHIRU_2.png", "HASHIRU.png"),
|
109 |
+
type="messages", show_copy_button=True, editable="user",
|
110 |
+
placeholder="Type your message here…",
|
|
|
|
|
|
|
|
|
111 |
)
|
112 |
+
gr.ChatInterface(run_model, type="messages", chatbot=chatbot, additional_outputs=[chatbot], save_history=True)
|
113 |
+
|
114 |
+
# Mount at root
|
115 |
+
gr.mount_gradio_app(app, demo, path="/")
|
116 |
+
|
117 |
+
# 6. Entrypoint --------------------------------------------------------
|
118 |
if __name__ == "__main__":
|
119 |
+
import uvicorn
|
120 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|