Commit
·
22bb5ba
1
Parent(s):
b75ea7b
Update garb cls to render on colors and infer 3 models together
Browse files
app.py
CHANGED
@@ -423,49 +423,70 @@ async def classify_garbage(file: UploadFile = File(...)):
|
|
423 |
img_id = _uid()
|
424 |
img_path = f"{UPLOAD_DIR}/{img_id}_{file.filename}"
|
425 |
out_path = f"{OUTPUT_DIR}/{img_id}_classified.jpg"
|
426 |
-
#
|
427 |
with open(img_path, "wb") as f:
|
428 |
shutil.copyfileobj(file.file, f)
|
429 |
-
# Read
|
430 |
print(f"[Classification] Received image: {img_path}")
|
431 |
image = cv2.imread(img_path)
|
432 |
rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
433 |
pil = Image.fromarray(rgb)
|
434 |
-
#
|
435 |
detections = []
|
436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
with torch.no_grad():
|
438 |
-
out = model_detr(**
|
439 |
results = processor_detr.post_process_object_detection(
|
440 |
outputs=out,
|
441 |
target_sizes=torch.tensor([pil.size[::-1]]),
|
442 |
threshold=0.5
|
443 |
)[0]
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
x1,
|
|
|
|
|
450 |
crop = image[y1:y2, x1:x2]
|
451 |
if crop.shape[0] < 10 or crop.shape[1] < 10:
|
452 |
-
continue
|
453 |
-
#
|
454 |
pil_crop = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
|
455 |
-
|
|
|
456 |
class_id = int(pred.probs.top1)
|
457 |
class_name = model_garbage_cls.names[class_id]
|
458 |
-
conf = pred.probs.top1conf
|
459 |
-
#
|
460 |
label = f"{class_name} ({conf:.2f})"
|
461 |
-
|
462 |
-
|
463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
464 |
cv2.imwrite(out_path, image)
|
465 |
print(f"[Classification] Output saved: {out_path}")
|
466 |
return FileResponse(out_path, media_type="image/jpeg")
|
467 |
|
468 |
|
|
|
469 |
# ── Core pipeline (runs in background thread) ───────────────────────────
|
470 |
def _pipeline(uid,img_path):
|
471 |
print(f"▶️ [{uid}] processing")
|
|
|
423 |
img_id = _uid()
|
424 |
img_path = f"{UPLOAD_DIR}/{img_id}_{file.filename}"
|
425 |
out_path = f"{OUTPUT_DIR}/{img_id}_classified.jpg"
|
426 |
+
# Save uploaded file
|
427 |
with open(img_path, "wb") as f:
|
428 |
shutil.copyfileobj(file.file, f)
|
429 |
+
# Read file
|
430 |
print(f"[Classification] Received image: {img_path}")
|
431 |
image = cv2.imread(img_path)
|
432 |
rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
433 |
pil = Image.fromarray(rgb)
|
434 |
+
# ─── Detection from 3 models ─────────────────────────────
|
435 |
detections = []
|
436 |
+
# YOLOv11 (self-trained)
|
437 |
+
for r in model_self(image):
|
438 |
+
detections += [b.xyxy[0].tolist() for b in r.boxes]
|
439 |
+
# YOLOv5
|
440 |
+
r = model_yolo5(image)
|
441 |
+
if hasattr(r, 'pred') and len(r.pred) > 0:
|
442 |
+
detections += [p[:4].tolist() for p in r.pred[0]]
|
443 |
+
# DETR
|
444 |
with torch.no_grad():
|
445 |
+
out = model_detr(**processor_detr(images=pil, return_tensors="pt"))
|
446 |
results = processor_detr.post_process_object_detection(
|
447 |
outputs=out,
|
448 |
target_sizes=torch.tensor([pil.size[::-1]]),
|
449 |
threshold=0.5
|
450 |
)[0]
|
451 |
+
detections += [b.tolist() for b in results["boxes"]]
|
452 |
+
print(f"[Classification] Total detections from 3 models: {len(detections)}")
|
453 |
+
# ─── Classification & Rendering ─────────────────────────
|
454 |
+
for box in detections:
|
455 |
+
x1, y1, x2, y2 = map(int, box)
|
456 |
+
x1, x2 = max(0, min(x1, 639)), max(0, min(x2, 639))
|
457 |
+
y1, y2 = max(0, min(y1, 639)), max(0, min(y2, 639))
|
458 |
+
# Stack all crops
|
459 |
crop = image[y1:y2, x1:x2]
|
460 |
if crop.shape[0] < 10 or crop.shape[1] < 10:
|
461 |
+
continue
|
462 |
+
# Image processing
|
463 |
pil_crop = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
|
464 |
+
with torch.no_grad():
|
465 |
+
pred = model_garbage_cls(pil_crop, verbose=False)[0]
|
466 |
class_id = int(pred.probs.top1)
|
467 |
class_name = model_garbage_cls.names[class_id]
|
468 |
+
conf = float(pred.probs.top1conf)
|
469 |
+
# Label format
|
470 |
label = f"{class_name} ({conf:.2f})"
|
471 |
+
# Dynamic color coding
|
472 |
+
if conf < 0.4:
|
473 |
+
color = (0, 0, 255) # Red
|
474 |
+
elif conf < 0.6:
|
475 |
+
color = (0, 255, 0) # Green
|
476 |
+
elif conf < 0.8:
|
477 |
+
color = (255, 255, 0) # Sky Blue
|
478 |
+
else:
|
479 |
+
color = (255, 0, 255) # Purple
|
480 |
+
# Labelling
|
481 |
+
cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
|
482 |
+
cv2.putText(image, label, (x1, y1 - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
483 |
+
# Save result
|
484 |
cv2.imwrite(out_path, image)
|
485 |
print(f"[Classification] Output saved: {out_path}")
|
486 |
return FileResponse(out_path, media_type="image/jpeg")
|
487 |
|
488 |
|
489 |
+
|
490 |
# ── Core pipeline (runs in background thread) ───────────────────────────
|
491 |
def _pipeline(uid,img_path):
|
492 |
print(f"▶️ [{uid}] processing")
|