Spaces:
Paused
Paused
Commit
Β·
c1aa475
1
Parent(s):
ddc0b88
stream
Browse files- endpoint_handler.py +1 -2
- main.py +20 -3
endpoint_handler.py
CHANGED
@@ -22,7 +22,7 @@ class EndpointHandler:
|
|
22 |
|
23 |
print(f"[Model Load] Loading model from: {model_path}")
|
24 |
try:
|
25 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_path,
|
26 |
self.model = AutoModel.from_pretrained(
|
27 |
model_path,
|
28 |
trust_remote_code=True,
|
@@ -50,7 +50,6 @@ class EndpointHandler:
|
|
50 |
|
51 |
def predict(self, request):
|
52 |
print(f"[Predict] Received request: {request}")
|
53 |
-
|
54 |
image_base64 = request.get("inputs", {}).get("image")
|
55 |
question = request.get("inputs", {}).get("question")
|
56 |
stream = request.get("inputs", {}).get("stream", False)
|
|
|
22 |
|
23 |
print(f"[Model Load] Loading model from: {model_path}")
|
24 |
try:
|
25 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
|
26 |
self.model = AutoModel.from_pretrained(
|
27 |
model_path,
|
28 |
trust_remote_code=True,
|
|
|
50 |
|
51 |
def predict(self, request):
|
52 |
print(f"[Predict] Received request: {request}")
|
|
|
53 |
image_base64 = request.get("inputs", {}).get("image")
|
54 |
question = request.get("inputs", {}).get("question")
|
55 |
stream = request.get("inputs", {}).get("stream", False)
|
main.py
CHANGED
@@ -63,15 +63,32 @@ async def predict_endpoint(payload: PredictRequest):
|
|
63 |
except Exception as e:
|
64 |
return JSONResponse({"error": "Internal server error"}, status_code=500)
|
65 |
|
|
|
66 |
if isinstance(result, types.GeneratorType):
|
67 |
def event_stream():
|
68 |
try:
|
69 |
for chunk in result:
|
|
|
|
|
70 |
yield f"data: {json.dumps(chunk)}\n\n"
|
71 |
-
#
|
72 |
yield f"data: {json.dumps({'end': True})}\n\n"
|
73 |
except Exception as e:
|
74 |
-
# Return structured JSON to indicate error
|
75 |
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
|
|
63 |
except Exception as e:
|
64 |
return JSONResponse({"error": "Internal server error"}, status_code=500)
|
65 |
|
66 |
+
# βββ If it's a generator, return SSE/streaming ββββββββββββββββββββββββββ
|
67 |
if isinstance(result, types.GeneratorType):
|
68 |
def event_stream():
|
69 |
try:
|
70 |
for chunk in result:
|
71 |
+
# Each chunk should already be a Python dict or
|
72 |
+
# a string containing JSON. We wrap it in "data: β¦\n\n"
|
73 |
yield f"data: {json.dumps(chunk)}\n\n"
|
74 |
+
# Finally send an endβofβstream marker
|
75 |
yield f"data: {json.dumps({'end': True})}\n\n"
|
76 |
except Exception as e:
|
|
|
77 |
yield f"data: {json.dumps({'error': str(e)})}\n\n"
|
78 |
+
|
79 |
+
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
80 |
+
|
81 |
+
# βββ Otherwise (nonβstreaming), return a single JSON response ββββββββββ
|
82 |
+
# result is expected to be a JSONβstring or a dict
|
83 |
+
try:
|
84 |
+
# If handler.predict returned a JSONβencoded str, parse it to dict
|
85 |
+
if isinstance(result, str):
|
86 |
+
parsed = json.loads(result)
|
87 |
+
else:
|
88 |
+
parsed = result # assume it's already a dict
|
89 |
+
except Exception:
|
90 |
+
# Fall back to returning the raw result
|
91 |
+
return JSONResponse({"error": "Invalid JSON from handler"}, status_code=500)
|
92 |
+
|
93 |
+
return JSONResponse(parsed)
|
94 |
|