# app.py ββ SHASHAΒ AI βHybridβ (FastAPIΒ +Β GradioΒ +Β Static UI) | |
from __future__ import annotations | |
from typing import Any, Dict, List, Optional, Tuple | |
import asyncio | |
import gradio as gr | |
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# internal helpers (unchanged) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
from constants import ( | |
HTML_SYSTEM_PROMPT, | |
TRANSFORMERS_JS_SYSTEM_PROMPT, | |
AVAILABLE_MODELS, | |
) | |
from inference import generation_code | |
History = List[Tuple[str, str]] | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 1.Β Blocksβonly βheadlessβ API (no UI, just /api/predict JSON) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
with gr.Blocks(css="body{display:none}") as api_demo: # invisible | |
prompt_in = gr.Textbox() | |
file_in = gr.File() | |
url_in = gr.Textbox() | |
model_state = gr.State(AVAILABLE_MODELS[0]) | |
search_chk = gr.Checkbox() | |
lang_dd = gr.Dropdown(choices=["html", "python"], value="html") | |
hist_state = gr.State([]) | |
code_out = gr.Textbox() # plain JSON | |
hist_out = gr.State() | |
preview_out = gr.Textbox() | |
chat_out = gr.State() | |
api_demo.load( | |
generation_code, | |
inputs=[prompt_in, file_in, url_in, model_state, | |
search_chk, lang_dd, hist_state], | |
outputs=[code_out, hist_out, preview_out, chat_out], | |
) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 2.Β Hybrid FastAPI server mounts: | |
# β’ / β static/ (index.html, style.css, index.js β¦) | |
# β’ /api/* β Gradio JSON (& websocket queue) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
app = FastAPI(title="SHASHAΒ AI hybrid server") | |
# static assets | |
app.mount( | |
"/", StaticFiles(directory="static", html=True), name="static" | |
) | |
# gradio API | |
app.mount( | |
"/api", | |
gr.mount_gradio_app(app, api_demo, path="/predict"), # POST /api/predict | |
name="gradio-api", | |
) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 3.Β Bonus: Webβsocket streamer for lightningβfast preview | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
async def stream(websocket): | |
""" | |
Frontβend connects, sends the same JSON as /predict, | |
and receives chunks (string tokens) as they arrive. | |
""" | |
await websocket.accept() | |
payload = await websocket.receive_json() | |
queue: asyncio.Queue[str] = asyncio.Queue() | |
# spawn background generation | |
async def _run() -> None: | |
async for token in generation_code.stream(**payload): # type: ignore | |
await queue.put(token) | |
await queue.put("__END__") | |
asyncio.create_task(_run()) | |
while True: | |
item = await queue.get() | |
if item == "__END__": | |
break | |
await websocket.send_text(item) | |
await websocket.close() | |