sreejith8100 commited on
Commit
c1aa475
Β·
1 Parent(s): ddc0b88
Files changed (2) hide show
  1. endpoint_handler.py +1 -2
  2. main.py +20 -3
endpoint_handler.py CHANGED
@@ -22,7 +22,7 @@ class EndpointHandler:
22
 
23
  print(f"[Model Load] Loading model from: {model_path}")
24
  try:
25
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
26
  self.model = AutoModel.from_pretrained(
27
  model_path,
28
  trust_remote_code=True,
@@ -50,7 +50,6 @@ class EndpointHandler:
50
 
51
  def predict(self, request):
52
  print(f"[Predict] Received request: {request}")
53
-
54
  image_base64 = request.get("inputs", {}).get("image")
55
  question = request.get("inputs", {}).get("question")
56
  stream = request.get("inputs", {}).get("stream", False)
 
22
 
23
  print(f"[Model Load] Loading model from: {model_path}")
24
  try:
25
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
26
  self.model = AutoModel.from_pretrained(
27
  model_path,
28
  trust_remote_code=True,
 
50
 
51
  def predict(self, request):
52
  print(f"[Predict] Received request: {request}")
 
53
  image_base64 = request.get("inputs", {}).get("image")
54
  question = request.get("inputs", {}).get("question")
55
  stream = request.get("inputs", {}).get("stream", False)
main.py CHANGED
@@ -63,15 +63,32 @@ async def predict_endpoint(payload: PredictRequest):
63
  except Exception as e:
64
  return JSONResponse({"error": "Internal server error"}, status_code=500)
65
 
 
66
  if isinstance(result, types.GeneratorType):
67
  def event_stream():
68
  try:
69
  for chunk in result:
 
 
70
  yield f"data: {json.dumps(chunk)}\n\n"
71
- # Return structured JSON to indicate end of stream
72
  yield f"data: {json.dumps({'end': True})}\n\n"
73
  except Exception as e:
74
- # Return structured JSON to indicate error
75
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
76
- return StreamingResponse(event_stream(), media_type="text/event-stream")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
63
  except Exception as e:
64
  return JSONResponse({"error": "Internal server error"}, status_code=500)
65
 
66
+ # ─── If it's a generator, return SSE/streaming ──────────────────────────
67
  if isinstance(result, types.GeneratorType):
68
  def event_stream():
69
  try:
70
  for chunk in result:
71
+ # Each chunk should already be a Python dict or
72
+ # a string containing JSON. We wrap it in "data: …\n\n"
73
  yield f"data: {json.dumps(chunk)}\n\n"
74
+ # Finally send an end‐of‐stream marker
75
  yield f"data: {json.dumps({'end': True})}\n\n"
76
  except Exception as e:
 
77
  yield f"data: {json.dumps({'error': str(e)})}\n\n"
78
+
79
+ return StreamingResponse(event_stream(), media_type="text/event-stream")
80
+
81
+ # ─── Otherwise (non‐streaming), return a single JSON response ──────────
82
+ # result is expected to be a JSON‐string or a dict
83
+ try:
84
+ # If handler.predict returned a JSON‐encoded str, parse it to dict
85
+ if isinstance(result, str):
86
+ parsed = json.loads(result)
87
+ else:
88
+ parsed = result # assume it's already a dict
89
+ except Exception:
90
+ # Fall back to returning the raw result
91
+ return JSONResponse({"error": "Invalid JSON from handler"}, status_code=500)
92
+
93
+ return JSONResponse(parsed)
94