thecollabagepatch commited on
Commit
80346b6
·
1 Parent(s): c3769b0

using tar in model repo

Browse files
Files changed (1) hide show
  1. app.py +44 -34
app.py CHANGED
@@ -66,59 +66,69 @@ except Exception:
66
  class ClientDisconnected(Exception): # fallback
67
  pass
68
 
69
- import re
70
  from pathlib import Path
 
71
 
72
  def _resolve_checkpoint_dir() -> str | None:
73
  repo_id = os.getenv("MRT_CKPT_REPO")
74
  if not repo_id:
75
  return None
 
76
 
77
- step = os.getenv("MRT_CKPT_STEP") # e.g., "1863001"
78
-
79
- from huggingface_hub import snapshot_download
80
- allow = None
81
- if step:
82
- base = f"checkpoint_{step}"
83
- # include everything under the step *including dotfiles*
84
- allow = [
85
- f"{base}/**", # all regular files
86
- f"{base}/**/.*", # dotfiles like .zarray / .zattrs
87
- f"{base}/**/.zarray",
88
- f"{base}/**/.zattrs",
89
- ]
90
-
91
- local = snapshot_download(
92
  repo_id=repo_id,
93
  repo_type="model",
94
  revision=os.getenv("MRT_CKPT_REV", "main"),
95
  local_dir="/home/appuser/.cache/mrt_ckpt/repo",
96
  local_dir_use_symlinks=False,
97
- # <- no allow_patterns, we grab everything to avoid dotfile issues
98
- )
99
- root = Path(local)
100
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  if step:
102
- step_dir = root / f"checkpoint_{step}"
103
- # sanity check: make sure dotfiles arrived
104
- if not any(step_dir.rglob(".zarray")):
105
- raise RuntimeError(
106
- f"Checkpoint appears incomplete (no .zarray files under {step_dir}). "
107
- "Ensure allow_patterns includes dotfiles or re-upload preserving dotfiles."
108
- )
109
- return str(step_dir)
110
-
111
- # otherwise pick latest checkpoint_* directory
112
  step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\\d+$", d.name)]
113
  if step_dirs:
114
  pick = max(step_dirs, key=lambda d: int(d.name.split('_')[-1]))
115
  if not any(pick.rglob(".zarray")):
116
- raise RuntimeError(
117
- f"Checkpoint appears incomplete (no .zarray files under {pick})."
118
- )
119
  return str(pick)
120
 
121
- return str(root)
122
 
123
 
124
  async def send_json_safe(ws: WebSocket, obj) -> bool:
 
66
  class ClientDisconnected(Exception): # fallback
67
  pass
68
 
69
+ import re, tarfile
70
  from pathlib import Path
71
+ from huggingface_hub import snapshot_download
72
 
73
  def _resolve_checkpoint_dir() -> str | None:
74
  repo_id = os.getenv("MRT_CKPT_REPO")
75
  if not repo_id:
76
  return None
77
+ step = os.getenv("MRT_CKPT_STEP") # e.g. "1863001"
78
 
79
+ root = Path(snapshot_download(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  repo_id=repo_id,
81
  repo_type="model",
82
  revision=os.getenv("MRT_CKPT_REV", "main"),
83
  local_dir="/home/appuser/.cache/mrt_ckpt/repo",
84
  local_dir_use_symlinks=False,
85
+ ))
86
+
87
+ # Prefer an archive if present (more reliable for Zarr/T5X)
88
+ arch_names = [
89
+ f"checkpoint_{step}.tgz",
90
+ f"checkpoint_{step}.tar.gz",
91
+ f"archives/checkpoint_{step}.tgz",
92
+ f"archives/checkpoint_{step}.tar.gz",
93
+ ] if step else []
94
+
95
+ cache_root = Path("/home/appuser/.cache/mrt_ckpt/extracted")
96
+ cache_root.mkdir(parents=True, exist_ok=True)
97
+ for name in arch_names:
98
+ arch = root / name
99
+ if arch.is_file():
100
+ out_dir = cache_root / f"checkpoint_{step}"
101
+ marker = out_dir.with_suffix(".ok")
102
+ if not marker.exists():
103
+ out_dir.mkdir(parents=True, exist_ok=True)
104
+ with tarfile.open(arch, "r:*") as tf:
105
+ tf.extractall(out_dir)
106
+ marker.write_text("ok")
107
+ # sanity: require .zarray to exist inside the extracted tree
108
+ if not any(out_dir.rglob(".zarray")):
109
+ raise RuntimeError(f"Extracted archive missing .zarray files: {out_dir}")
110
+ return str(out_dir / f"checkpoint_{step}") if (out_dir / f"checkpoint_{step}").exists() else str(out_dir)
111
+
112
+ # No archive; try raw folder from repo and sanity check.
113
  if step:
114
+ raw = root / f"checkpoint_{step}"
115
+ if raw.is_dir():
116
+ if not any(raw.rglob(".zarray")):
117
+ raise RuntimeError(
118
+ f"Downloaded checkpoint_{step} appears incomplete (no .zarray). "
119
+ "Upload as a .tgz or push via git from a Unix shell."
120
+ )
121
+ return str(raw)
122
+
123
+ # Pick latest if no step
124
  step_dirs = [d for d in root.iterdir() if d.is_dir() and re.match(r"checkpoint_\\d+$", d.name)]
125
  if step_dirs:
126
  pick = max(step_dirs, key=lambda d: int(d.name.split('_')[-1]))
127
  if not any(pick.rglob(".zarray")):
128
+ raise RuntimeError(f"Downloaded {pick} appears incomplete (no .zarray).")
 
 
129
  return str(pick)
130
 
131
+ return None
132
 
133
 
134
  async def send_json_safe(ws: WebSocket, obj) -> bool: