Prathamesh1420 commited on
Commit
cb8a6cc
·
verified ·
1 Parent(s): a8e0751

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -60
app.py CHANGED
@@ -1,78 +1,58 @@
1
  import json
2
  from pathlib import Path
3
-
4
- import cv2
5
- import numpy as np
6
- from PIL import Image
7
- from fastapi import FastAPI, Request
8
- from fastapi.responses import HTMLResponse, JSONResponse
9
- from pydantic import BaseModel, Field
10
  from huggingface_hub import hf_hub_download
11
- from io import BytesIO
12
- import base64
 
13
 
14
- try:
15
- from demo.object_detection.inference import YOLOv10
16
- except (ImportError, ModuleNotFoundError):
17
- from inference import YOLOv10
18
 
19
- # Define app and paths
20
- app = FastAPI()
21
  cur_dir = Path(__file__).parent
22
 
23
- # Load YOLOv10 ONNX model
24
  model_file = hf_hub_download(
25
  repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
26
  )
27
- model = YOLOv10(model_file)
28
-
29
- # Serve the index.html file
30
- @app.get("/", response_class=HTMLResponse)
31
- async def serve_frontend():
32
- html_path = cur_dir / "index.html"
33
- with open(html_path, "r", encoding="utf-8") as f:
34
- html_content = f.read()
35
-
36
- # Replace placeholder with empty RTC config or other configs if needed
37
- html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps({}))
38
- return HTMLResponse(content=html_content)
39
 
 
40
 
41
- # Model input format
42
- class ImagePayload(BaseModel):
43
- image: str # base64 string
44
- conf_threshold: float = Field(default=0.3, ge=0, le=1)
45
-
46
-
47
- # Inference route
48
- @app.post("/detect")
49
- async def detect_objects(payload: ImagePayload):
50
- try:
51
- # Decode base64 image
52
- header, encoded = payload.image.split(",", 1)
53
- img_bytes = base64.b64decode(encoded)
54
- img = Image.open(BytesIO(img_bytes)).convert("RGB")
55
- img_np = np.array(img)
56
-
57
- # Resize for model input
58
- img_resized = cv2.resize(img_np, (model.input_width, model.input_height))
59
-
60
- # Run detection
61
- output_image = model.detect_objects(img_resized, payload.conf_threshold)
62
-
63
- # Return detections (if you want to send image back, convert to base64)
64
- return JSONResponse(content={"status": "success"})
65
 
66
- except Exception as e:
67
- return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
68
 
 
 
 
 
69
 
70
- # Optional: health check
71
- @app.get("/health")
72
- async def health():
73
- return {"status": "ok"}
74
 
 
 
 
75
 
76
  if __name__ == "__main__":
77
- import uvicorn
78
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
1
  import json
2
  from pathlib import Path
3
+ from fastapi import FastAPI
4
+ from fastapi.responses import HTMLResponse
5
+ from fastrtc import Stream
 
 
 
 
6
  from huggingface_hub import hf_hub_download
7
+ from pydantic import BaseModel, Field
8
+ import os
9
+ import cv2
10
 
11
+ from inference import YOLOv10
 
 
 
12
 
 
 
13
  cur_dir = Path(__file__).parent
14
 
 
15
  model_file = hf_hub_download(
16
  repo_id="onnx-community/yolov10n", filename="onnx/model.onnx"
17
  )
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ model = YOLOv10(model_file)
20
 
21
+ def detection(image, conf_threshold=0.3):
22
+ image = cv2.resize(image, (model.input_width, model.input_height))
23
+ new_image = model.detect_objects(image, conf_threshold)
24
+ return cv2.resize(new_image, (500, 500))
25
+
26
+ stream = Stream(
27
+ handler=detection,
28
+ modality="video",
29
+ mode="send-receive",
30
+ additional_inputs=[gr.Slider(minimum=0, maximum=1, step=0.01, value=0.3)],
31
+ rtc_configuration=None,
32
+ concurrency_limit=None,
33
+ )
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ app = FastAPI()
36
+ stream.mount(app)
37
 
38
+ @app.get("/")
39
+ async def serve_index():
40
+ html_content = open(cur_dir / "index.html").read()
41
+ return HTMLResponse(content=html_content)
42
 
43
+ class InputData(BaseModel):
44
+ webrtc_id: str
45
+ conf_threshold: float = Field(ge=0, le=1)
 
46
 
47
+ @app.post("/input_hook")
48
+ async def update_input(data: InputData):
49
+ stream.set_input(data.webrtc_id, data.conf_threshold)
50
 
51
  if __name__ == "__main__":
52
+ if (mode := os.getenv("MODE")) == "UI":
53
+ stream.ui.launch(server_port=7860)
54
+ elif mode == "PHONE":
55
+ stream.fastphone(host="0.0.0.0", port=7860)
56
+ else:
57
+ import uvicorn
58
+ uvicorn.run(app, host="0.0.0.0", port=7860)