Commit
·
f0823f1
1
Parent(s):
5b910df
finetune assets loading fix
Browse files
app.py
CHANGED
@@ -134,11 +134,12 @@ _CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
|
|
134 |
asset_manager = AssetManager()
|
135 |
model_selector = ModelSelector(CheckpointManager(), asset_manager)
|
136 |
|
137 |
-
|
138 |
-
#
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
142 |
|
143 |
def _any_jam_running() -> bool:
|
144 |
with jam_lock:
|
@@ -335,15 +336,20 @@ def get_mrt():
|
|
335 |
if _MRT is None:
|
336 |
with _MRT_LOCK:
|
337 |
if _MRT is None:
|
338 |
-
|
339 |
-
ckpt_dir = CheckpointManager.resolve_checkpoint_dir() # ← Updated call
|
340 |
_MRT = system.MagentaRT(
|
341 |
tag=os.getenv("MRT_SIZE", "large"),
|
342 |
guidance_weight=5.0,
|
343 |
device="gpu",
|
344 |
checkpoint_dir=ckpt_dir,
|
345 |
-
lazy=False
|
346 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
return _MRT
|
348 |
|
349 |
_WARMED = False
|
@@ -420,9 +426,18 @@ def _mrt_warmup():
|
|
420 |
# startup and model selection
|
421 |
# ----------------------------
|
422 |
|
423 |
-
# Kick it off in the background on server start
|
424 |
@app.on_event("startup")
|
425 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
if os.getenv("MRT_WARMUP", "1") != "0":
|
427 |
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
428 |
|
@@ -556,6 +571,8 @@ def model_select(req: ModelSelect):
|
|
556 |
if "error" in validation_result:
|
557 |
raise HTTPException(status_code=400, detail=validation_result["error"])
|
558 |
return {"ok": False, **validation_result}
|
|
|
|
|
559 |
|
560 |
# Augment response surface
|
561 |
validation_result["active_jam"] = _any_jam_running()
|
@@ -563,6 +580,10 @@ def model_select(req: ModelSelect):
|
|
563 |
# Dry-run path
|
564 |
if req.dry_run:
|
565 |
return {"ok": True, "dry_run": True, **validation_result}
|
|
|
|
|
|
|
|
|
566 |
|
567 |
# 2) Handle jam policy
|
568 |
if _any_jam_running():
|
|
|
134 |
asset_manager = AssetManager()
|
135 |
model_selector = ModelSelector(CheckpointManager(), asset_manager)
|
136 |
|
137 |
+
def _sync_assets_globals_from_manager():
|
138 |
+
# Keeps /model/config in sync with what the asset manager has
|
139 |
+
global _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
|
140 |
+
_MEAN_EMBED = asset_manager.mean_embed
|
141 |
+
_CENTROIDS = asset_manager.centroids
|
142 |
+
_ASSETS_REPO_ID = asset_manager.assets_repo_id
|
143 |
|
144 |
def _any_jam_running() -> bool:
|
145 |
with jam_lock:
|
|
|
336 |
if _MRT is None:
|
337 |
with _MRT_LOCK:
|
338 |
if _MRT is None:
|
339 |
+
ckpt_dir = CheckpointManager.resolve_checkpoint_dir() # uses MRT_CKPT_REPO/STEP if present
|
|
|
340 |
_MRT = system.MagentaRT(
|
341 |
tag=os.getenv("MRT_SIZE", "large"),
|
342 |
guidance_weight=5.0,
|
343 |
device="gpu",
|
344 |
checkpoint_dir=ckpt_dir,
|
345 |
+
lazy=False
|
346 |
)
|
347 |
+
# If no assets loaded yet, and a repo is configured, load them now.
|
348 |
+
if asset_manager.mean_embed is None and asset_manager.centroids is None:
|
349 |
+
repo = os.getenv("MRT_ASSETS_REPO") or os.getenv("MRT_CKPT_REPO")
|
350 |
+
if repo:
|
351 |
+
asset_manager.load_finetune_assets_from_hf(repo, None)
|
352 |
+
_sync_assets_globals_from_manager()
|
353 |
return _MRT
|
354 |
|
355 |
_WARMED = False
|
|
|
426 |
# startup and model selection
|
427 |
# ----------------------------
|
428 |
|
|
|
429 |
@app.on_event("startup")
|
430 |
+
def _boot():
|
431 |
+
# 1) Load finetune assets up front (only if envs are present)
|
432 |
+
repo = os.getenv("MRT_ASSETS_REPO") or os.getenv("MRT_CKPT_REPO")
|
433 |
+
if repo:
|
434 |
+
ok, msg = asset_manager.load_finetune_assets_from_hf(repo, None)
|
435 |
+
_sync_assets_globals_from_manager() # keep /model/config in sync
|
436 |
+
logging.info("Startup asset load from %s: %s", repo, "ok" if ok else msg)
|
437 |
+
else:
|
438 |
+
logging.info("Startup asset load: no repo env set; skipping.")
|
439 |
+
|
440 |
+
# 2) Start warmup in the background (unchanged behavior)
|
441 |
if os.getenv("MRT_WARMUP", "1") != "0":
|
442 |
threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
|
443 |
|
|
|
571 |
if "error" in validation_result:
|
572 |
raise HTTPException(status_code=400, detail=validation_result["error"])
|
573 |
return {"ok": False, **validation_result}
|
574 |
+
|
575 |
+
|
576 |
|
577 |
# Augment response surface
|
578 |
validation_result["active_jam"] = _any_jam_running()
|
|
|
580 |
# Dry-run path
|
581 |
if req.dry_run:
|
582 |
return {"ok": True, "dry_run": True, **validation_result}
|
583 |
+
|
584 |
+
if req.ckpt_step == "none": # user asked for stock base
|
585 |
+
asset_manager.clear_assets() # implement .clear_assets() to set embeds/centroids to None
|
586 |
+
_sync_assets_globals_from_manager()
|
587 |
|
588 |
# 2) Handle jam policy
|
589 |
if _any_jam_running():
|