thecollabagepatch commited on
Commit
2cbee4c
·
1 Parent(s): 147fd8b

resolving zarray files

Browse files
Files changed (1) hide show
  1. app.py +27 -17
app.py CHANGED
@@ -70,44 +70,54 @@ import re
70
  from pathlib import Path
71
 
72
  def _resolve_checkpoint_dir() -> str | None:
73
- """
74
- Returns a local directory path for MagentaRT(checkpoint_dir=...),
75
- using a Hugging Face model repo that contains subfolders like:
76
- checkpoint_1861001/, checkpoint_1862001/, ...
77
- """
78
  repo_id = os.getenv("MRT_CKPT_REPO")
79
  if not repo_id:
80
- return None # fall back to builtin 'base'/'large' assets
81
 
82
  step = os.getenv("MRT_CKPT_STEP") # e.g., "1863001"
 
 
83
  allow = None
84
  if step:
85
- # only pull that step + optional centroid files
86
- allow = [f"checkpoint_{step}/**", "cluster_centroids.npy", "mean_style_embed.npy"]
 
 
 
 
 
 
87
 
88
- from huggingface_hub import snapshot_download
89
  local = snapshot_download(
90
  repo_id=repo_id,
91
  repo_type="model",
 
92
  local_dir="/home/appuser/.cache/mrt_ckpt/repo",
93
  local_dir_use_symlinks=False,
94
  allow_patterns=allow or ["*"], # whole repo if no step provided
95
  )
96
  root = Path(local)
97
 
98
- # If a step is specified, return that subfolder
99
  if step:
100
- cand = root / f"checkpoint_{step}"
101
- if cand.is_dir():
102
- return str(cand)
103
-
104
- # Otherwise pick the numerically latest checkpoint_* folder
 
 
 
 
 
105
  step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\\d+$", d.name)]
106
  if step_dirs:
107
- pick = max(step_dirs, key=lambda d: int(d.name.split("_")[-1]))
 
 
 
 
108
  return str(pick)
109
 
110
- # Fallback: repo itself might already be a single checkpoint directory
111
  return str(root)
112
 
113
 
 
70
  from pathlib import Path
71
 
72
  def _resolve_checkpoint_dir() -> str | None:
 
 
 
 
 
73
  repo_id = os.getenv("MRT_CKPT_REPO")
74
  if not repo_id:
75
+ return None
76
 
77
  step = os.getenv("MRT_CKPT_STEP") # e.g., "1863001"
78
+
79
+ from huggingface_hub import snapshot_download
80
  allow = None
81
  if step:
82
+ base = f"checkpoint_{step}"
83
+ # include everything under the step *including dotfiles*
84
+ allow = [
85
+ f"{base}/**", # all regular files
86
+ f"{base}/**/.*", # dotfiles like .zarray / .zattrs
87
+ f"{base}/**/.zarray",
88
+ f"{base}/**/.zattrs",
89
+ ]
90
 
 
91
  local = snapshot_download(
92
  repo_id=repo_id,
93
  repo_type="model",
94
+ revision=os.getenv("MRT_CKPT_REV", "main"),
95
  local_dir="/home/appuser/.cache/mrt_ckpt/repo",
96
  local_dir_use_symlinks=False,
97
  allow_patterns=allow or ["*"], # whole repo if no step provided
98
  )
99
  root = Path(local)
100
 
 
101
  if step:
102
+ step_dir = root / f"checkpoint_{step}"
103
+ # sanity check: make sure dotfiles arrived
104
+ if not any(step_dir.rglob(".zarray")):
105
+ raise RuntimeError(
106
+ f"Checkpoint appears incomplete (no .zarray files under {step_dir}). "
107
+ "Ensure allow_patterns includes dotfiles or re-upload preserving dotfiles."
108
+ )
109
+ return str(step_dir)
110
+
111
+ # otherwise pick latest checkpoint_* directory
112
  step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\\d+$", d.name)]
113
  if step_dirs:
114
+ pick = max(step_dirs, key=lambda d: int(d.name.split('_')[-1]))
115
+ if not any(pick.rglob(".zarray")):
116
+ raise RuntimeError(
117
+ f"Checkpoint appears incomplete (no .zarray files under {pick})."
118
+ )
119
  return str(pick)
120
 
 
121
  return str(root)
122
 
123