Harshil Patel commited on
Commit
a5b8c89
·
1 Parent(s): 67a949c

Make login mandatory

Browse files
Files changed (1) hide show
  1. main.py +28 -2
main.py CHANGED
@@ -1,6 +1,6 @@
1
  import os, base64
2
  from dotenv import load_dotenv
3
- from fastapi import FastAPI, Request
4
  from fastapi.responses import RedirectResponse, JSONResponse, FileResponse
5
  from fastapi.staticfiles import StaticFiles
6
  from starlette.middleware.sessions import SessionMiddleware
@@ -48,6 +48,20 @@ app.add_middleware(
48
  )
49
 
50
  # 4. Auth routes -------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @app.get("/login")
52
  async def login(request: Request):
53
  print("Session cookie:", request.cookies.get("session"))
@@ -236,6 +250,18 @@ def run_model(message, history):
236
 
237
  def update_model(name):
238
  print("Model changed to:", name)
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  with gr.Blocks(css=CSS, fill_width=True, fill_height=True) as demo:
241
  model_manager = GeminiManager(gemini_model="gemini-2.0-flash")
@@ -348,7 +374,7 @@ with gr.Blocks(css=CSS, fill_width=True, fill_height=True) as demo:
348
  }
349
  """)
350
 
351
- gr.mount_gradio_app(app, demo, path="/")
352
 
353
  # 6. Entrypoint --------------------------------------------------------
354
  if __name__ == "__main__":
 
1
  import os, base64
2
  from dotenv import load_dotenv
3
+ from fastapi import FastAPI, Request, Depends
4
  from fastapi.responses import RedirectResponse, JSONResponse, FileResponse
5
  from fastapi.staticfiles import StaticFiles
6
  from starlette.middleware.sessions import SessionMiddleware
 
48
  )
49
 
50
  # 4. Auth routes -------------------------------------------------------
51
+ # Dependency to get the current user
52
+ def get_user(request: Request):
53
+ user = request.session.get('user')
54
+ if user:
55
+ return user['name']
56
+ return None
57
+
58
+ @app.get('/')
59
+ def public(request: Request, user = Depends(get_user)):
60
+ if user:
61
+ return RedirectResponse("/gradio")
62
+ else:
63
+ return RedirectResponse("/main")
64
+
65
  @app.get("/login")
66
  async def login(request: Request):
67
  print("Session cookie:", request.cookies.get("session"))
 
250
 
251
  def update_model(name):
252
  print("Model changed to:", name)
253
+
254
+ with gr.Blocks() as login:
255
+ btn = gr.Button("Login")
256
+ _js_redirect = """
257
+ () => {
258
+ url = '/login' + window.location.search;
259
+ window.open(url, '_blank');
260
+ }
261
+ """
262
+ btn.click(None, js=_js_redirect)
263
+
264
+ app = gr.mount_gradio_app(app, login, path="/main")
265
 
266
  with gr.Blocks(css=CSS, fill_width=True, fill_height=True) as demo:
267
  model_manager = GeminiManager(gemini_model="gemini-2.0-flash")
 
374
  }
375
  """)
376
 
377
+ app = gr.mount_gradio_app(app, demo, path="/gradio",auth_dependency=get_user)
378
 
379
  # 6. Entrypoint --------------------------------------------------------
380
  if __name__ == "__main__":