transformers-modular-refactor / modular_graph_and_candidates.py
Molbap's picture
Molbap HF Staff
update
a9aba5d
raw
history blame
24.3 kB
#!/usr/bin/env python
"""
modular_graph_and_candidates.py
================================
Create **one** rich view that combines
1. The *dependency graph* between existing **modular_*.py** implementations in
πŸ€—Β Transformers (blue/🟑) **and**
2. The list of *missing* modular models (full‑red nodes) **plus** similarity
edges (full‑red links) between highly‑overlapping modelling files – the
output of *find_modular_candidates.py* – so you can immediately spot good
refactor opportunities.
––– Usage –––
```bash
python modular_graph_and_candidates.py /path/to/transformers \
--multimodal # keep only models whose modelling code mentions
# "pixel_values" β‰₯Β 3 times
--sim-threshold 0.5 # Jaccard cutoff (default 0.50)
--out graph.html # output HTML file name
```
Colour legend in the generated HTML:
* 🟑 **base model**Β β€” has modular shards *imported* by others but no parent
* πŸ”΅Β **derived modular model**Β β€” has a `modular_*.py` and inherits from β‰₯β€―1 model
* πŸ”΄Β **candidate**Β β€” no `modular_*.py` yet (and/or very similar to another)
* red edges = high‑Jaccard similarity links (potential to factorise)
"""
from __future__ import annotations
import argparse
import ast
import json
import re
import tokenize
from collections import Counter, defaultdict
from itertools import combinations
from pathlib import Path
from typing import Dict, List, Set, Tuple
from sentence_transformers import SentenceTransformer, util
from tqdm import tqdm
import numpy as np
import spaces
import torch
# ────────────────────────────────────────────────────────────────────────────────
# CONFIG
# ───────────────────────────────────────────────────────────────────────────────
SIM_DEFAULT = 0.5 # similarity threshold
PIXEL_MIN_HITS = 0 # multimodal trigger ("pixel_values")
HTML_DEFAULT = "d3_modular_graph.html"
# ────────────────────────────────────────────────────────────────────────────────
# 1) Helpers to analyse *modelling* files (for similarity & multimodal filter)
# ────────────────────────────────────────────────────────────────────────────────
def _strip_source(code: str) -> str:
"""Remove doc‑strings, comments and import lines to keep only the core code."""
code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code) # doc‑strings
code = re.sub(r"#.*", "", code) # # comments
return "\n".join(ln for ln in code.splitlines()
if not re.match(r"\s*(from|import)\s+", ln))
def _tokenise(code: str) -> Set[str]:
"""Extract identifiers using regex - more robust than tokenizer for malformed code."""
toks: Set[str] = set()
for match in re.finditer(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', code):
toks.add(match.group())
return toks
def build_token_bags(models_root: Path) -> Tuple[Dict[str, List[Set[str]]], Dict[str, int]]:
"""Return token‑bags of every `modeling_*.py` plus a pixel‑value counter."""
bags: Dict[str, List[Set[str]]] = defaultdict(list)
pixel_hits: Dict[str, int] = defaultdict(int)
for mdl_dir in sorted(p for p in models_root.iterdir() if p.is_dir()):
for py in mdl_dir.rglob("modeling_*.py"):
try:
text = py.read_text(encoding="utf‑8")
pixel_hits[mdl_dir.name] += text.count("pixel_values")
bags[mdl_dir.name].append(_tokenise(_strip_source(text)))
except Exception as e:
print(f"⚠️ Skipped {py}: {e}")
return bags, pixel_hits
def _jaccard(a: Set[str], b: Set[str]) -> float:
return 0.0 if (not a or not b) else len(a & b) / len(a | b)
def similarity_clusters(bags: Dict[str, List[Set[str]]], thr: float) -> Dict[Tuple[str,str], float]:
"""Return {(modelA, modelB): score} for pairs with Jaccard β‰₯ *thr*."""
largest = {m: max(ts, key=len) for m, ts in bags.items() if ts}
out: Dict[Tuple[str,str], float] = {}
for m1, m2 in combinations(sorted(largest.keys()), 2):
s = _jaccard(largest[m1], largest[m2])
if s >= thr:
out[(m1, m2)] = s
return out
@spaces.GPU
def embedding_similarity_clusters(models_root: Path, missing: List[str], thr: float) -> Dict[Tuple[str, str], float]:
model = SentenceTransformer("codesage/codesage-large-v2", device="cuda", trust_remote_code=True)
try:
cfg = model[0].auto_model.config
pos_limit = int(getattr(cfg, "n_positions", getattr(cfg, "max_position_embeddings")))
except Exception:
pos_limit = 1024
seq_len = min(pos_limit, 2048)
model.max_seq_length = seq_len
model[0].max_seq_length = seq_len
model[0].tokenizer.model_max_length = seq_len
texts = {}
for name in tqdm(missing, desc="Reading modeling files"):
if any(skip in name.lower() for skip in ["mobilebert", "lxmert"]):
print(f"Skipping {name} (causes GPU abort)")
continue
code = ""
for py in (models_root / name).rglob("modeling_*.py"):
try:
code += _strip_source(py.read_text(encoding="utf-8")) + "\n"
except Exception:
continue
texts[name] = code.strip() or " "
names = list(texts)
all_embeddings = []
print(f"Encoding embeddings for {len(names)} models...")
batch_size = 4 # keep your default
# ── persistent embeddings storage ────────────────────────────────────────────
embeddings_path = Path("embeddings_cache.npz")
start_idx = 0
emb_dim = getattr(model, "get_sentence_embedding_dimension", lambda: 768)()
if embeddings_path.exists():
try:
cached = np.load(embeddings_path, allow_pickle=True)
cached_names = list(cached["names"])
if names[:len(cached_names)] == cached_names:
loaded = cached["embeddings"].astype(np.float32)
all_embeddings.append(loaded)
start_idx = len(cached_names)
print(f"πŸ“¦ Using cached embeddings for {start_idx}/{len(names)} models")
except Exception as e:
print(f"⚠️ Failed to load cached embeddings: {type(e).__name__}: {e}")
# ───────────────────────────────────────────────────────────────────────────
for i in tqdm(range(start_idx, len(names), batch_size), desc="Batches", leave=False):
batch_names = names[i:i+batch_size]
batch_texts = [texts[name] for name in batch_names]
try:
print(f"Processing batch: {batch_names}")
emb = model.encode(batch_texts, convert_to_numpy=True, show_progress_bar=False)
except Exception as e:
print(f"⚠️ GPU worker error for batch {batch_names}: {type(e).__name__}: {e}")
emb = np.zeros((len(batch_names), emb_dim), dtype=np.float32)
all_embeddings.append(emb)
# save to persistent cache after each batch
try:
cur = np.vstack(all_embeddings).astype(np.float32)
np.savez(
embeddings_path,
embeddings=cur,
names=np.array(names[:i+len(batch_names)], dtype=object),
)
except Exception as e:
print(f"⚠️ Failed to write embeddings cache: {type(e).__name__}: {e}")
if (i - start_idx) % (3 * batch_size) == 0 and torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
print(f"🧹 Cleared GPU cache after batch {(i - start_idx)//batch_size + 1}")
embeddings = np.vstack(all_embeddings).astype(np.float32)
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12
embeddings = embeddings / norms
print("Computing pairwise similarities...")
sims_mat = embeddings @ embeddings.T
out = {}
matrix_size = embeddings.shape[0]
processed_names = names[:matrix_size]
for i in range(matrix_size):
for j in range(i + 1, matrix_size):
s = float(sims_mat[i, j])
if s >= thr:
out[(processed_names[i], processed_names[j])] = s
print(f"πŸ’Ύ Embeddings saved to {embeddings_path}")
return out
def compute_similarities_from_cache(threshold: float) -> Dict[Tuple[str, str], float]:
"""Compute similarities from cached embeddings without reprocessing."""
embeddings_path = Path("embeddings_cache.npz")
if not embeddings_path.exists():
return {}
try:
cached = np.load(embeddings_path, allow_pickle=True)
embeddings = cached["embeddings"].astype(np.float32)
names = list(cached["names"])
# Normalize embeddings
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12
embeddings = embeddings / norms
# Compute similarities
sims_mat = embeddings @ embeddings.T
out = {}
for i in range(len(names)):
for j in range(i + 1, len(names)):
s = float(sims_mat[i, j])
if s >= threshold:
out[(names[i], names[j])] = s
print(f"⚑ Computed {len(out)} similarities from cache (threshold: {threshold})")
return out
except Exception as e:
print(f"⚠️ Failed to compute from cache: {e}")
return {}
# ────────────────────────────────────────────────────────────────────────────────
# 2) Scan *modular_*.py* files to build an import‑dependency graph
# – only **modeling_*** imports are considered (skip configuration / processing)
# ────────────────────────────────────────────────────────────────────────────────
def modular_files(models_root: Path) -> List[Path]:
return [p for p in models_root.rglob("modular_*.py") if p.suffix == ".py"]
def dependency_graph(modular_files: List[Path], models_root: Path) -> Dict[str, List[Dict[str,str]]]:
"""Return {derived_model: [{source, imported_class}, ...]}
Only `modeling_*` imports are kept; anything coming from configuration/processing/
image* utils is ignored so the visual graph focuses strictly on modelling code.
Excludes edges to sources whose model name is not a model dir.
"""
model_names = {p.name for p in models_root.iterdir() if p.is_dir()}
deps: Dict[str, List[Dict[str,str]]] = defaultdict(list)
for fp in modular_files:
derived = fp.parent.name
try:
tree = ast.parse(fp.read_text(encoding="utf‑8"), filename=str(fp))
except Exception as e:
print(f"⚠️ AST parse failed for {fp}: {e}")
continue
for node in ast.walk(tree):
if not isinstance(node, ast.ImportFrom) or not node.module:
continue
mod = node.module
# keep only *modeling_* imports, drop anything else
if ("modeling_" not in mod or
"configuration_" in mod or
"processing_" in mod or
"image_processing" in mod or
"modeling_attn_mask_utils" in mod):
continue
parts = re.split(r"[./]", mod)
src = next((p for p in parts if p not in {"", "models", "transformers"}), "")
if not src or src == derived or src not in model_names:
continue
for alias in node.names:
deps[derived].append({"source": src, "imported_class": alias.name})
return dict(deps)
# modular_graph_and_candidates.py (top-level)
def get_missing_models(models_root: Path, multimodal: bool = False) -> Tuple[List[str], Dict[str, List[Set[str]]], Dict[str, int]]:
"""Get list of models missing modular implementations."""
bags, pix_hits = build_token_bags(models_root)
mod_files = modular_files(models_root)
models_with_modular = {p.parent.name for p in mod_files}
missing = [m for m in bags if m not in models_with_modular]
if multimodal:
missing = [m for m in missing if pix_hits[m] >= PIXEL_MIN_HITS]
return missing, bags, pix_hits
def compute_similarities(models_root: Path, missing: List[str], bags: Dict[str, List[Set[str]]],
threshold: float, sim_method: str) -> Dict[Tuple[str, str], float]:
"""Compute similarities between missing models using specified method."""
if sim_method == "jaccard":
return similarity_clusters({m: bags[m] for m in missing}, threshold)
else:
# Try to use cached embeddings first
embeddings_path = Path("embeddings_cache.npz")
if embeddings_path.exists():
cached_sims = compute_similarities_from_cache(threshold)
if cached_sims: # Cache exists and worked
return cached_sims
# Fallback to full computation
return embedding_similarity_clusters(models_root, missing, threshold)
def build_graph_json(
transformers_dir: Path,
threshold: float = SIM_DEFAULT,
multimodal: bool = False,
sim_method: str = "jaccard",
) -> dict:
"""Return the {nodes, links} dict that D3 needs."""
# Check if we can use cached embeddings only
embeddings_cache = Path("embeddings_cache.npz")
if sim_method == "embedding" and embeddings_cache.exists():
try:
# Try to compute from cache without accessing repo
cached_sims = compute_similarities_from_cache(threshold)
if cached_sims:
# Create minimal graph with cached data
cached_data = np.load(embeddings_cache, allow_pickle=True)
missing = list(cached_data["names"])
nodes = []
for name in missing:
nodes.append({"id": name, "cls": "cand", "sz": 1})
links = []
for (a, b), s in cached_sims.items():
links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True})
print(f"⚑ Built graph from cache: {len(nodes)} nodes, {len(links)} links")
return {"nodes": nodes, "links": links}
except Exception as e:
print(f"⚠️ Cache-only build failed: {e}, falling back to full build")
# Full build with repository access
models_root = transformers_dir / "src/transformers/models"
# Get missing models and their data
missing, bags, pix_hits = get_missing_models(models_root, multimodal)
# Build dependency graph
mod_files = modular_files(models_root)
deps = dependency_graph(mod_files, models_root)
# Compute similarities
sims = compute_similarities(models_root, missing, bags, threshold, sim_method)
# ---- assemble nodes & links ----
nodes: Set[str] = set()
links: List[dict] = []
for drv, lst in deps.items():
for d in lst:
links.append({
"source": d["source"],
"target": drv,
"label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports",
"cand": False
})
nodes.update({d["source"], drv})
for (a, b), s in sims.items():
links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True})
nodes.update({a, b})
nodes.update(missing)
deg = Counter()
for lk in links:
deg[lk["source"]] += 1
deg[lk["target"]] += 1
max_deg = max(deg.values() or [1])
targets = {lk["target"] for lk in links if not lk["cand"]}
sources = {lk["source"] for lk in links if not lk["cand"]}
missing_only = [m for m in missing if m not in sources and m not in targets]
nodes.update(missing_only)
nodelist = []
for n in sorted(nodes):
if n in missing_only:
cls = "cand"
elif n in sources and n not in targets:
cls = "base"
else:
cls = "derived"
nodelist.append({"id": n, "cls": cls, "sz": 1 + 2*(deg[n]/max_deg)})
graph = {"nodes": nodelist, "links": links}
return graph
def generate_html(graph: dict) -> str:
"""Return the full HTML string with inlined CSS/JS + graph JSON."""
js = JS.replace("__GRAPH_DATA__", json.dumps(graph, separators=(",", ":")))
return HTML.replace("__CSS__", CSS).replace("__JS__", js)
# ────────────────────────────────────────────────────────────────────────────────
# 3) HTML (D3.js) boilerplate – CSS + JS templates (unchanged design)
# ────────────────────────────────────────────────────────────────────────────────
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
:root{
--bg:#ffffff;
--text:#222222;
--muted:#555555;
--outline:#ffffff;
}
@media (prefers-color-scheme: dark){
:root{
--bg:#0b0d10;
--text:#e8e8e8;
--muted:#c8c8c8;
--outline:#000000;
}
}
body{ margin:0; font-family:'Inter',Arial,sans-serif; background:var(--bg); overflow:hidden; }
svg{ width:100vw; height:100vh; }
.link{ stroke:#999; stroke-opacity:.6; }
.link.cand{ stroke:#e63946; stroke-width:2.5; }
.node-label{
fill:var(--text);
pointer-events:none;
text-anchor:middle;
font-weight:600;
paint-order:stroke fill;
stroke:var(--outline);
stroke-width:3px;
}
.link-label{
fill:var(--muted);
pointer-events:none;
text-anchor:middle;
font-size:10px;
paint-order:stroke fill;
stroke:var(--bg);
stroke-width:2px;
}
.node.base image{ width:60px; height:60px; transform:translate(-30px,-30px); }
.node.derived circle{ fill:#1f77b4; }
.node.cand circle, .node.cand path{ fill:#e63946; }
#legend{
position:fixed; top:18px; left:18px;
background:rgba(255,255,255,.92);
padding:18px 28px; border-radius:10px; border:1.5px solid #bbb;
font-size:18px; box-shadow:0 2px 8px rgba(0,0,0,.08);
}
@media (prefers-color-scheme: dark){
#legend{ background:rgba(20,22,25,.92); color:#e8e8e8; border-color:#444; }
}
"""
JS = """
function updateVisibility() {
const show = document.getElementById('toggleRed').checked;
svg.selectAll('.link.cand').style('display', show ? null : 'none');
svg.selectAll('.node.cand').style('display', show ? null : 'none');
svg.selectAll('.link-label').filter(d => d.cand).style('display', show ? null : 'none');
}
document.getElementById('toggleRed').addEventListener('change', updateVisibility);
const HF_LOGO_URI = "static/hf-logo.svg";
const graph = __GRAPH_DATA__;
const W = innerWidth, H = innerHeight;
const svg = d3.select('#dependency').call(d3.zoom().on('zoom', e => g.attr('transform', e.transform)));
const g = svg.append('g');
const link = g.selectAll('line')
.data(graph.links)
.join('line')
.attr('class', d => d.cand ? 'link cand' : 'link');
const linkLbl = g.selectAll('text.link-label')
.data(graph.links)
.join('text')
.attr('class', 'link-label')
.text(d => d.label);
const node = g.selectAll('g.node')
.data(graph.nodes)
.join('g')
.attr('class', d => `node ${d.cls}`)
.call(d3.drag().on('start', dragStart).on('drag', dragged).on('end', dragEnd));
const baseSel = node.filter(d => d.cls === 'base');
if (HF_LOGO_URI){
baseSel.append('image').attr('href', HF_LOGO_URI);
}else{
baseSel.append('circle').attr('r', d => 22*d.sz).attr('fill', '#ffbe0b');
}
node.filter(d => d.cls !== 'base').append('circle').attr('r', d => 20*d.sz);
node.append('text').attr('class','node-label').attr('dy','-2.4em').text(d => d.id);
const sim = d3.forceSimulation(graph.nodes)
.force('link', d3.forceLink(graph.links).id(d => d.id).distance(520))
.force('charge', d3.forceManyBody().strength(-600))
.force('center', d3.forceCenter(W / 2, H / 2))
.force('collide', d3.forceCollide(d => 50));
sim.on('tick', () => {
link.attr('x1', d=>d.source.x).attr('y1', d=>d.source.y)
.attr('x2', d=>d.target.x).attr('y2', d=>d.target.y);
linkLbl.attr('x', d=> (d.source.x+d.target.x)/2)
.attr('y', d=> (d.source.y+d.target.y)/2);
node.attr('transform', d=>`translate(${d.x},${d.y})`);
});
function dragStart(e,d){ if(!e.active) sim.alphaTarget(.3).restart(); d.fx=d.x; d.fy=d.y; }
function dragged(e,d){ d.fx=e.x; d.fy=e.y; }
function dragEnd(e,d){ if(!e.active) sim.alphaTarget(0); d.fx=d.fy=null; }
"""
HTML = """
<!DOCTYPE html>
<html lang='en'><head><meta charset='UTF-8'>
<title>Transformers modular graph</title>
<style>__CSS__</style></head><body>
<div id='legend'>
🟑 base<br>πŸ”΅ modular<br>πŸ”΄ candidate<br>red edgeΒ = high embedding similarity<br><br>
<label><input type="checkbox" id="toggleRed" checked> Show candidates edges and nodes</label>
</div>
<svg id='dependency'></svg>
<script src='https://d3js.org/d3.v7.min.js'></script>
<script>__JS__</script></body></html>
"""
# ────────────────────────────────────────────────────────────────────────────────
# HTML writer
# ────────────────────────────────────────────────────────────────────────────────
def write_html(graph_data: dict, path: Path):
path.write_text(generate_html(graph_data), encoding="utf-8")
# ────────────────────────────────────────────────────────────────────────────────
# MAIN
# ────────────────────────────────────────────────────────────────────────────────
def main():
ap = argparse.ArgumentParser(description="Visualise modular dependencies + candidates")
ap.add_argument("transformers", help="Path to local πŸ€— transformers repo root")
ap.add_argument("--multimodal", action="store_true", help="filter to models with β‰₯3 'pixel_values'")
ap.add_argument("--sim-threshold", type=float, default=SIM_DEFAULT)
ap.add_argument("--out", default=HTML_DEFAULT)
ap.add_argument("--sim-method", choices=["jaccard", "embedding"], default="jaccard",
help="Similarity method: 'jaccard' or 'embedding'")
args = ap.parse_args()
graph = build_graph_json(
transformers_dir=Path(args.transformers).expanduser().resolve(),
threshold=args.sim_threshold,
multimodal=args.multimodal,
sim_method=args.sim_method,
)
write_html(graph, Path(args.out).expanduser())
if __name__ == "__main__":
main()