LiamKhoaLe commited on
Commit
22bb5ba
·
1 Parent(s): b75ea7b

Update garb cls to render on colors and infer 3 models together

Browse files
Files changed (1) hide show
  1. app.py +40 -19
app.py CHANGED
@@ -423,49 +423,70 @@ async def classify_garbage(file: UploadFile = File(...)):
423
  img_id = _uid()
424
  img_path = f"{UPLOAD_DIR}/{img_id}_{file.filename}"
425
  out_path = f"{OUTPUT_DIR}/{img_id}_classified.jpg"
426
- # Load file
427
  with open(img_path, "wb") as f:
428
  shutil.copyfileobj(file.file, f)
429
- # Read image
430
  print(f"[Classification] Received image: {img_path}")
431
  image = cv2.imread(img_path)
432
  rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
433
  pil = Image.fromarray(rgb)
434
- # DETR for garbage detection boxes
435
  detections = []
436
- inp = processor_detr(images=pil, return_tensors="pt")
 
 
 
 
 
 
 
437
  with torch.no_grad():
438
- out = model_detr(**inp)
439
  results = processor_detr.post_process_object_detection(
440
  outputs=out,
441
  target_sizes=torch.tensor([pil.size[::-1]]),
442
  threshold=0.5
443
  )[0]
444
- # Bbox return
445
- boxes = results["boxes"]
446
- print(f"[Classification] {len(boxes)} garbage objects detected by DETR.")
447
- # Mapping in between
448
- for i, box in enumerate(boxes):
449
- x1, y1, x2, y2 = map(int, box.tolist())
 
 
450
  crop = image[y1:y2, x1:x2]
451
  if crop.shape[0] < 10 or crop.shape[1] < 10:
452
- continue # skip tiny crops
453
- # Convert crop to RGB and classify
454
  pil_crop = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
455
- pred = model_garbage_cls(pil_crop, verbose=False)[0]
 
456
  class_id = int(pred.probs.top1)
457
  class_name = model_garbage_cls.names[class_id]
458
- conf = pred.probs.top1conf
459
- # Labelling on output image
460
  label = f"{class_name} ({conf:.2f})"
461
- cv2.rectangle(image, (x1, y1), (x2, y2), (0, 165, 255), 2)
462
- cv2.putText(image, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 165, 255), 2)
463
- # Write image on render
 
 
 
 
 
 
 
 
 
 
464
  cv2.imwrite(out_path, image)
465
  print(f"[Classification] Output saved: {out_path}")
466
  return FileResponse(out_path, media_type="image/jpeg")
467
 
468
 
 
469
  # ── Core pipeline (runs in background thread) ───────────────────────────
470
  def _pipeline(uid,img_path):
471
  print(f"▶️ [{uid}] processing")
 
423
  img_id = _uid()
424
  img_path = f"{UPLOAD_DIR}/{img_id}_{file.filename}"
425
  out_path = f"{OUTPUT_DIR}/{img_id}_classified.jpg"
426
+ # Save uploaded file
427
  with open(img_path, "wb") as f:
428
  shutil.copyfileobj(file.file, f)
429
+ # Read file
430
  print(f"[Classification] Received image: {img_path}")
431
  image = cv2.imread(img_path)
432
  rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
433
  pil = Image.fromarray(rgb)
434
+ # ─── Detection from 3 models ─────────────────────────────
435
  detections = []
436
+ # YOLOv11 (self-trained)
437
+ for r in model_self(image):
438
+ detections += [b.xyxy[0].tolist() for b in r.boxes]
439
+ # YOLOv5
440
+ r = model_yolo5(image)
441
+ if hasattr(r, 'pred') and len(r.pred) > 0:
442
+ detections += [p[:4].tolist() for p in r.pred[0]]
443
+ # DETR
444
  with torch.no_grad():
445
+ out = model_detr(**processor_detr(images=pil, return_tensors="pt"))
446
  results = processor_detr.post_process_object_detection(
447
  outputs=out,
448
  target_sizes=torch.tensor([pil.size[::-1]]),
449
  threshold=0.5
450
  )[0]
451
+ detections += [b.tolist() for b in results["boxes"]]
452
+ print(f"[Classification] Total detections from 3 models: {len(detections)}")
453
+ # ─── Classification & Rendering ─────────────────────────
454
+ for box in detections:
455
+ x1, y1, x2, y2 = map(int, box)
456
+ x1, x2 = max(0, min(x1, 639)), max(0, min(x2, 639))
457
+ y1, y2 = max(0, min(y1, 639)), max(0, min(y2, 639))
458
+ # Stack all crops
459
  crop = image[y1:y2, x1:x2]
460
  if crop.shape[0] < 10 or crop.shape[1] < 10:
461
+ continue
462
+ # Image processing
463
  pil_crop = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
464
+ with torch.no_grad():
465
+ pred = model_garbage_cls(pil_crop, verbose=False)[0]
466
  class_id = int(pred.probs.top1)
467
  class_name = model_garbage_cls.names[class_id]
468
+ conf = float(pred.probs.top1conf)
469
+ # Label format
470
  label = f"{class_name} ({conf:.2f})"
471
+ # Dynamic color coding
472
+ if conf < 0.4:
473
+ color = (0, 0, 255) # Red
474
+ elif conf < 0.6:
475
+ color = (0, 255, 0) # Green
476
+ elif conf < 0.8:
477
+ color = (255, 255, 0) # Sky Blue
478
+ else:
479
+ color = (255, 0, 255) # Purple
480
+ # Labelling
481
+ cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
482
+ cv2.putText(image, label, (x1, y1 - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
483
+ # Save result
484
  cv2.imwrite(out_path, image)
485
  print(f"[Classification] Output saved: {out_path}")
486
  return FileResponse(out_path, media_type="image/jpeg")
487
 
488
 
489
+
490
  # ── Core pipeline (runs in background thread) ───────────────────────────
491
  def _pipeline(uid,img_path):
492
  print(f"▶️ [{uid}] processing")