Prathamesh1420 commited on
Commit
a8e0751
·
verified ·
1 Parent(s): b1da0c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -45
app.py CHANGED
@@ -2,76 +2,77 @@ import json
2
  from pathlib import Path
3
 
4
  import cv2
5
- import gradio as gr
6
- from fastapi import FastAPI
7
- from fastapi.responses import HTMLResponse
8
- from fastrtc import Stream, get_twilio_turn_credentials
9
- from gradio.utils import get_space
10
- from huggingface_hub import hf_hub_download
11
  from pydantic import BaseModel, Field
 
 
 
12
 
13
  try:
14
  from demo.object_detection.inference import YOLOv10
15
  except (ImportError, ModuleNotFoundError):
16
  from inference import YOLOv10
17
 
18
-
 
19
  cur_dir = Path(__file__).parent
20
 
 
21
  model_file = hf_hub_download(
22
  repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
23
  )
24
-
25
  model = YOLOv10(model_file)
26
 
 
 
 
 
 
 
27
 
28
- def detection(image, conf_threshold=0.3):
29
- image = cv2.resize(image, (model.input_width, model.input_height))
30
- print("conf_threshold", conf_threshold)
31
- new_image = model.detect_objects(image, conf_threshold)
32
- return cv2.resize(new_image, (500, 500))
33
 
34
 
35
- stream = Stream(
36
- handler=detection,
37
- modality="video",
38
- mode="send-receive",
39
- additional_inputs=[gr.Slider(minimum=0, maximum=1, step=0.01, value=0.3)],
40
- rtc_configuration=get_twilio_turn_credentials() if get_space() else None,
41
- concurrency_limit=2 if get_space() else None,
42
- )
43
 
44
- app = FastAPI()
45
 
46
- stream.mount(app)
 
 
 
 
 
 
 
 
47
 
 
 
48
 
49
- @app.get("/")
50
- async def _():
51
- rtc_config = get_twilio_turn_credentials() if get_space() else None
52
- html_content = open(cur_dir / "index.html").read()
53
- html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
54
- return HTMLResponse(content=html_content)
55
 
 
 
56
 
57
- class InputData(BaseModel):
58
- webrtc_id: str
59
- conf_threshold: float = Field(ge=0, le=1)
60
 
61
 
62
- @app.post("/input_hook")
63
- async def _(data: InputData):
64
- stream.set_input(data.webrtc_id, data.conf_threshold)
 
65
 
66
 
67
  if __name__ == "__main__":
68
- import os
69
-
70
- if (mode := os.getenv("MODE")) == "UI":
71
- stream.ui.launch(server_port=7860)
72
- elif mode == "PHONE":
73
- stream.fastphone(host="0.0.0.0", port=7860)
74
- else:
75
- import uvicorn
76
-
77
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
2
  from pathlib import Path
3
 
4
  import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ from fastapi import FastAPI, Request
8
+ from fastapi.responses import HTMLResponse, JSONResponse
 
 
9
  from pydantic import BaseModel, Field
10
+ from huggingface_hub import hf_hub_download
11
+ from io import BytesIO
12
+ import base64
13
 
14
  try:
15
  from demo.object_detection.inference import YOLOv10
16
  except (ImportError, ModuleNotFoundError):
17
  from inference import YOLOv10
18
 
19
+ # Define app and paths
20
+ app = FastAPI()
21
  cur_dir = Path(__file__).parent
22
 
23
+ # Load YOLOv10 ONNX model
24
  model_file = hf_hub_download(
25
  repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
26
  )
 
27
  model = YOLOv10(model_file)
28
 
29
+ # Serve the index.html file
30
+ @app.get("/", response_class=HTMLResponse)
31
+ async def serve_frontend():
32
+ html_path = cur_dir / "index.html"
33
+ with open(html_path, "r", encoding="utf-8") as f:
34
+ html_content = f.read()
35
 
36
+ # Replace placeholder with empty RTC config or other configs if needed
37
+ html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps({}))
38
+ return HTMLResponse(content=html_content)
 
 
39
 
40
 
41
+ # Model input format
42
+ class ImagePayload(BaseModel):
43
+ image: str # base64 string
44
+ conf_threshold: float = Field(default=0.3, ge=0, le=1)
 
 
 
 
45
 
 
46
 
47
+ # Inference route
48
+ @app.post("/detect")
49
+ async def detect_objects(payload: ImagePayload):
50
+ try:
51
+ # Decode base64 image
52
+ header, encoded = payload.image.split(",", 1)
53
+ img_bytes = base64.b64decode(encoded)
54
+ img = Image.open(BytesIO(img_bytes)).convert("RGB")
55
+ img_np = np.array(img)
56
 
57
+ # Resize for model input
58
+ img_resized = cv2.resize(img_np, (model.input_width, model.input_height))
59
 
60
+ # Run detection
61
+ output_image = model.detect_objects(img_resized, payload.conf_threshold)
 
 
 
 
62
 
63
+ # Return detections (if you want to send image back, convert to base64)
64
+ return JSONResponse(content={"status": "success"})
65
 
66
+ except Exception as e:
67
+ return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
 
68
 
69
 
70
+ # Optional: health check
71
+ @app.get("/health")
72
+ async def health():
73
+ return {"status": "ok"}
74
 
75
 
76
  if __name__ == "__main__":
77
+ import uvicorn
78
+ uvicorn.run(app, host="0.0.0.0", port=7860)