Commit
·
0577e3b
1
Parent(s):
30fdbbc
model/select endpoint warmup fix
Browse files
app.py
CHANGED
@@ -498,33 +498,37 @@ def model_checkpoints(repo_id: str, revision: str = "main"):
|
|
498 |
|
499 |
@app.post("/model/select")
|
500 |
def model_select(req: ModelSelect):
|
501 |
-
|
502 |
-
|
503 |
-
|
|
|
|
|
|
|
|
|
504 |
success, validation_result = model_selector.validate_selection(req)
|
505 |
if not success:
|
506 |
if "error" in validation_result:
|
507 |
raise HTTPException(status_code=400, detail=validation_result["error"])
|
508 |
return {"ok": False, **validation_result}
|
509 |
-
|
510 |
-
#
|
511 |
validation_result["active_jam"] = _any_jam_running()
|
512 |
-
|
513 |
-
#
|
514 |
if req.dry_run:
|
515 |
return {"ok": True, "dry_run": True, **validation_result}
|
516 |
|
517 |
-
# Handle jam policy
|
518 |
if _any_jam_running():
|
519 |
if req.stop_active:
|
520 |
_stop_all_jams()
|
521 |
else:
|
522 |
raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
|
523 |
|
524 |
-
#
|
525 |
env_changes = model_selector.prepare_env_changes(req, validation_result)
|
526 |
-
|
527 |
-
#
|
528 |
old_env = {
|
529 |
"MRT_SIZE": os.getenv("MRT_SIZE"),
|
530 |
"MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
|
@@ -532,52 +536,64 @@ def model_select(req: ModelSelect):
|
|
532 |
"MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
|
533 |
"MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
|
534 |
}
|
535 |
-
|
536 |
try:
|
537 |
-
# Apply
|
538 |
for key, value in env_changes.items():
|
539 |
if value is None:
|
540 |
os.environ.pop(key, None)
|
541 |
else:
|
542 |
os.environ[key] = str(value)
|
543 |
|
544 |
-
# Force model
|
545 |
with _MRT_LOCK:
|
546 |
_MRT = None
|
|
|
|
|
547 |
|
548 |
-
# Load finetune assets if requested
|
549 |
if req.sync_assets and validation_result.get("assets_repo"):
|
550 |
ok, msg = asset_manager.load_finetune_assets_from_hf(
|
551 |
-
validation_result["assets_repo"],
|
552 |
-
|
553 |
)
|
554 |
if ok:
|
555 |
-
# Sync globals after successful asset loading
|
556 |
_MEAN_EMBED = asset_manager.mean_embed
|
557 |
_CENTROIDS = asset_manager.centroids
|
558 |
_ASSETS_REPO_ID = asset_manager.assets_repo_id
|
|
|
|
|
559 |
|
560 |
-
#
|
|
|
|
|
561 |
if req.prewarm:
|
562 |
-
get_mrt()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
|
564 |
-
return {"ok": True, **validation_result}
|
565 |
-
|
566 |
except Exception as e:
|
567 |
-
#
|
568 |
for k, v in old_env.items():
|
569 |
if v is None:
|
570 |
os.environ.pop(k, None)
|
571 |
else:
|
572 |
os.environ[k] = v
|
|
|
573 |
with _MRT_LOCK:
|
574 |
_MRT = None
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
pass
|
580 |
-
raise HTTPException(status_code=500, detail=f"Swap failed: {e}")
|
581 |
|
582 |
|
583 |
|
|
|
498 |
|
499 |
@app.post("/model/select")
|
500 |
def model_select(req: ModelSelect):
|
501 |
+
"""
|
502 |
+
Swap model/checkpoint/assets. If req.prewarm is True, run the full bar-aligned warmup
|
503 |
+
(_mrt_warmup) synchronously so we only report warmed once the new model is actually ready.
|
504 |
+
"""
|
505 |
+
global _MRT, _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID, _WARMED
|
506 |
+
|
507 |
+
# 1) Validate the request (no side-effects)
|
508 |
success, validation_result = model_selector.validate_selection(req)
|
509 |
if not success:
|
510 |
if "error" in validation_result:
|
511 |
raise HTTPException(status_code=400, detail=validation_result["error"])
|
512 |
return {"ok": False, **validation_result}
|
513 |
+
|
514 |
+
# Augment response surface
|
515 |
validation_result["active_jam"] = _any_jam_running()
|
516 |
+
|
517 |
+
# Dry-run path
|
518 |
if req.dry_run:
|
519 |
return {"ok": True, "dry_run": True, **validation_result}
|
520 |
|
521 |
+
# 2) Handle jam policy
|
522 |
if _any_jam_running():
|
523 |
if req.stop_active:
|
524 |
_stop_all_jams()
|
525 |
else:
|
526 |
raise HTTPException(status_code=409, detail="A jam is running; retry with stop_active=true")
|
527 |
|
528 |
+
# 3) Compute environment changes (no mutation yet)
|
529 |
env_changes = model_selector.prepare_env_changes(req, validation_result)
|
530 |
+
|
531 |
+
# Keep current env for rollback
|
532 |
old_env = {
|
533 |
"MRT_SIZE": os.getenv("MRT_SIZE"),
|
534 |
"MRT_CKPT_REPO": os.getenv("MRT_CKPT_REPO"),
|
|
|
536 |
"MRT_CKPT_STEP": os.getenv("MRT_CKPT_STEP"),
|
537 |
"MRT_ASSETS_REPO": os.getenv("MRT_ASSETS_REPO"),
|
538 |
}
|
539 |
+
|
540 |
try:
|
541 |
+
# 4) Apply env atomically
|
542 |
for key, value in env_changes.items():
|
543 |
if value is None:
|
544 |
os.environ.pop(key, None)
|
545 |
else:
|
546 |
os.environ[key] = str(value)
|
547 |
|
548 |
+
# 5) Force rebuild of the model and reset warmup state
|
549 |
with _MRT_LOCK:
|
550 |
_MRT = None
|
551 |
+
with _WARMUP_LOCK:
|
552 |
+
_WARMED = False # ← critical: don't leak previous model's warmed state
|
553 |
|
554 |
+
# 6) Load finetune assets if requested (mean/centroids)
|
555 |
if req.sync_assets and validation_result.get("assets_repo"):
|
556 |
ok, msg = asset_manager.load_finetune_assets_from_hf(
|
557 |
+
validation_result["assets_repo"],
|
558 |
+
None # don't implicitly instantiate model here; we'll do it below
|
559 |
)
|
560 |
if ok:
|
|
|
561 |
_MEAN_EMBED = asset_manager.mean_embed
|
562 |
_CENTROIDS = asset_manager.centroids
|
563 |
_ASSETS_REPO_ID = asset_manager.assets_repo_id
|
564 |
+
else:
|
565 |
+
logging.warning("Asset sync skipped/failed: %s", msg)
|
566 |
|
567 |
+
# 7) Prewarm behavior:
|
568 |
+
# - If prewarm=True, run the *real* bar-aligned warmup synchronously.
|
569 |
+
# - This will instantiate the new MRT and set _WARMED=True on success.
|
570 |
if req.prewarm:
|
571 |
+
_mrt_warmup() # builds MRT internally via get_mrt(), runs generate_chunk, sets _WARMED
|
572 |
+
|
573 |
+
# Optional: if you want to always ensure MRT exists (even without prewarm), uncomment:
|
574 |
+
# else:
|
575 |
+
# _ = get_mrt()
|
576 |
+
|
577 |
+
return {
|
578 |
+
"ok": True,
|
579 |
+
**validation_result,
|
580 |
+
"warmup_done": bool(_WARMED),
|
581 |
+
}
|
582 |
|
|
|
|
|
583 |
except Exception as e:
|
584 |
+
# 8) Roll back env on failure
|
585 |
for k, v in old_env.items():
|
586 |
if v is None:
|
587 |
os.environ.pop(k, None)
|
588 |
else:
|
589 |
os.environ[k] = v
|
590 |
+
# Also reset model pointer & warmed flag to a safe state
|
591 |
with _MRT_LOCK:
|
592 |
_MRT = None
|
593 |
+
with _WARMUP_LOCK:
|
594 |
+
_WARMED = False
|
595 |
+
logging.exception("Model select failed: %s", e)
|
596 |
+
raise HTTPException(status_code=500, detail=f"Model select failed: {e}")
|
|
|
|
|
597 |
|
598 |
|
599 |
|