builder / app.py
mgbam's picture
Update app.py
f7cf3be verified
raw
history blame
4.08 kB
# app.py ── SHASHAΒ AI β€œHybrid” (FastAPIΒ +Β GradioΒ +Β Static UI)
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
import asyncio
import gradio as gr
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
# ────────────────────────────────────────────────────────────────
# internal helpers (unchanged)
# ────────────────────────────────────────────────────────────────
from constants import (
HTML_SYSTEM_PROMPT,
TRANSFORMERS_JS_SYSTEM_PROMPT,
AVAILABLE_MODELS,
)
from inference import generation_code
History = List[Tuple[str, str]]
# ────────────────────────────────────────────────────────────────
# 1.Β  Blocks‑only β€œheadless” API (no UI, just /api/predict JSON)
# ────────────────────────────────────────────────────────────────
with gr.Blocks(css="body{display:none}") as api_demo: # invisible
prompt_in = gr.Textbox()
file_in = gr.File()
url_in = gr.Textbox()
model_state = gr.State(AVAILABLE_MODELS[0])
search_chk = gr.Checkbox()
lang_dd = gr.Dropdown(choices=["html", "python"], value="html")
hist_state = gr.State([])
code_out = gr.Textbox() # plain JSON
hist_out = gr.State()
preview_out = gr.Textbox()
chat_out = gr.State()
api_demo.load(
generation_code,
inputs=[prompt_in, file_in, url_in, model_state,
search_chk, lang_dd, hist_state],
outputs=[code_out, hist_out, preview_out, chat_out],
)
# ────────────────────────────────────────────────────────────────
# 2.Β  Hybrid FastAPI server mounts:
# β€’ / β†’ static/ (index.html, style.css, index.js …)
# β€’ /api/* β†’ Gradio JSON (& websocket queue)
# ────────────────────────────────────────────────────────────────
app = FastAPI(title="SHASHAΒ AI hybrid server")
# static assets
app.mount(
"/", StaticFiles(directory="static", html=True), name="static"
)
# gradio API
app.mount(
"/api",
gr.mount_gradio_app(app, api_demo, path="/predict"), # POST /api/predict
name="gradio-api",
)
# ────────────────────────────────────────────────────────────────
# 3.Β  Bonus: Web‑socket streamer for lightning‑fast preview
# ────────────────────────────────────────────────────────────────
@app.websocket("/api/stream")
async def stream(websocket):
"""
Front‑end connects, sends the same JSON as /predict,
and receives chunks (string tokens) as they arrive.
"""
await websocket.accept()
payload = await websocket.receive_json()
queue: asyncio.Queue[str] = asyncio.Queue()
# spawn background generation
async def _run() -> None:
async for token in generation_code.stream(**payload): # type: ignore
await queue.put(token)
await queue.put("__END__")
asyncio.create_task(_run())
while True:
item = await queue.get()
if item == "__END__":
break
await websocket.send_text(item)
await websocket.close()