thecollabagepatch commited on
Commit
0577e3b
·
1 Parent(s): 30fdbbc

model/select endpoint warmup fix

Browse files
Files changed (1) hide show
  1. app.py +45 -29
app.py CHANGED
@@ -498,33 +498,37 @@ def model_checkpoints(repo_id: str, revision: str = "main"):
498
 
499
  @app.post("/model/select")
500
  def model_select(req: ModelSelect):
501
- global _MRT, _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
502
-
503
- # Use ModelSelector to validate the request
 
 
 
 
504
  success, validation_result = model_selector.validate_selection(req)
505
  if not success:
506
  if "error" in validation_result:
507
  raise HTTPException(status_code=400, detail=validation_result["error"])
508
  return {"ok": False, **validation_result}
509
-
510
- # Add active_jam status to the validation result
511
  validation_result["active_jam"] = _any_jam_running()
512
-
513
- # If dry run, return the validation result
514
  if req.dry_run:
515
  return {"ok": True, "dry_run": True, **validation_result}
516
 
517
- # Handle jam policy
518
  if _any_jam_running():
519
  if req.stop_active:
520
  _stop_all_jams()
521
  else:
522
  raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
523
 
524
- # Prepare environment changes
525
  env_changes = model_selector.prepare_env_changes(req, validation_result)
526
-
527
- # Save current environment for rollback
528
  old_env = {
529
  "MRT_SIZE": os.getenv("MRT_SIZE"),
530
  "MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
@@ -532,52 +536,64 @@ def model_select(req: ModelSelect):
532
  "MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
533
  "MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
534
  }
535
-
536
  try:
537
- # Apply environment changes atomically
538
  for key, value in env_changes.items():
539
  if value is None:
540
  os.environ.pop(key, None)
541
  else:
542
  os.environ[key] = str(value)
543
 
544
- # Force model rebuild
545
  with _MRT_LOCK:
546
  _MRT = None
 
 
547
 
548
- # Load finetune assets if requested
549
  if req.sync_assets and validation_result.get("assets_repo"):
550
  ok, msg = asset_manager.load_finetune_assets_from_hf(
551
- validation_result["assets_repo"],
552
- get_mrt() if req.prewarm else None
553
  )
554
  if ok:
555
- # Sync globals after successful asset loading
556
  _MEAN_EMBED = asset_manager.mean_embed
557
  _CENTROIDS = asset_manager.centroids
558
  _ASSETS_REPO_ID = asset_manager.assets_repo_id
 
 
559
 
560
- # Optional prewarm to amortize JIT
 
 
561
  if req.prewarm:
562
- get_mrt()
 
 
 
 
 
 
 
 
 
 
563
 
564
- return {"ok": True, **validation_result}
565
-
566
  except Exception as e:
567
- # Rollback on error
568
  for k, v in old_env.items():
569
  if v is None:
570
  os.environ.pop(k, None)
571
  else:
572
  os.environ[k] = v
 
573
  with _MRT_LOCK:
574
  _MRT = None
575
- # Try to restore working state
576
- try:
577
- get_mrt()
578
- except Exception:
579
- pass
580
- raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
581
 
582
 
583
 
 
498
 
499
  @app.post("/model/select")
500
  def model_select(req: ModelSelect):
501
+ """
502
+ Swap model/checkpoint/assets. If req.prewarm is True, run the full bar-aligned warmup
503
+ (_mrt_warmup) synchronously so we only report warmed once the new model is actually ready.
504
+ """
505
+ global _MRT, _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID, _WARMED
506
+
507
+ # 1) Validate the request (no side-effects)
508
  success, validation_result = model_selector.validate_selection(req)
509
  if not success:
510
  if "error" in validation_result:
511
  raise HTTPException(status_code=400, detail=validation_result["error"])
512
  return {"ok": False, **validation_result}
513
+
514
+ # Augment response surface
515
  validation_result["active_jam"] = _any_jam_running()
516
+
517
+ # Dry-run path
518
  if req.dry_run:
519
  return {"ok": True, "dry_run": True, **validation_result}
520
 
521
+ # 2) Handle jam policy
522
  if _any_jam_running():
523
  if req.stop_active:
524
  _stop_all_jams()
525
  else:
526
  raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
527
 
528
+ # 3) Compute environment changes (no mutation yet)
529
  env_changes = model_selector.prepare_env_changes(req, validation_result)
530
+
531
+ # Keep current env for rollback
532
  old_env = {
533
  "MRT_SIZE": os.getenv("MRT_SIZE"),
534
  "MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
 
536
  "MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
537
  "MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
538
  }
539
+
540
  try:
541
+ # 4) Apply env atomically
542
  for key, value in env_changes.items():
543
  if value is None:
544
  os.environ.pop(key, None)
545
  else:
546
  os.environ[key] = str(value)
547
 
548
+ # 5) Force rebuild of the model and reset warmup state
549
  with _MRT_LOCK:
550
  _MRT = None
551
+ with _WARMUP_LOCK:
552
+ _WARMED = False # ← critical: don't leak previous model's warmed state
553
 
554
+ # 6) Load finetune assets if requested (mean/centroids)
555
  if req.sync_assets and validation_result.get("assets_repo"):
556
  ok, msg = asset_manager.load_finetune_assets_from_hf(
557
+ validation_result["assets_repo"],
558
+ None # don't implicitly instantiate model here; we'll do it below
559
  )
560
  if ok:
 
561
  _MEAN_EMBED = asset_manager.mean_embed
562
  _CENTROIDS = asset_manager.centroids
563
  _ASSETS_REPO_ID = asset_manager.assets_repo_id
564
+ else:
565
+ logging.warning("Asset sync skipped/failed: %s", msg)
566
 
567
+ # 7) Prewarm behavior:
568
+ # - If prewarm=True, run the *real* bar-aligned warmup synchronously.
569
+ # - This will instantiate the new MRT and set _WARMED=True on success.
570
  if req.prewarm:
571
+ _mrt_warmup() # builds MRT internally via get_mrt(), runs generate_chunk, sets _WARMED
572
+
573
+ # Optional: if you want to always ensure MRT exists (even without prewarm), uncomment:
574
+ # else:
575
+ # _ = get_mrt()
576
+
577
+ return {
578
+ "ok": True,
579
+ **validation_result,
580
+ "warmup_done": bool(_WARMED),
581
+ }
582
 
 
 
583
  except Exception as e:
584
+ # 8) Roll back env on failure
585
  for k, v in old_env.items():
586
  if v is None:
587
  os.environ.pop(k, None)
588
  else:
589
  os.environ[k] = v
590
+ # Also reset model pointer & warmed flag to a safe state
591
  with _MRT_LOCK:
592
  _MRT = None
593
+ with _WARMUP_LOCK:
594
+ _WARMED = False
595
+ logging.exception("Model select failed: %s", e)
596
+ raise HTTPException(status_code=500, detail=f"Model select failed: {e}")
 
 
597
 
598
 
599