
Add debugs on A* path if they may get blocked in routing nodes. Add function to allow offfset of 3 segment classes
40967d9
# Access: https://BinKhoaLe1812-Sall-eGarbageDetection.hf.space/ui | |
# ───────────────────────── app.py (Sall-e demo) ───────────────────────── | |
# FastAPI ▸ upload image ▸ multi-model garbage detection ▸ ADE-20K | |
# semantic segmentation (Water / Garbage) ▸ A* navigation ▸ H.264 video | |
# ======================================================================= | |
import os, uuid, threading, shutil, time, heapq, cv2, numpy as np | |
from PIL import Image | |
import uvicorn | |
from fastapi import FastAPI, File, UploadFile, Request | |
from fastapi.responses import HTMLResponse, StreamingResponse, Response | |
from fastapi.staticfiles import StaticFiles | |
# ── Vision libs ───────────────────────────────────────────────────────── | |
import torch, yolov5, ffmpeg | |
from ultralytics import YOLO | |
from transformers import ( | |
DetrImageProcessor, DetrForObjectDetection, | |
SegformerFeatureExtractor, SegformerForSemanticSegmentation | |
) | |
# from sklearn.neighbors import NearestNeighbors | |
from inference_sdk import InferenceHTTPClient | |
# ── Folders / files ───────────────────────────────────────────────────── | |
BASE = "/home/user/app" | |
CACHE = f"{BASE}/cache" | |
UPLOAD_DIR = f"{CACHE}/uploads" | |
OUTPUT_DIR = f"{BASE}/outputs" | |
MODEL_DIR = f"{BASE}/model" | |
SPRITE = f"{BASE}/sprite.png" | |
os.makedirs(UPLOAD_DIR, exist_ok=True) | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
os.makedirs(CACHE , exist_ok=True) | |
os.environ["TRANSFORMERS_CACHE"] = CACHE | |
os.environ["HF_HOME"] = CACHE | |
# ── Load models once ─────────────────────────────────────────────────── | |
print("🔄 Loading models …") | |
model_self = YOLO(f"{MODEL_DIR}/garbage_detector.pt") # YOLOv11(l) | |
model_yolo5 = yolov5.load(f"{MODEL_DIR}/yolov5-detect-trash-classification.pt") | |
processor_detr = DetrImageProcessor.from_pretrained(f"{MODEL_DIR}/detr") | |
model_detr = DetrForObjectDetection.from_pretrained(f"{MODEL_DIR}/detr") | |
feat_extractor = SegformerFeatureExtractor.from_pretrained( | |
"nvidia/segformer-b4-finetuned-ade-512-512") | |
segformer = SegformerForSemanticSegmentation.from_pretrained( | |
"nvidia/segformer-b4-finetuned-ade-512-512") | |
model_animal = YOLO(f"{MODEL_DIR}/yolov8n.pt") # Load COCO pre-trained YOLOv8 for animal detection | |
print("✅ Models ready\n") | |
# ── ADE-20K palette + custom mapping (verbatim) ───────────────────────── | |
# ADE20K palette | |
ade_palette = np.array([ | |
[0, 0, 0], [120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], | |
[4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], [230, 230, 230], | |
[4, 250, 7], [224, 5, 255], [235, 255, 7], [150, 5, 61], [120, 120, 70], | |
[8, 255, 51], [255, 6, 82], [143, 255, 140], [204, 255, 4], [255, 51, 7], | |
[204, 70, 3], [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], | |
[255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], | |
[112, 9, 255], [8, 255, 214], [7, 255, 224], [255, 184, 6], [10, 255, 71], | |
[255, 41, 10], [7, 255, 255], [224, 255, 8], [102, 8, 255], [255, 61, 6], | |
[255, 194, 7], [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], | |
[6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], [140, 140, 140], | |
[250, 10, 15], [20, 255, 0], [31, 255, 0], [255, 31, 0], [255, 224, 0], | |
[153, 255, 0], [0, 0, 255], [255, 71, 0], [0, 235, 255], [0, 173, 255], | |
[31, 0, 255], [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], | |
[0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], [255, 102, 0], | |
[194, 255, 0], [0, 143, 255], [51, 255, 0], [0, 82, 255], [0, 255, 41], | |
[255, 0, 255], [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], | |
[255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], [255, 0, 204], | |
[0, 255, 194], [0, 255, 82], [0, 10, 255], [0, 112, 255], [51, 0, 255], | |
[0, 194, 255], [0, 122, 255], [0, 255, 163], [255, 153, 0], [0, 255, 10], | |
[255, 112, 0], [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], | |
[8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], [255, 0, 31], | |
[0, 184, 255], [0, 214, 255], [255, 0, 112], [92, 255, 0], [0, 224, 255], | |
[112, 224, 255], [70, 184, 160], [163, 0, 255], [153, 0, 255], [71, 255, 0], | |
[255, 0, 163], [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], | |
[255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], [10, 190, 212], | |
[214, 255, 0], [0, 204, 255], [20, 0, 255], [255, 255, 0], [0, 153, 255], | |
[0, 41, 255], [0, 255, 204], [41, 0, 255], [41, 255, 0], [173, 0, 255], | |
[0, 245, 255], [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], | |
[184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], [102, 255, 0], | |
[92, 0, 255] | |
], dtype=np.uint8) | |
if ade_palette.shape[0] < 150: # Some update require 150 class now but we only afford 146, allow offset | |
missing = 150 - ade_palette.shape[0] | |
padding = np.zeros((missing, 3), dtype=np.uint8) | |
ade_palette = np.vstack([ade_palette, padding]) | |
custom_class_map = { | |
"Garbage": [(255, 8, 41), (235, 255, 7), (255, 5, 153), (255, 0, 102)], | |
"Water": [(0, 102, 200), (11, 102, 255), (31, 0, 255), (10, 0, 255), (9, 7, 230)], | |
"Grass / Vegetation": [(10, 255, 71), (143, 255, 140)], | |
"Tree / Natural Obstacle": [(4, 200, 3), (235, 12, 255), (255, 6, 82), (255, 163, 0)], | |
"Sand / Soil / Ground": [(80, 50, 50), (230, 230, 230)], | |
"Buildings / Structures": [(255, 0, 255), (184, 0, 255), (120, 120, 120), (7, 255, 224)], | |
"Sky / Background": [(180, 120, 120)], | |
"Undetecable": [(0, 0, 0)], | |
"Unknown Class": [] | |
} | |
TOL = 30 # RGB tolerance | |
# Segment class [150, 5, 61] is only detectable as garbage if it's large enough | |
def interpret_rgb_class(decoded_img): | |
ambiguous_rgb = np.array([150, 5, 61]) | |
matches = np.all(np.abs(decoded_img - ambiguous_rgb) <= TOL, axis=-1) | |
match_ratio = np.count_nonzero(matches) / matches.size | |
return "garbage" if match_ratio > 0.15 else "sand" | |
# Masking zones (Garbage and Water zone to be travelable) | |
def build_masks(seg): | |
""" | |
Returns three binary masks at (H,W): | |
water_mask – 1 = water | |
garbage_mask – 1 = semantic “Garbage” pixels | |
movable_mask – union of water & garbage (robot can travel here) | |
""" | |
decoded = ade_palette[seg] | |
water_mask = np.zeros(seg.shape, np.uint8) | |
garbage_mask = np.zeros_like(water_mask) | |
# Resolve ambiguity: (150,5,61) → Sand or Garbage? | |
context_label = interpret_rgb_class(decoded) | |
resolved_map = custom_class_map.copy() | |
# Dynamically re-assign the ambiguous RGB class | |
if context_label == "garbage": | |
resolved_map["Garbage"].append((150, 5, 61)) | |
resolved_map["Sand / Soil / Ground"] = [rgb for rgb in resolved_map["Sand / Soil / Ground"] if rgb != (150, 5, 61)] | |
else: # Fall back as appointed to be sth else | |
resolved_map["Sand / Soil / Ground"].append((150, 5, 61)) | |
resolved_map["Garbage"] = [rgb for rgb in resolved_map["Garbage"] if rgb != (150, 5, 61)] | |
# Append water pixels to water_mask | |
for rgb in custom_class_map["Water"]: | |
water_mask |= (np.abs(decoded - rgb).max(axis=-1) <= TOL) | |
# Append gb pixels to garbage_mask | |
for rgb in custom_class_map["Garbage"]: | |
garbage_mask |= (np.abs(decoded - rgb).max(axis=-1) <= TOL) | |
movable_mask = water_mask | garbage_mask | |
return water_mask, garbage_mask, movable_mask | |
# Garbage mask can be highlighted in red | |
def highlight_chunk_masks_on_frame( | |
frame, labels, objs, | |
color_uncollected=(0, 0, 128), | |
color_collected=(0, 128, 0), | |
color_unreachable=(0, 255, 255), | |
alpha=0.8 | |
): | |
""" | |
Overlays semi-transparent colored regions for garbage chunks on the frame. | |
`objs` must have 'pos' and 'col' keys. The collection status changes the overlay color. | |
""" | |
overlay = frame.copy() | |
for i, obj in enumerate(objs): | |
x, y = obj["pos"] | |
lab = labels[y, x] | |
if lab == 0: | |
continue | |
mask = (labels == lab).astype(np.uint8) | |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
# Choose color based on status | |
if obj.get("unreachable"): | |
color = color_unreachable | |
elif obj["col"]: | |
color = color_collected | |
else: | |
color = color_uncollected # drawContours on overlay | |
cv2.drawContours(overlay, contours, -1, color, thickness=cv2.FILLED) | |
# Blend overlay with original frame using alpha | |
return cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0) | |
# Water mask to be blue | |
def highlight_water_mask_on_frame(frame, binary_mask, color=(255, 0, 0), alpha=0.3): | |
""" | |
Overlays semi-transparent colored mask (binary) on the frame. | |
""" | |
overlay = frame.copy() | |
mask = binary_mask.astype(np.uint8) * 255 | |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
# drawContours on overlay | |
cv2.drawContours(overlay, contours, -1, color, thickness=cv2.FILLED) | |
return cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0) | |
# ── A* and KNN over binary water grid ───────────────────────────────── | |
def astar(start, goal, occ): | |
h = lambda a,b: abs(a[0]-b[0])+abs(a[1]-b[1]) | |
N8 = [(-1,-1),(-1,0),(-1,1),(0,-1),(0,1),(1,-1),(1,0),(1,1)] | |
openq=[(0,start)]; g={start:0}; came={}; visited = set() | |
while openq: | |
_,cur=heapq.heappop(openq) | |
if cur==goal: | |
p=[cur]; # reconstruct | |
while cur in came: cur=came[cur]; p.append(cur) | |
return p[::-1] | |
if cur in visited: | |
continue | |
visited.add(cur) | |
for dx, dy in N8: | |
nx, ny = cur[0] + dx, cur[1] + dy | |
if not (0 <= nx < 640 and 0 <= ny < 640) or occ[ny, nx] == 0: | |
continue | |
if abs(dx) == 1 and abs(dy) == 1: | |
if occ[cur[1]+dy, cur[0]] == 0 or occ[cur[1], cur[0]+dx] == 0: | |
continue | |
neighbor = (nx, ny) | |
ng = g[cur] + 1 | |
if neighbor not in g or ng < g[neighbor]: | |
g[neighbor] = ng | |
f = ng + h(neighbor, goal) | |
heapq.heappush(openq, (f, neighbor)) | |
came[neighbor] = cur | |
# Save visited search as debug image | |
visited_img = np.zeros_like(occ, dtype=np.uint8) | |
for x, y in visited: | |
visited_img[y, x] = 127 | |
cv2.circle(visited_img, start[::-1], 3, 255, -1) | |
cv2.circle(visited_img, goal[::-1], 3, 255, -1) | |
cv2.imwrite("/home/user/app/outputs/debug_astar_failure.png", visited_img * 2) | |
print(f"🧨 A* failed from {start} to {goal} — frontier saved to debug_astar_failure.png") | |
return [] | |
# KNN fit optimal path | |
def knn_path(start, targets, occ): | |
todo = targets[:]; path=[] | |
cur = tuple(start) | |
reachable = []; unreachable = [] | |
while todo: | |
# KNN follow a Greedy approach, which may not guarantee shortest path, hence only use A* | |
best = None | |
best_len = float('inf') | |
best_seg = [] | |
# Try A* to each target, find shortest actual path | |
for t in todo: | |
seg = astar(cur, tuple(t), occ) | |
if seg and len(seg) < best_len: # index error? | |
best = tuple(t) | |
best_len = len(seg) | |
best_seg = seg | |
if not best: | |
# All remaining in `todo` are unreachable | |
for u in todo: | |
print(f"⚠️ Garbage unreachable at {u}") | |
unreachable.append(u) | |
break # no more reachable targets | |
if path and path[-1] == best_seg[0]: | |
best_seg = best_seg[1:] # skip duplicate | |
path.extend(best_seg) | |
reachable.append(list(best)) | |
cur = best | |
todo.remove(list(best)) | |
return path, unreachable | |
# ── Robot sprite/class -────────────────────────────────────────────────── | |
class Robot: | |
def __init__(self, sprite, speed=2000): # Declare the robot's physical stats and routing (position, speed, movement, path) | |
img = Image.open(sprite).convert("RGBA").resize((40, 40)) | |
self.png = np.array(img) | |
if self.png.shape[-1] != 4: | |
raise ValueError("Sprite image must have 4 channels (RGBA)") | |
self.png = np.array(Image.open(sprite).convert("RGBA").resize((40,40))) | |
self.speed = speed | |
self.pos = [20, 20] # Fallback spawn with body offset at top-left | |
def step(self, path): | |
while path: | |
dx, dy = path[0][0] - self.pos[0], path[0][1] - self.pos[1] | |
dist = (dx * dx + dy * dy) ** 0.5 | |
if dist <= self.speed: | |
self.pos = list(path.pop(0)) | |
else: # If valid path within | |
r = self.speed / dist | |
new_x = self.pos[0] + dx * r | |
new_y = self.pos[1] + dy * r | |
# Clip to valid region with 20px margin (for body offset) | |
self.pos = [ | |
int(np.clip(new_x, 20, 640 - 20)), | |
int(np.clip(new_y, 20, 640 - 20)) | |
] | |
# Break after one logical move to avoid overshooting | |
break | |
# ── Static-web ────────────────────────────────────────────────────────── | |
from fastapi.responses import JSONResponse, FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
app.mount("/statics", StaticFiles(directory="statics"), name="statics") | |
video_ready={} | |
async def serve_index(): | |
p = "statics/index.html" | |
if os.path.exists(p): | |
print("[STATIC] Serving index.html") | |
return FileResponse(p) | |
print("[STATIC] index.html not found") | |
return JSONResponse(status_code=404, content={"detail":"Not found"}) | |
def _uid(): return uuid.uuid4().hex[:8] | |
# ── End-points ────────────────────────────────────────────────────────── | |
# User upload environment img here | |
async def upload(file:UploadFile=File(...)): | |
uid=_uid(); dest=f"{UPLOAD_DIR}/{uid}_{file.filename}" | |
with open(dest,"wb") as bf: shutil.copyfileobj(file.file,bf) | |
threading.Thread(target=_pipeline, args=(uid,dest)).start() | |
return {"user_id":uid} | |
# Health check, make sure the video generator is alive and debug which video id is processed (multiple video can be processed at 1 worker) | |
def chk(uid:str): return {"ready":video_ready.get(uid,False)} | |
# Where the final video being saved | |
def stream(uid:str): | |
vid=f"{OUTPUT_DIR}/{uid}.mp4" | |
if not os.path.exists(vid): return Response(status_code=404) | |
return StreamingResponse(open(vid,"rb"), media_type="video/mp4") | |
# ─── Detect animal/wildlife ───────────────────────────────────────────────── | |
# Init clients | |
# https://universe.roboflow.com/team-hope-mmcyy/hydroquest | https://universe.roboflow.com/sky-sd2zq/bird_only-pt0bm/model/1 | |
import base64, requests | |
def roboflow_infer(image_path, api_url, api_key): | |
with open(image_path, "rb") as image_file: | |
files = {"file": image_file} | |
res = requests.post( | |
f"{api_url}?api_key={api_key}&confidence=70", # Add threshold to URL | |
files=files | |
) | |
print(f"[Roboflow] {res.status_code} response") | |
try: | |
return res.json() | |
except Exception as e: | |
print("[Roboflow JSON decode error]", e) | |
return {} | |
# Animal detection endpoint (animal, fish, bird as target classes) | |
async def detect_animals(file: UploadFile = File(...)): | |
img_id = _uid() | |
img_path = f"{UPLOAD_DIR}/{img_id}_{file.filename}" | |
with open(img_path, "wb") as f: | |
shutil.copyfileobj(file.file, f) | |
print(f"[Animal] Uploaded image: {img_path}") | |
# Read and prepare detection | |
image = cv2.imread(img_path) | |
detections = [] | |
# 1. YOLOv8 local | |
print("[Animal] Detecting via YOLOv8…") | |
try: | |
results = model_animal(image)[0] | |
for box in results.boxes: | |
conf = box.conf[0].item() | |
if conf >= 0.70: | |
cls_id = int(box.cls[0].item()) | |
label = model_animal.names[cls_id].lower() | |
if label in ["dog", "cat", "cow", "horse", "elephant", "bear", "zebra", "giraffe", "bird"]: | |
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist()) | |
detections.append(((x1, y1, x2, y2), f"Animal Alert {conf}")) | |
except Exception as e: | |
print("[YOLOv8 Error]", e) | |
# Hide on production-level | |
print("[API] Roboflow key:", os.getenv("ROBOFLOW_KEY", "❌ not set")) | |
# 2. Roboflow Fish | |
try: | |
print("[Animal] Detecting via Roboflow Fish model…") | |
fish_response = roboflow_infer( | |
img_path, | |
"https://detect.roboflow.com/hydroquest/1", | |
api_key=os.getenv("ROBOFLOW_KEY", "") | |
) | |
for pred in fish_response.get("predictions", []): | |
if pred["confidence"] >= 0.70: | |
acc = pred["confidence"] | |
x1 = int(pred["x"] - pred["width"] / 2) | |
y1 = int(pred["y"] - pred["height"] / 2) | |
x2 = int(pred["x"] + pred["width"] / 2) | |
y2 = int(pred["y"] + pred["height"] / 2) | |
detections.append(((x1, y1, x2, y2), f"Fish Alert {acc}")) | |
print("[Roboflow Fish Response]", fish_response) | |
except Exception as e: | |
print("[Roboflow Fish Error]", e) | |
# 3. Roboflow Bird | |
try: | |
print("[Animal] Detecting via Roboflow Bird model…") | |
bird_response = roboflow_infer( | |
img_path, | |
"https://detect.roboflow.com/bird_only-pt0bm/1", | |
api_key=os.getenv("ROBOFLOW_KEY", "") | |
) | |
for pred in bird_response.get("predictions", []): | |
if pred["confidence"] >= 0.70: | |
acc = pred["confidence"] | |
x1 = int(pred["x"] - pred["width"] / 2) | |
y1 = int(pred["y"] - pred["height"] / 2) | |
x2 = int(pred["x"] + pred["width"] / 2) | |
y2 = int(pred["y"] + pred["height"] / 2) | |
detections.append(((x1, y1, x2, y2), f"Bird Alert {acc}")) | |
print("[Roboflow Bird Response]", bird_response) | |
except Exception as e: | |
print("[Roboflow Bird Error]", e) | |
# Count detection | |
print(f"[Animal] Total detections: {len(detections)}") | |
# Write label | |
for (x1, y1, x2, y2), label in detections: | |
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 0, 255), 2) | |
cv2.putText(image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) | |
# Write img | |
result_path = f"{OUTPUT_DIR}/{img_id}_animal.jpg" | |
cv2.imwrite(result_path, image) | |
return FileResponse(result_path, media_type="image/jpeg") | |
# ── Core pipeline (runs in background thread) ─────────────────────────── | |
def _pipeline(uid,img_path): | |
print(f"▶️ [{uid}] processing") | |
bgr=cv2.resize(cv2.imread(img_path),(640,640)); rgb=cv2.cvtColor(bgr,cv2.COLOR_BGR2RGB) | |
pil=Image.fromarray(rgb) | |
# 1- Segmentation → masking each segmented zone with pytorch | |
with torch.no_grad(): | |
inputs = feat_extractor(pil, return_tensors="pt") | |
seg_logits = segformer(**inputs).logits | |
# Tensor run by CPU | |
seg_tensor = seg_logits.argmax(1)[0].cpu() | |
if seg_tensor.numel() == 0: | |
print(f"❌ [{uid}] segmentation failed (empty tensor)") | |
video_ready[uid] = True | |
return | |
# Resize the tensor to 640x640 | |
seg = cv2.resize(seg_tensor.numpy(), (640, 640), interpolation=cv2.INTER_NEAREST) | |
print(f"🧪 [{uid}] segmentation input shape: {inputs['pixel_values'].shape}") | |
water_mask, garbage_mask, movable_mask = build_masks(seg) # movable zone = water and garbage masks | |
for cx, cy in centres: | |
cv2.circle(movable_mask, (cx, cy), 3, 127, -1) # gray center dots | |
cv2.imwrite(f"{OUTPUT_DIR}/{uid}_movable_with_centres.png", movable_mask * 255) | |
print(f"🧩 Saved debug movable_mask: {OUTPUT_DIR}/{uid}_movable_mask.png") | |
# 2- Garbage detection (3 models) → keep centres on water | |
detections=[] | |
# Detect garbage chunks (from segmentation) | |
num_cc, labels = cv2.connectedComponents(garbage_mask.astype(np.uint8)) | |
chunk_centres = [] | |
for lab in range(1, num_cc): | |
ys, xs = np.where(labels == lab) | |
if xs.size == 0: # safety | |
continue | |
chunk_centres.append([int(xs.mean()), int(ys.mean())]) | |
print(f"🧠 {len(chunk_centres)} garbage chunk detected") | |
# Detect garbage object by within travelable zones | |
for r in model_self(bgr): # YOLOv11 (self-trained) | |
detections += [b.xyxy[0].tolist() for b in r.boxes] | |
r = model_yolo5(bgr) # YOLOv5 | |
if hasattr(r, 'pred') and len(r.pred) > 0: | |
detections += [p[:4].tolist() for p in r.pred[0]] | |
inp=processor_detr(images=pil,return_tensors="pt") | |
with torch.no_grad(): out=model_detr(**inp) # DETR | |
post = processor_detr.post_process_object_detection( | |
outputs=out, | |
target_sizes=torch.tensor([pil.size[::-1]]), | |
threshold=0.5 | |
)[0] | |
detections += [b.tolist() for b in post["boxes"]] | |
# centre & mask filter (the garbage lies within travelable zone are collectable) | |
centres = [] | |
for x1, y1, x2, y2 in detections: # Define IoU heuristic | |
''' | |
We conduct a 20% allowance whether the center | |
of the detected garbage's bbox lies within the travelable zone | |
which was segmented earlier to be the water and garbage zone | |
''' | |
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2]) | |
x1 = max(0, min(x1, 639)); y1 = max(0, min(y1, 639)) | |
x2 = max(0, min(x2, 639)); y2 = max(0, min(y2, 639)) | |
box_mask = movable_mask[y1:y2, x1:x2] # ← switch to movable_mask | |
if box_mask.size == 0: | |
continue | |
if np.count_nonzero(box_mask) / box_mask.size >= 0.5: | |
centres.append([int((x1 + x2) / 2), int((y1 + y2) / 2)]) | |
# add chunk centres and deduplicate | |
centres.extend(chunk_centres) | |
centres = [list(c) for c in {tuple(c) for c in centres}] | |
if not centres: # No garbages within travelable zone | |
print(f"🛑 [{uid}] no reachable garbage"); video_ready[uid]=True; return | |
else: # Garbage within valid travelable zone | |
print(f"🧠 {len(centres)} garbage objects on water selected from {len(detections)} detections") | |
# 3- Robot initialization, position and navigation | |
# find all (y,x) within water-zone | |
ys, xs = np.where(water_mask) | |
if len(ys)==0: | |
# no travelable zone → bail out | |
print(f"❌ [{uid}] no water to spawn on") | |
video_ready[uid] = True | |
return | |
# sort by y, then x | |
idx = np.lexsort((xs, ys)) | |
spawn_y, spawn_x = int(ys[idx[0]]), int(xs[idx[0]]) | |
# enforce 20px margin so sprite never pokes out | |
spawn_x = np.clip(spawn_x, 20, 640-20) | |
spawn_y = np.clip(spawn_y, 20, 640-20) | |
robot = Robot(SPRITE) | |
# Robot will be spawn on the closest movable mask to top-left | |
robot.pos = [spawn_x, spawn_y] | |
path, unreachable = knn_path(robot.pos, centres, movable_mask) | |
if unreachable: | |
print(f"⚠️ Unreachable garbage chunks at: {unreachable}") | |
# 4- Video synthesis | |
out_tmp=f"{OUTPUT_DIR}/{uid}_tmp.mp4" | |
vw=cv2.VideoWriter(out_tmp,cv2.VideoWriter_fourcc(*"mp4v"),10.0,(640,640)) | |
objs = [{"pos": p, "col": False, "unreachable": False} for p in centres if p not in unreachable] | |
objs += [{"pos": p, "col": False, "unreachable": True} for p in unreachable] | |
bg = bgr.copy() | |
for _ in range(15000): # safety frames | |
frame=bg.copy() | |
# Draw garbage chunk masks in red-to-green (semi-transparent) | |
frame = highlight_chunk_masks_on_frame( | |
frame, | |
labels, | |
objs, | |
color_uncollected=(0, 0, 128), # 🔴 | |
color_collected=(0, 128, 0), # 🟢 | |
color_unreachable=(0, 255, 255) # 🟡 | |
) # 🔴 garbage overlay | |
frame = highlight_water_mask_on_frame(frame, water_mask) # 🔵 water overlay | |
# Draw object detections as red (to green) dots | |
for o in objs: | |
color = (0, 0, 128) if not o["col"] else (0, 128, 0) | |
x, y = o["pos"] | |
cv2.circle(frame, (x, y), 6, color, -1) | |
# Robot displacement | |
robot.step(path) | |
sp = robot.png | |
sprite_h, sprite_w = sp.shape[:2] | |
rx, ry = robot.pos | |
x1, y1 = rx - sprite_w // 2, ry - sprite_h // 2 | |
x2, y2 = x1 + sprite_w, y1 + sprite_h | |
# Clip boundaries to image size | |
x1_clip, x2_clip = max(0, x1), min(frame.shape[1], x2) | |
y1_clip, y2_clip = max(0, y1), min(frame.shape[0], y2) | |
# Adjust sprite crop accordingly | |
sx1, sy1 = x1_clip - x1, y1_clip - y1 | |
sx2, sy2 = sprite_w - (x2 - x2_clip), sprite_h - (y2 - y2_clip) | |
sprite_crop = sp[sy1:sy2, sx1:sx2] | |
alpha = sprite_crop[:, :, 3] / 255.0 | |
alpha = np.stack([alpha] * 3, axis=-1) | |
bgroi = frame[y1_clip:y2_clip, x1_clip:x2_clip] | |
blended = (alpha * sprite_crop[:, :, :3] + (1 - alpha) * bgroi).astype(np.uint8) | |
frame[y1_clip:y2_clip, x1_clip:x2_clip] = blended | |
# collection check | |
for o in objs: | |
if not o["col"] and np.hypot(o["pos"][0]-robot.pos[0], o["pos"][1]-robot.pos[1]) <= 20: | |
o["col"]=True | |
vw.write(frame) | |
if all(o["col"] for o in objs): break | |
if not path: break | |
vw.release() | |
# 5- Convert to H.264 | |
final=f"{OUTPUT_DIR}/{uid}.mp4" | |
ffmpeg.input(out_tmp).output(final,vcodec="libx264",pix_fmt="yuv420p").run(overwrite_output=True,quiet=True) | |
os.remove(out_tmp); video_ready[uid]=True | |
print(f"✅ [{uid}] video ready → {final}") | |
# ── Run locally (HF Space ignores since built with Docker image) ──────── | |
if __name__=="__main__": | |
uvicorn.run(app,host="0.0.0.0",port=7860) | |