Prathamesh1420's picture
Update app.py
a8e0751 verified
raw
history blame
2.23 kB
import json
from pathlib import Path
import cv2
import numpy as np
from PIL import Image
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, JSONResponse
from pydantic import BaseModel, Field
from huggingface_hub import hf_hub_download
from io import BytesIO
import base64
try:
from demo.object_detection.inference import YOLOv10
except (ImportError, ModuleNotFoundError):
from inference import YOLOv10
# Define app and paths
app = FastAPI()
cur_dir = Path(__file__).parent
# Load YOLOv10 ONNX model
model_file = hf_hub_download(
repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
)
model = YOLOv10(model_file)
# Serve the index.html file
@app.get("/", response_class=HTMLResponse)
async def serve_frontend():
html_path = cur_dir / "index.html"
with open(html_path, "r", encoding="utf-8") as f:
html_content = f.read()
# Replace placeholder with empty RTC config or other configs if needed
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps({}))
return HTMLResponse(content=html_content)
# Model input format
class ImagePayload(BaseModel):
image: str # base64 string
conf_threshold: float = Field(default=0.3, ge=0, le=1)
# Inference route
@app.post("/detect")
async def detect_objects(payload: ImagePayload):
try:
# Decode base64 image
header, encoded = payload.image.split(",", 1)
img_bytes = base64.b64decode(encoded)
img = Image.open(BytesIO(img_bytes)).convert("RGB")
img_np = np.array(img)
# Resize for model input
img_resized = cv2.resize(img_np, (model.input_width, model.input_height))
# Run detection
output_image = model.detect_objects(img_resized, payload.conf_threshold)
# Return detections (if you want to send image back, convert to base64)
return JSONResponse(content={"status": "success"})
except Exception as e:
return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
# Optional: health check
@app.get("/health")
async def health():
return {"status": "ok"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)