sreejith8100 commited on
Commit
5697315
·
1 Parent(s): f998819
Files changed (1) hide show
  1. main.py +24 -69
main.py CHANGED
@@ -1,95 +1,50 @@
1
- from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks
2
  from fastapi.responses import JSONResponse, StreamingResponse
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from pydantic import BaseModel, validator
5
- import base64, types, json, logging, asyncio
6
- from endpoint_handler import EndpointHandler
7
 
8
- logger = logging.getLogger("uvicorn.error")
9
- app = FastAPI()
10
 
 
11
 
12
  handler = None
13
 
14
  @app.on_event("startup")
15
  async def load_handler():
16
  global handler
17
- try:
18
- handler = EndpointHandler()
19
- logger.info("EndpointHandler initialized successfully.")
20
- except Exception as e:
21
- logger.error(f"Failed to initialize handler: {e}", exc_info=True)
22
- handler = None # so /predict will return 503
23
-
24
- @app.on_event("shutdown")
25
- async def cleanup_handler():
26
- if handler:
27
- try:
28
- handler.close()
29
- logger.info("Handler cleaned up on shutdown.")
30
- except Exception:
31
- logger.error("Error during handler cleanup", exc_info=True)
32
 
33
  class PredictInput(BaseModel):
34
- image: str
35
  question: str
36
  stream: bool = False
37
 
38
- @validator("question")
39
- def question_not_empty(cls, v):
40
- if not v.strip():
41
- raise ValueError("Question must not be empty")
42
- return v
43
-
44
- @validator("image")
45
- def valid_base64(cls, v):
46
- try:
47
- base64.b64decode(v, validate=True)
48
- except Exception:
49
- raise ValueError("`image` must be valid base64")
50
- return v
51
-
52
  class PredictRequest(BaseModel):
53
  inputs: PredictInput
54
 
 
 
 
 
55
  @app.post("/predict")
56
  async def predict_endpoint(payload: PredictRequest):
57
- if handler is None:
58
- return JSONResponse({"error": "Service unavailable"}, status_code=503)
59
-
60
- # Log input
61
- logger.info(f"Received question: {payload.inputs.question}")
62
-
63
- # Prepare the data dict exactly how EndpointHandler expects
64
- request_dict = {
65
  "inputs": {
66
  "image": payload.inputs.image,
67
  "question": payload.inputs.question,
68
- "stream": payload.inputs.stream,
69
  }
70
  }
71
-
72
- try:
73
- result = await asyncio.to_thread(handler.predict, request_dict)
74
- except ValueError as ve:
75
- return JSONResponse({"error": str(ve)}, status_code=400)
76
- except Exception as e:
77
- logger.error("Unexpected error in handler.predict", exc_info=True)
78
- return JSONResponse({"error": "Internal server error"}, status_code=500)
79
-
80
- # If handler.predict returned a generator, wrap in SSE
81
  if isinstance(result, types.GeneratorType):
82
- async def event_stream():
83
- try:
84
- for chunk in result:
85
- yield f"data: {json.dumps(chunk)}\n\n"
86
- except (asyncio.CancelledError, ConnectionResetError):
87
- logger.info("Client disconnected from stream.")
88
- except Exception:
89
- logger.error("Error during streaming", exc_info=True)
90
- yield f"data: {json.dumps({'error': 'Stream error'})}\n\n"
91
-
92
  return StreamingResponse(event_stream(), media_type="text/event-stream")
93
-
94
- # Otherwise normal JSON
95
- return JSONResponse(content=result, status_code=200)
 
1
+ from fastapi import FastAPI
2
  from fastapi.responses import JSONResponse, StreamingResponse
3
+ from pydantic import BaseModel
4
+ import types
5
+ import json
 
6
 
7
+ from endpoint_handler import EndpointHandler # your handler file
 
8
 
9
+ app = FastAPI()
10
 
11
  handler = None
12
 
13
  @app.on_event("startup")
14
  async def load_handler():
15
  global handler
16
+ handler = EndpointHandler()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class PredictInput(BaseModel):
19
+ image: str # base64-encoded image string
20
  question: str
21
  stream: bool = False
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  class PredictRequest(BaseModel):
24
  inputs: PredictInput
25
 
26
+ @app.get("/")
27
+ async def root():
28
+ return {"message": "FastAPI app is running on Hugging Face"}
29
+
30
  @app.post("/predict")
31
  async def predict_endpoint(payload: PredictRequest):
32
+ print(f"[Request] Received question: {payload.inputs.question}")
33
+
34
+ data = {
 
 
 
 
 
35
  "inputs": {
36
  "image": payload.inputs.image,
37
  "question": payload.inputs.question,
38
+ "stream": payload.inputs.stream
39
  }
40
  }
41
+
42
+ result = handler.predict(data)
43
+
 
 
 
 
 
 
 
44
  if isinstance(result, types.GeneratorType):
45
+ def event_stream():
46
+ for chunk in result:
47
+ yield f"data: {json.dumps(chunk)}\n\n"
 
 
 
 
 
 
 
48
  return StreamingResponse(event_stream(), media_type="text/event-stream")
49
+
50
+ return JSONResponse(content=result)