sreejith8100 commited on
Commit
e28e6dd
·
1 Parent(s): 5697315
Files changed (1) hide show
  1. main.py +33 -6
main.py CHANGED
@@ -3,8 +3,9 @@ from fastapi.responses import JSONResponse, StreamingResponse
3
  from pydantic import BaseModel
4
  import types
5
  import json
6
-
7
  from endpoint_handler import EndpointHandler # your handler file
 
8
 
9
  app = FastAPI()
10
 
@@ -20,6 +21,22 @@ class PredictInput(BaseModel):
20
  question: str
21
  stream: bool = False
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class PredictRequest(BaseModel):
24
  inputs: PredictInput
25
 
@@ -39,12 +56,22 @@ async def predict_endpoint(payload: PredictRequest):
39
  }
40
  }
41
 
42
- result = handler.predict(data)
43
-
 
 
 
 
 
44
  if isinstance(result, types.GeneratorType):
45
  def event_stream():
46
- for chunk in result:
47
- yield f"data: {json.dumps(chunk)}\n\n"
 
 
 
 
 
 
48
  return StreamingResponse(event_stream(), media_type="text/event-stream")
49
-
50
  return JSONResponse(content=result)
 
3
  from pydantic import BaseModel
4
  import types
5
  import json
6
+ from pydantic import validator
7
  from endpoint_handler import EndpointHandler # your handler file
8
+ import base64
9
 
10
  app = FastAPI()
11
 
 
21
  question: str
22
  stream: bool = False
23
 
24
+ @validator("question")
25
+ def question_not_empty(cls, v):
26
+ if not v.strip():
27
+ raise ValueError("Question must not be empty")
28
+ return v
29
+
30
+ @validator("image")
31
+ def valid_base64_and_size(cls, v):
32
+ try:
33
+ decoded = base64.b64decode(v, validate=True)
34
+ except Exception:
35
+ raise ValueError("`image` must be valid base64")
36
+ if len(decoded) > 10 * 1024 * 1024: # 10 MB limit
37
+ raise ValueError("Image exceeds 10 MB after decoding")
38
+ return v
39
+
40
  class PredictRequest(BaseModel):
41
  inputs: PredictInput
42
 
 
56
  }
57
  }
58
 
59
+ try:
60
+ result = handler.predict(data)
61
+ except ValueError as ve:
62
+ return JSONResponse({"error": str(ve)}, status_code=400)
63
+ except Exception as e:
64
+ return JSONResponse({"error": "Internal server error"}, status_code=500)
65
+
66
  if isinstance(result, types.GeneratorType):
67
  def event_stream():
68
+ try:
69
+ for chunk in result:
70
+ yield f"data: {json.dumps(chunk)}\n\n"
71
+ # Add [END] marker after generator ends
72
+ yield 'data: "[END]"\n\n'
73
+ except Exception as e:
74
+ # Send error in stream
75
+ yield f'data: "[ERROR] {str(e)}"\n\n'
76
  return StreamingResponse(event_stream(), media_type="text/event-stream")
 
77
  return JSONResponse(content=result)