File size: 4,084 Bytes
f7cf3be
6dcd973
f7cf3be
 
6dcd973
f7cf3be
10686a9
f7cf3be
 
10686a9
f7cf3be
 
 
afdc33e
 
 
 
10686a9
f7cf3be
9b171dd
10686a9
e7d5ce8
f7cf3be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83257f1
f7cf3be
 
 
 
 
 
83257f1
f7cf3be
 
 
 
83257f1
f7cf3be
 
 
 
 
 
6dcd973
f7cf3be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# app.py ── SHASHAΒ AI β€œHybrid” (FastAPIΒ +Β GradioΒ +Β Static UI)

from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple

import asyncio
import gradio as gr
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles

# ────────────────────────────────────────────────────────────────
# internal helpers (unchanged)
# ────────────────────────────────────────────────────────────────
from constants import (
    HTML_SYSTEM_PROMPT,
    TRANSFORMERS_JS_SYSTEM_PROMPT,
    AVAILABLE_MODELS,
)
from inference import generation_code

History = List[Tuple[str, str]]

# ────────────────────────────────────────────────────────────────
# 1.Β  Blocks‑only β€œheadless” API  (no UI, just /api/predict JSON)
# ────────────────────────────────────────────────────────────────
with gr.Blocks(css="body{display:none}") as api_demo:  # invisible
    prompt_in     = gr.Textbox()
    file_in       = gr.File()
    url_in        = gr.Textbox()
    model_state   = gr.State(AVAILABLE_MODELS[0])
    search_chk    = gr.Checkbox()
    lang_dd       = gr.Dropdown(choices=["html", "python"], value="html")
    hist_state    = gr.State([])

    code_out      = gr.Textbox()       # plain JSON
    hist_out      = gr.State()
    preview_out   = gr.Textbox()
    chat_out      = gr.State()

    api_demo.load(
        generation_code,
        inputs=[prompt_in, file_in, url_in, model_state,
                search_chk, lang_dd, hist_state],
        outputs=[code_out, hist_out, preview_out, chat_out],
    )

# ────────────────────────────────────────────────────────────────
# 2.Β  Hybrid FastAPI server mounts:
#      β€’ /           β†’ static/  (index.html, style.css, index.js …)
#      β€’ /api/*      β†’ Gradio JSON (& websocket queue)
# ────────────────────────────────────────────────────────────────
app = FastAPI(title="SHASHAΒ AI hybrid server")

# static assets
app.mount(
    "/", StaticFiles(directory="static", html=True), name="static"
)

# gradio API
app.mount(
    "/api",
    gr.mount_gradio_app(app, api_demo, path="/predict"),   # POST /api/predict
    name="gradio-api",
)

# ────────────────────────────────────────────────────────────────
# 3.Β  Bonus: Web‑socket streamer for lightning‑fast preview
# ────────────────────────────────────────────────────────────────
@app.websocket("/api/stream")
async def stream(websocket):
    """
    Front‑end connects, sends the same JSON as /predict,
    and receives chunks (string tokens) as they arrive.
    """
    await websocket.accept()
    payload = await websocket.receive_json()
    queue: asyncio.Queue[str] = asyncio.Queue()

    # spawn background generation
    async def _run() -> None:
        async for token in generation_code.stream(**payload):  # type: ignore
            await queue.put(token)
        await queue.put("__END__")

    asyncio.create_task(_run())

    while True:
        item = await queue.get()
        if item == "__END__":
            break
        await websocket.send_text(item)

    await websocket.close()