transformers-modular-refactor / modular_graph_and_candidates.py
Molbap's picture
Molbap HF Staff
double fonts
69e2272
#!/usr/bin/env python
"""
modular_graph_and_candidates.py
================================
Create **one** rich view that combines
1. The *dependency graph* between existing **modular_*.py** implementations in
πŸ€—Β Transformers (blue/🟑) **and**
2. The list of *missing* modular models (full‑red nodes) **plus** similarity
edges (full‑red links) between highly‑overlapping modelling files – the
output of *find_modular_candidates.py* – so you can immediately spot good
refactor opportunities.
––– Usage –––
```bash
python modular_graph_and_candidates.py /path/to/transformers \
--multimodal # keep only models whose modelling code mentions
# "pixel_values" β‰₯Β 3 times
--sim-threshold 0.5 # Jaccard cutoff (default 0.50)
--out graph.html # output HTML file name
```
Colour legend in the generated HTML:
* 🟑 **base model**Β β€” has modular shards *imported* by others but no parent
* πŸ”΅Β **derived modular model**Β β€” has a `modular_*.py` and inherits from β‰₯β€―1 model
* πŸ”΄Β **candidate**Β β€” no `modular_*.py` yet (and/or very similar to another)
* red edges = high‑Jaccard similarity links (potential to factorise)
"""
from __future__ import annotations
import argparse
import ast
import json
import re
import tokenize
from collections import Counter, defaultdict
from itertools import combinations
from pathlib import Path
from typing import Dict, List, Set, Tuple
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm
import numpy as np
import spaces
import torch
# ────────────────────────────────────────────────────────────────────────────────
# CONFIG
# ───────────────────────────────────────────────────────────────────────────────
SIM_DEFAULT = 0.5 # similarity threshold
PIXEL_MIN_HITS = 0 # multimodal trigger ("pixel_values")
HTML_DEFAULT = "d3_modular_graph.html"
# ────────────────────────────────────────────────────────────────────────────────
# 1) Helpers to analyse *modelling* files (for similarity & multimodal filter)
# ────────────────────────────────────────────────────────────────────────────────
def _strip_source(code: str) -> str:
"""Remove doc‑strings, comments and import lines to keep only the core code."""
code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code) # doc‑strings
code = re.sub(r"#.*", "", code) # # comments
return "\n".join(ln for ln in code.splitlines()
if not re.match(r"\s*(from|import)\s+", ln))
def _tokenise(code: str) -> Set[str]:
"""Extract identifiers using regex - more robust than tokenizer for malformed code."""
toks: Set[str] = set()
for match in re.finditer(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', code):
toks.add(match.group())
return toks
def build_token_bags(models_root: Path) -> Tuple[Dict[str, List[Set[str]]], Dict[str, int]]:
"""Return token‑bags of every `modeling_*.py` plus a pixel‑value counter."""
bags: Dict[str, List[Set[str]]] = defaultdict(list)
pixel_hits: Dict[str, int] = defaultdict(int)
for mdl_dir in sorted(p for p in models_root.iterdir() if p.is_dir()):
for py in mdl_dir.rglob("modeling_*.py"):
try:
text = py.read_text(encoding="utf‑8")
pixel_hits[mdl_dir.name] += text.count("pixel_values")
bags[mdl_dir.name].append(_tokenise(_strip_source(text)))
except Exception as e:
print(f"⚠️ Skipped {py}: {e}")
return bags, pixel_hits
def _jaccard(a: Set[str], b: Set[str]) -> float:
return 0.0 if (not a or not b) else len(a & b) / len(a | b)
def similarity_clusters(bags: Dict[str, List[Set[str]]], thr: float) -> Dict[Tuple[str,str], float]:
"""Return {(modelA, modelB): score} for pairs with Jaccard β‰₯ *thr*."""
largest = {m: max(ts, key=len) for m, ts in bags.items() if ts}
out: Dict[Tuple[str,str], float] = {}
for m1, m2 in combinations(sorted(largest.keys()), 2):
s = _jaccard(largest[m1], largest[m2])
if s >= thr:
out[(m1, m2)] = s
return out
@spaces.GPU
def embedding_similarity_clusters(models_root: Path, missing: List[str], thr: float) -> Dict[Tuple[str, str], float]:
model = SentenceTransformer("codesage/codesage-large-v2", device="cuda", trust_remote_code=True)
try:
cfg = model[0].auto_model.config
pos_limit = int(getattr(cfg, "n_positions", getattr(cfg, "max_position_embeddings")))
except Exception:
pos_limit = 1024
seq_len = min(pos_limit, 2048)
model.max_seq_length = seq_len
model[0].max_seq_length = seq_len
model[0].tokenizer.model_max_length = seq_len
texts = {}
for name in tqdm(missing, desc="Reading modeling files"):
if any(skip in name.lower() for skip in ["mobilebert", "lxmert"]):
print(f"Skipping {name} (causes GPU abort)")
continue
code = ""
for py in (models_root / name).rglob("modeling_*.py"):
try:
code += _strip_source(py.read_text(encoding="utf-8")) + "\n"
except Exception:
continue
texts[name] = code.strip() or " "
names = list(texts)
all_embeddings = []
print(f"Encoding embeddings for {len(names)} models...")
batch_size = 4 # keep your default
# ── two-stage caching: temp (for resume) + permanent (for reuse) ─────────────
temp_cache_path = Path("temp_embeddings.npz") # For resuming computation
final_cache_path = Path("embeddings_cache.npz") # For permanent storage
start_idx = 0
emb_dim = getattr(model, "get_sentence_embedding_dimension", lambda: 768)()
# Try to load from permanent cache first
if final_cache_path.exists():
try:
cached = np.load(final_cache_path, allow_pickle=True)
cached_names = list(cached["names"])
if names == cached_names: # Exact match - use final cache
print(f"βœ… Using final embeddings cache ({len(cached_names)} models)")
return compute_similarities_from_cache(thr)
except Exception as e:
print(f"⚠️ Failed to load final cache: {e}")
# Try to resume from temp cache
if temp_cache_path.exists():
try:
cached = np.load(temp_cache_path, allow_pickle=True)
cached_names = list(cached["names"])
if names[:len(cached_names)] == cached_names:
loaded = cached["embeddings"].astype(np.float32)
all_embeddings.append(loaded)
start_idx = len(cached_names)
print(f"πŸ”„ Resuming from temp cache: {start_idx}/{len(names)} models")
except Exception as e:
print(f"⚠️ Failed to load temp cache: {e}")
# ───────────────────────────────────────────────────────────────────────────
for i in tqdm(range(start_idx, len(names), batch_size), desc="Batches", leave=False):
batch_names = names[i:i+batch_size]
batch_texts = [texts[name] for name in batch_names]
try:
print(f"Processing batch: {batch_names}")
emb = model.encode(batch_texts, convert_to_numpy=True, show_progress_bar=False)
except Exception as e:
print(f"⚠️ GPU worker error for batch {batch_names}: {type(e).__name__}: {e}")
emb = np.zeros((len(batch_names), emb_dim), dtype=np.float32)
all_embeddings.append(emb)
# save to temp cache after each batch (for resume)
try:
cur = np.vstack(all_embeddings).astype(np.float32)
np.savez(
temp_cache_path,
embeddings=cur,
names=np.array(names[:i+len(batch_names)], dtype=object),
)
except Exception as e:
print(f"⚠️ Failed to write temp cache: {e}")
if (i - start_idx) % (3 * batch_size) == 0 and torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(f"🧹 Cleared GPU cache after batch {(i - start_idx)//batch_size + 1}")
embeddings = np.vstack(all_embeddings).astype(np.float32)
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12
embeddings = embeddings / norms
print("Computing pairwise similarities...")
sims_mat = embeddings @ embeddings.T
out = {}
matrix_size = embeddings.shape[0]
processed_names = names[:matrix_size]
for i in range(matrix_size):
for j in range(i + 1, matrix_size):
s = float(sims_mat[i, j])
if s >= thr:
out[(processed_names[i], processed_names[j])] = s
# Save to final cache when complete
try:
np.savez(final_cache_path, embeddings=embeddings, names=np.array(names, dtype=object))
print(f"πŸ’Ύ Final embeddings saved to {final_cache_path}")
# Clean up temp cache
if temp_cache_path.exists():
temp_cache_path.unlink()
print(f"🧹 Cleaned up temp cache")
except Exception as e:
print(f"⚠️ Failed to save final cache: {e}")
return out
def compute_similarities_from_cache(threshold: float) -> Dict[Tuple[str, str], float]:
"""Compute similarities from cached embeddings without reprocessing."""
embeddings_path = Path("embeddings_cache.npz")
if not embeddings_path.exists():
return {}
try:
cached = np.load(embeddings_path, allow_pickle=True)
embeddings = cached["embeddings"].astype(np.float32)
names = list(cached["names"])
# Normalize embeddings
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12
embeddings = embeddings / norms
# Compute similarities
sims_mat = embeddings @ embeddings.T
out = {}
for i in range(len(names)):
for j in range(i + 1, len(names)):
s = float(sims_mat[i, j])
if s >= threshold:
out[(names[i], names[j])] = s
print(f"⚑ Computed {len(out)} similarities from cache (threshold: {threshold})")
return out
except Exception as e:
print(f"⚠️ Failed to compute from cache: {e}")
return {}
# ────────────────────────────────────────────────────────────────────────────────
# 2) Scan *modular_*.py* files to build an import‑dependency graph
# – only **modeling_*** imports are considered (skip configuration / processing)
# ────────────────────────────────────────────────────────────────────────────────
def modular_files(models_root: Path) -> List[Path]:
return [p for p in models_root.rglob("modular_*.py") if p.suffix == ".py"]
def dependency_graph(modular_files: List[Path], models_root: Path) -> Dict[str, List[Dict[str,str]]]:
"""Return {derived_model: [{source, imported_class}, ...]}
Only `modeling_*` imports are kept; anything coming from configuration/processing/
image* utils is ignored so the visual graph focuses strictly on modelling code.
Excludes edges to sources whose model name is not a model dir.
"""
model_names = {p.name for p in models_root.iterdir() if p.is_dir()}
deps: Dict[str, List[Dict[str,str]]] = defaultdict(list)
for fp in modular_files:
derived = fp.parent.name
try:
tree = ast.parse(fp.read_text(encoding="utf‑8"), filename=str(fp))
except Exception as e:
print(f"⚠️ AST parse failed for {fp}: {e}")
continue
for node in ast.walk(tree):
if not isinstance(node, ast.ImportFrom) or not node.module:
continue
mod = node.module
# keep only *modeling_* imports, drop anything else
if ("modeling_" not in mod or
"configuration_" in mod or
"processing_" in mod or
"image_processing" in mod or
"modeling_attn_mask_utils" in mod):
continue
parts = re.split(r"[./]", mod)
src = next((p for p in parts if p not in {"", "models", "transformers"}), "")
if not src or src == derived or src not in model_names:
continue
for alias in node.names:
deps[derived].append({"source": src, "imported_class": alias.name})
return dict(deps)
# modular_graph_and_candidates.py (top-level)
def get_missing_models(models_root: Path, multimodal: bool = False) -> Tuple[List[str], Dict[str, List[Set[str]]], Dict[str, int]]:
"""Get list of models missing modular implementations."""
bags, pix_hits = build_token_bags(models_root)
mod_files = modular_files(models_root)
models_with_modular = {p.parent.name for p in mod_files}
missing = [m for m in bags if m not in models_with_modular]
if multimodal:
missing = [m for m in missing if pix_hits[m] >= PIXEL_MIN_HITS]
return missing, bags, pix_hits
def compute_similarities(models_root: Path, missing: List[str], bags: Dict[str, List[Set[str]]],
threshold: float, sim_method: str) -> Dict[Tuple[str, str], float]:
"""Compute similarities between missing models using specified method."""
if sim_method == "jaccard":
return similarity_clusters({m: bags[m] for m in missing}, threshold)
else:
# Try to use cached embeddings first
embeddings_path = Path("embeddings_cache.npz")
if embeddings_path.exists():
cached_sims = compute_similarities_from_cache(threshold)
if cached_sims: # Cache exists and worked
return cached_sims
# Fallback to full computation
return embedding_similarity_clusters(models_root, missing, threshold)
def build_graph_json(
transformers_dir: Path,
threshold: float = SIM_DEFAULT,
multimodal: bool = False,
sim_method: str = "jaccard",
) -> dict:
"""Return the {nodes, links} dict that D3 needs."""
# Check if we can use cached embeddings only
embeddings_cache = Path("embeddings_cache.npz")
print(f"πŸ” Cache file exists: {embeddings_cache.exists()}, sim_method: {sim_method}")
if sim_method == "embedding" and embeddings_cache.exists():
try:
# Try to compute from cache without accessing repo
cached_sims = compute_similarities_from_cache(threshold)
print(f"πŸ” Got {len(cached_sims)} cached similarities")
if cached_sims:
# Create graph with cached similarities + modular dependencies
cached_data = np.load(embeddings_cache, allow_pickle=True)
missing = list(cached_data["names"])
# Still need to get modular dependencies from repo
models_root = transformers_dir / "src/transformers/models"
mod_files = modular_files(models_root)
deps = dependency_graph(mod_files, models_root)
# Build full graph structure
nodes = set(missing) # Start with cached models
links = []
# Add dependency links
for drv, lst in deps.items():
for d in lst:
links.append({
"source": d["source"],
"target": drv,
"label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports",
"cand": False
})
nodes.update({d["source"], drv})
# Add similarity links
for (a, b), s in cached_sims.items():
links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True})
# Create node list with proper classification
targets = {lk["target"] for lk in links if not lk["cand"]}
sources = {lk["source"] for lk in links if not lk["cand"]}
nodelist = []
for n in sorted(nodes):
if n in missing and n not in sources and n not in targets:
cls = "cand"
elif n in sources and n not in targets:
cls = "base"
else:
cls = "derived"
nodelist.append({"id": n, "cls": cls, "sz": 1})
print(f"⚑ Built graph from cache: {len(nodelist)} nodes, {len(links)} links")
return {"nodes": nodelist, "links": links}
except Exception as e:
print(f"⚠️ Cache-only build failed: {e}, falling back to full build")
# Full build with repository access
models_root = transformers_dir / "src/transformers/models"
# Get missing models and their data
missing, bags, pix_hits = get_missing_models(models_root, multimodal)
# Build dependency graph
mod_files = modular_files(models_root)
deps = dependency_graph(mod_files, models_root)
# Compute similarities
sims = compute_similarities(models_root, missing, bags, threshold, sim_method)
# ---- assemble nodes & links ----
nodes: Set[str] = set()
links: List[dict] = []
for drv, lst in deps.items():
for d in lst:
links.append({
"source": d["source"],
"target": drv,
"label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports",
"cand": False
})
nodes.update({d["source"], drv})
for (a, b), s in sims.items():
links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True})
nodes.update({a, b})
nodes.update(missing)
deg = Counter()
for lk in links:
deg[lk["source"]] += 1
deg[lk["target"]] += 1
max_deg = max(deg.values() or [1])
targets = {lk["target"] for lk in links if not lk["cand"]}
sources = {lk["source"] for lk in links if not lk["cand"]}
missing_only = [m for m in missing if m not in sources and m not in targets]
nodes.update(missing_only)
nodelist = []
for n in sorted(nodes):
if n in missing_only:
cls = "cand"
elif n in sources and n not in targets:
cls = "base"
else:
cls = "derived"
nodelist.append({"id": n, "cls": cls, "sz": 1 + 2*(deg[n]/max_deg)})
graph = {"nodes": nodelist, "links": links}
return graph
def generate_html(graph: dict) -> str:
"""Return the full HTML string with inlined CSS/JS + graph JSON."""
js = JS.replace("__GRAPH_DATA__", json.dumps(graph, separators=(",", ":")))
return HTML.replace("__CSS__", CSS).replace("__JS__", js)
# ────────────────────────────────────────────────────────────────────────────────
# 3) HTML (D3.js) boilerplate – CSS + JS templates (unchanged design)
# ────────────────────────────────────────────────────────────────────────────────
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
:root{
--bg:#ffffff;
--text:#222222;
--muted:#555555;
--outline:#ffffff;
}
@media (prefers-color-scheme: dark){
:root{
--bg:#0b0d10;
--text:#e8e8e8;
--muted:#c8c8c8;
--outline:#000000;
}
}
body{ margin:0; font-family:'Inter',Arial,sans-serif; background:var(--bg); overflow:hidden; }
svg{ width:100vw; height:100vh; }
.link{ stroke:#999; stroke-opacity:.6; }
.link.cand{ stroke:#e63946; stroke-width:2.5; }
.node-label{
fill:var(--text);
pointer-events:none;
text-anchor:middle;
font-weight:600;
paint-order:stroke fill;
stroke:var(--outline);
stroke-width:3px;
}
.link-label{
fill:var(--muted);
pointer-events:none;
text-anchor:middle;
font-size:10px;
paint-order:stroke fill;
stroke:var(--bg);
stroke-width:2px;
}
.node.base image{ width:60px; height:60px; transform:translate(-30px,-30px); }
.node.derived circle{ fill:#1f77b4; }
.node.cand circle, .node.cand path{ fill:#e63946; }
#legend{
position:fixed; top:18px; left:18px;
background:rgba(255,255,255,.92);
padding:18px 28px; border-radius:10px; border:1.5px solid #bbb;
font-size:18px; box-shadow:0 2px 8px rgba(0,0,0,.08);
}
@media (prefers-color-scheme: dark){
#legend{ background:rgba(20,22,25,.92); color:#e8e8e8; border-color:#444; }
}
"""
JS = """
function updateVisibility() {
const show = document.getElementById('toggleRed').checked;
svg.selectAll('.link.cand').style('display', show ? null : 'none');
svg.selectAll('.node.cand').style('display', show ? null : 'none');
svg.selectAll('.link-label').filter(d => d.cand).style('display', show ? null : 'none');
}
document.getElementById('toggleRed').addEventListener('change', updateVisibility);
const HF_LOGO_URI = "./static/hf-logo.png";
const graph = __GRAPH_DATA__;
const W = innerWidth, H = innerHeight;
const svg = d3.select('#dependency').call(d3.zoom().on('zoom', e => g.attr('transform', e.transform)));
const g = svg.append('g');
const link = g.selectAll('line')
.data(graph.links)
.join('line')
.attr('class', d => d.cand ? 'link cand' : 'link');
const linkLbl = g.selectAll('text.link-label')
.data(graph.links)
.join('text')
.attr('class', 'link-label')
.text(d => d.label);
const node = g.selectAll('g.node')
.data(graph.nodes)
.join('g')
.attr('class', d => `node ${d.cls}`)
.call(d3.drag().on('start', dragStart).on('drag', dragged).on('end', dragEnd));
const baseSel = node.filter(d => d.cls === 'base');
if (HF_LOGO_URI){
baseSel.append('image')
.attr('href', HF_LOGO_URI)
.attr('width', 40)
.attr('height', 40)
.attr('x', -20)
.attr('y', -20)
.on('error', function() {
console.log('Image failed to load:', HF_LOGO_URI);
// Fallback to circle
d3.select(this.parentNode).append('circle')
.attr('r', 22).attr('fill', '#ffbe0b');
});
console.log('Loading logo from:', HF_LOGO_URI);
}else{
baseSel.append('circle').attr('r', d => 22*d.sz).attr('fill', '#ffbe0b');
}
node.filter(d => d.cls !== 'base').append('circle').attr('r', d => 20*d.sz);
node.append('text')
.attr('class','node-label')
.attr('dy','-2.4em')
.style('font-size', d => d.cls === 'base' ? '32px' : '28px')
.style('font-weight', d => d.cls === 'base' ? 'bold' : 'normal')
.text(d => d.id);
const sim = d3.forceSimulation(graph.nodes)
.force('link', d3.forceLink(graph.links).id(d => d.id).distance(520))
.force('charge', d3.forceManyBody().strength(-600))
.force('center', d3.forceCenter(W / 2, H / 2))
.force('collide', d3.forceCollide(d => 50));
sim.on('tick', () => {
link.attr('x1', d=>d.source.x).attr('y1', d=>d.source.y)
.attr('x2', d=>d.target.x).attr('y2', d=>d.target.y);
linkLbl.attr('x', d=> (d.source.x+d.target.x)/2)
.attr('y', d=> (d.source.y+d.target.y)/2);
node.attr('transform', d=>`translate(${d.x},${d.y})`);
});
function dragStart(e,d){ if(!e.active) sim.alphaTarget(.3).restart(); d.fx=d.x; d.fy=d.y; }
function dragged(e,d){ d.fx=e.x; d.fy=e.y; }
function dragEnd(e,d){ if(!e.active) sim.alphaTarget(0); d.fx=d.fy=null; }
"""
HTML = """
<!DOCTYPE html>
<html lang='en'><head><meta charset='UTF-8'>
<title>Transformers modular graph</title>
<style>__CSS__</style></head><body>
<div id='legend'>
🟑 base<br>πŸ”΅ modular<br>πŸ”΄ candidate<br>red edgeΒ = high embedding similarity<br><br>
<label><input type="checkbox" id="toggleRed" checked> Show candidates edges and nodes</label>
</div>
<svg id='dependency'></svg>
<script src='https://d3js.org/d3.v7.min.js'></script>
<script>__JS__</script></body></html>
"""
# ────────────────────────────────────────────────────────────────────────────────
# HTML writer
# ────────────────────────────────────────────────────────────────────────────────
def write_html(graph_data: dict, path: Path):
path.write_text(generate_html(graph_data), encoding="utf-8")
# ────────────────────────────────────────────────────────────────────────────────
# MAIN
# ────────────────────────────────────────────────────────────────────────────────
def main():
ap = argparse.ArgumentParser(description="Visualise modular dependencies + candidates")
ap.add_argument("transformers", help="Path to local πŸ€— transformers repo root")
ap.add_argument("--multimodal", action="store_true", help="filter to models with β‰₯3 'pixel_values'")
ap.add_argument("--sim-threshold", type=float, default=SIM_DEFAULT)
ap.add_argument("--out", default=HTML_DEFAULT)
ap.add_argument("--sim-method", choices=["jaccard", "embedding"], default="jaccard",
help="Similarity method: 'jaccard' or 'embedding'")
args = ap.parse_args()
graph = build_graph_json(
transformers_dir=Path(args.transformers).expanduser().resolve(),
threshold=args.sim_threshold,
multimodal=args.multimodal,
sim_method=args.sim_method,
)
write_html(graph, Path(args.out).expanduser())
if __name__ == "__main__":
main()