thecollabagepatch commited on
Commit
f0823f1
·
1 Parent(s): 5b910df

finetune assets loading fix

Browse files
Files changed (1) hide show
  1. app.py +31 -10
app.py CHANGED
@@ -134,11 +134,12 @@ _CENTROIDS: np.ndarray | None = None # shape (K, D) dtype float32
134
  asset_manager = AssetManager()
135
  model_selector = ModelSelector(CheckpointManager(), asset_manager)
136
 
137
- # Sync asset manager with existing globals
138
- # def _sync_asset_manager():
139
- # asset_manager.mean_embed = _MEAN_EMBED
140
- # asset_manager.centroids = _CENTROIDS
141
- # asset_manager.assets_repo_id = _ASSETS_REPO_ID
 
142
 
143
  def _any_jam_running() -> bool:
144
  with jam_lock:
@@ -335,15 +336,20 @@ def get_mrt():
335
  if _MRT is None:
336
  with _MRT_LOCK:
337
  if _MRT is None:
338
- from model_management import CheckpointManager
339
- ckpt_dir = CheckpointManager.resolve_checkpoint_dir() # ← Updated call
340
  _MRT = system.MagentaRT(
341
  tag=os.getenv("MRT_SIZE", "large"),
342
  guidance_weight=5.0,
343
  device="gpu",
344
  checkpoint_dir=ckpt_dir,
345
- lazy=False,
346
  )
 
 
 
 
 
 
347
  return _MRT
348
 
349
  _WARMED = False
@@ -420,9 +426,18 @@ def _mrt_warmup():
420
  # startup and model selection
421
  # ----------------------------
422
 
423
- # Kick it off in the background on server start
424
  @app.on_event("startup")
425
- def _kickoff_warmup():
 
 
 
 
 
 
 
 
 
 
426
  if os.getenv("MRT_WARMUP", "1") != "0":
427
  threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
428
 
@@ -556,6 +571,8 @@ def model_select(req: ModelSelect):
556
  if "error" in validation_result:
557
  raise HTTPException(status_code=400, detail=validation_result["error"])
558
  return {"ok": False, **validation_result}
 
 
559
 
560
  # Augment response surface
561
  validation_result["active_jam"] = _any_jam_running()
@@ -563,6 +580,10 @@ def model_select(req: ModelSelect):
563
  # Dry-run path
564
  if req.dry_run:
565
  return {"ok": True, "dry_run": True, **validation_result}
 
 
 
 
566
 
567
  # 2) Handle jam policy
568
  if _any_jam_running():
 
134
  asset_manager = AssetManager()
135
  model_selector = ModelSelector(CheckpointManager(), asset_manager)
136
 
137
+ def _sync_assets_globals_from_manager():
138
+ # Keeps /model/config in sync with what the asset manager has
139
+ global _MEAN_EMBED, _CENTROIDS, _ASSETS_REPO_ID
140
+ _MEAN_EMBED = asset_manager.mean_embed
141
+ _CENTROIDS = asset_manager.centroids
142
+ _ASSETS_REPO_ID = asset_manager.assets_repo_id
143
 
144
  def _any_jam_running() -> bool:
145
  with jam_lock:
 
336
  if _MRT is None:
337
  with _MRT_LOCK:
338
  if _MRT is None:
339
+ ckpt_dir = CheckpointManager.resolve_checkpoint_dir() # uses MRT_CKPT_REPO/STEP if present
 
340
  _MRT = system.MagentaRT(
341
  tag=os.getenv("MRT_SIZE", "large"),
342
  guidance_weight=5.0,
343
  device="gpu",
344
  checkpoint_dir=ckpt_dir,
345
+ lazy=False
346
  )
347
+ # If no assets loaded yet, and a repo is configured, load them now.
348
+ if asset_manager.mean_embed is None and asset_manager.centroids is None:
349
+ repo = os.getenv("MRT_ASSETS_REPO") or os.getenv("MRT_CKPT_REPO")
350
+ if repo:
351
+ asset_manager.load_finetune_assets_from_hf(repo, None)
352
+ _sync_assets_globals_from_manager()
353
  return _MRT
354
 
355
  _WARMED = False
 
426
  # startup and model selection
427
  # ----------------------------
428
 
 
429
  @app.on_event("startup")
430
+ def _boot():
431
+ # 1) Load finetune assets up front (only if envs are present)
432
+ repo = os.getenv("MRT_ASSETS_REPO") or os.getenv("MRT_CKPT_REPO")
433
+ if repo:
434
+ ok, msg = asset_manager.load_finetune_assets_from_hf(repo, None)
435
+ _sync_assets_globals_from_manager() # keep /model/config in sync
436
+ logging.info("Startup asset load from %s: %s", repo, "ok" if ok else msg)
437
+ else:
438
+ logging.info("Startup asset load: no repo env set; skipping.")
439
+
440
+ # 2) Start warmup in the background (unchanged behavior)
441
  if os.getenv("MRT_WARMUP", "1") != "0":
442
  threading.Thread(target=_mrt_warmup, name="mrt-warmup", daemon=True).start()
443
 
 
571
  if "error" in validation_result:
572
  raise HTTPException(status_code=400, detail=validation_result["error"])
573
  return {"ok": False, **validation_result}
574
+
575
+
576
 
577
  # Augment response surface
578
  validation_result["active_jam"] = _any_jam_running()
 
580
  # Dry-run path
581
  if req.dry_run:
582
  return {"ok": True, "dry_run": True, **validation_result}
583
+
584
+ if req.ckpt_step == "none": # user asked for stock base
585
+ asset_manager.clear_assets() # implement .clear_assets() to set embeds/centroids to None
586
+ _sync_assets_globals_from_manager()
587
 
588
  # 2) Handle jam policy
589
  if _any_jam_running():