File size: 3,165 Bytes
5697315
657c17b
5697315
 
 
e28e6dd
5697315
e28e6dd
657c17b
5697315
7b58805
657c17b
 
 
 
 
5697315
657c17b
 
5697315
657c17b
 
 
e28e6dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657c17b
 
 
5697315
 
 
 
657c17b
 
5697315
 
 
657c17b
 
 
5697315
657c17b
 
5697315
e28e6dd
 
 
 
 
 
 
c1aa475
657c17b
5697315
e28e6dd
 
c1aa475
 
e28e6dd
c1aa475
ddc0b88
e28e6dd
ddc0b88
c1aa475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddc0b88
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from fastapi import FastAPI
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
import types
import json
from pydantic import validator
from endpoint_handler import EndpointHandler  # your handler file
import base64

app = FastAPI()

handler = None

@app.on_event("startup")
async def load_handler():
    global handler
    handler = EndpointHandler()

class PredictInput(BaseModel):
    image: str       # base64-encoded image string
    question: str
    stream: bool = False

    @validator("question")
    def question_not_empty(cls, v):
        if not v.strip():
            raise ValueError("Question must not be empty")
        return v

    @validator("image")
    def valid_base64_and_size(cls, v):
        try:
            decoded = base64.b64decode(v, validate=True)
        except Exception:
            raise ValueError("`image` must be valid base64")
        if len(decoded) > 10 * 1024 * 1024:  # 10 MB limit
            raise ValueError("Image exceeds 10 MB after decoding")
        return v

class PredictRequest(BaseModel):
    inputs: PredictInput

@app.get("/")
async def root():
    return {"message": "FastAPI app is running on Hugging Face"}

@app.post("/predict")
async def predict_endpoint(payload: PredictRequest):
    print(f"[Request] Received question: {payload.inputs.question}")
    
    data = {
        "inputs": {
            "image": payload.inputs.image,
            "question": payload.inputs.question,
            "stream": payload.inputs.stream
        }
    }
    
    try:
        result = handler.predict(data)
    except ValueError as ve:
        return JSONResponse({"error": str(ve)}, status_code=400)
    except Exception as e:
        return JSONResponse({"error": "Internal server error"}, status_code=500)

    # ─── If it's a generator, return SSE/streaming ──────────────────────────
    if isinstance(result, types.GeneratorType):
        def event_stream():
            try:
                for chunk in result:
                    # Each chunk should already be a Python dict or
                    # a string containing JSON. We wrap it in "data: …\n\n"
                    yield f"data: {json.dumps(chunk)}\n\n"
                # Finally send an end‐of‐stream marker
                yield f"data: {json.dumps({'end': True})}\n\n"
            except Exception as e:
                yield f"data: {json.dumps({'error': str(e)})}\n\n"

        return StreamingResponse(event_stream(), media_type="text/event-stream")

    # ─── Otherwise (non‐streaming), return a single JSON response ──────────
    #    result is expected to be a JSON‐string or a dict
    try:
        # If handler.predict returned a JSON‐encoded str, parse it to dict
        if isinstance(result, str):
            parsed = json.loads(result)
        else:
            parsed = result  # assume it's already a dict
    except Exception:
        # Fall back to returning the raw result
        return JSONResponse({"error": "Invalid JSON from handler"}, status_code=500)

    return JSONResponse(parsed)