#!/usr/bin/env python """ modular_graph_and_candidates.py ================================ Create **one** rich view that combines 1. The *dependency graph* between existing **modular_*.py** implementations in πŸ€—Β Transformers (blue/🟑) **and** 2. The list of *missing* modular models (full‑red nodes) **plus** similarity edges (full‑red links) between highly‑overlapping modelling files – the output of *find_modular_candidates.py* – so you can immediately spot good refactor opportunities. ––– Usage ––– ```bash python modular_graph_and_candidates.py /path/to/transformers \ --multimodal # keep only models whose modelling code mentions # "pixel_values" β‰₯Β 3 times --sim-threshold 0.5 # Jaccard cutoff (default 0.50) --out graph.html # output HTML file name ``` Colour legend in the generated HTML: * 🟑 **base model**Β β€” has modular shards *imported* by others but no parent * πŸ”΅Β **derived modular model**Β β€” has a `modular_*.py` and inherits from β‰₯β€―1 model * πŸ”΄Β **candidate**Β β€” no `modular_*.py` yet (and/or very similar to another) * red edges = high‑Jaccard similarity links (potential to factorise) """ from __future__ import annotations import argparse import ast import json import re import tokenize from collections import Counter, defaultdict from itertools import combinations from pathlib import Path from typing import Dict, List, Set, Tuple from sentence_transformers import SentenceTransformer, util from tqdm import tqdm import numpy as np import spaces import torch # ──────────────────────────────────────────────────────────────────────────────── # CONFIG # ─────────────────────────────────────────────────────────────────────────────── SIM_DEFAULT = 0.5 # similarity threshold PIXEL_MIN_HITS = 0 # multimodal trigger ("pixel_values") HTML_DEFAULT = "d3_modular_graph.html" # ──────────────────────────────────────────────────────────────────────────────── # 1) Helpers to analyse *modelling* files (for similarity & multimodal filter) # ──────────────────────────────────────────────────────────────────────────────── def _strip_source(code: str) -> str: """Remove doc‑strings, comments and import lines to keep only the core code.""" code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code) # doc‑strings code = re.sub(r"#.*", "", code) # # comments return "\n".join(ln for ln in code.splitlines() if not re.match(r"\s*(from|import)\s+", ln)) def _tokenise(code: str) -> Set[str]: """Extract identifiers using regex - more robust than tokenizer for malformed code.""" toks: Set[str] = set() for match in re.finditer(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', code): toks.add(match.group()) return toks def build_token_bags(models_root: Path) -> Tuple[Dict[str, List[Set[str]]], Dict[str, int]]: """Return token‑bags of every `modeling_*.py` plus a pixel‑value counter.""" bags: Dict[str, List[Set[str]]] = defaultdict(list) pixel_hits: Dict[str, int] = defaultdict(int) for mdl_dir in sorted(p for p in models_root.iterdir() if p.is_dir()): for py in mdl_dir.rglob("modeling_*.py"): try: text = py.read_text(encoding="utf‑8") pixel_hits[mdl_dir.name] += text.count("pixel_values") bags[mdl_dir.name].append(_tokenise(_strip_source(text))) except Exception as e: print(f"⚠️ Skipped {py}: {e}") return bags, pixel_hits def _jaccard(a: Set[str], b: Set[str]) -> float: return 0.0 if (not a or not b) else len(a & b) / len(a | b) def similarity_clusters(bags: Dict[str, List[Set[str]]], thr: float) -> Dict[Tuple[str,str], float]: """Return {(modelA, modelB): score} for pairs with Jaccard β‰₯ *thr*.""" largest = {m: max(ts, key=len) for m, ts in bags.items() if ts} out: Dict[Tuple[str,str], float] = {} for m1, m2 in combinations(sorted(largest.keys()), 2): s = _jaccard(largest[m1], largest[m2]) if s >= thr: out[(m1, m2)] = s return out @spaces.GPU def embedding_similarity_clusters(models_root: Path, missing: List[str], thr: float) -> Dict[Tuple[str, str], float]: model = SentenceTransformer("codesage/codesage-large-v2", device="cuda", trust_remote_code=True) try: cfg = model[0].auto_model.config pos_limit = int(getattr(cfg, "n_positions", getattr(cfg, "max_position_embeddings"))) except Exception: pos_limit = 1024 seq_len = min(pos_limit, 2048) model.max_seq_length = seq_len model[0].max_seq_length = seq_len model[0].tokenizer.model_max_length = seq_len texts = {} for name in tqdm(missing, desc="Reading modeling files"): if any(skip in name.lower() for skip in ["mobilebert", "lxmert"]): print(f"Skipping {name} (causes GPU abort)") continue code = "" for py in (models_root / name).rglob("modeling_*.py"): try: code += _strip_source(py.read_text(encoding="utf-8")) + "\n" except Exception: continue texts[name] = code.strip() or " " names = list(texts) all_embeddings = [] print(f"Encoding embeddings for {len(names)} models...") batch_size = 4 # keep your default # ── two-stage caching: temp (for resume) + permanent (for reuse) ───────────── temp_cache_path = Path("temp_embeddings.npz") # For resuming computation final_cache_path = Path("embeddings_cache.npz") # For permanent storage start_idx = 0 emb_dim = getattr(model, "get_sentence_embedding_dimension", lambda: 768)() # Try to load from permanent cache first if final_cache_path.exists(): try: cached = np.load(final_cache_path, allow_pickle=True) cached_names = list(cached["names"]) if names == cached_names: # Exact match - use final cache print(f"βœ… Using final embeddings cache ({len(cached_names)} models)") return compute_similarities_from_cache(thr) except Exception as e: print(f"⚠️ Failed to load final cache: {e}") # Try to resume from temp cache if temp_cache_path.exists(): try: cached = np.load(temp_cache_path, allow_pickle=True) cached_names = list(cached["names"]) if names[:len(cached_names)] == cached_names: loaded = cached["embeddings"].astype(np.float32) all_embeddings.append(loaded) start_idx = len(cached_names) print(f"πŸ”„ Resuming from temp cache: {start_idx}/{len(names)} models") except Exception as e: print(f"⚠️ Failed to load temp cache: {e}") # ─────────────────────────────────────────────────────────────────────────── for i in tqdm(range(start_idx, len(names), batch_size), desc="Batches", leave=False): batch_names = names[i:i+batch_size] batch_texts = [texts[name] for name in batch_names] try: print(f"Processing batch: {batch_names}") emb = model.encode(batch_texts, convert_to_numpy=True, show_progress_bar=False) except Exception as e: print(f"⚠️ GPU worker error for batch {batch_names}: {type(e).__name__}: {e}") emb = np.zeros((len(batch_names), emb_dim), dtype=np.float32) all_embeddings.append(emb) # save to temp cache after each batch (for resume) try: cur = np.vstack(all_embeddings).astype(np.float32) np.savez( temp_cache_path, embeddings=cur, names=np.array(names[:i+len(batch_names)], dtype=object), ) except Exception as e: print(f"⚠️ Failed to write temp cache: {e}") if (i - start_idx) % (3 * batch_size) == 0 and torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() print(f"🧹 Cleared GPU cache after batch {(i - start_idx)//batch_size + 1}") embeddings = np.vstack(all_embeddings).astype(np.float32) norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12 embeddings = embeddings / norms print("Computing pairwise similarities...") sims_mat = embeddings @ embeddings.T out = {} matrix_size = embeddings.shape[0] processed_names = names[:matrix_size] for i in range(matrix_size): for j in range(i + 1, matrix_size): s = float(sims_mat[i, j]) if s >= thr: out[(processed_names[i], processed_names[j])] = s # Save to final cache when complete try: np.savez(final_cache_path, embeddings=embeddings, names=np.array(names, dtype=object)) print(f"πŸ’Ύ Final embeddings saved to {final_cache_path}") # Clean up temp cache if temp_cache_path.exists(): temp_cache_path.unlink() print(f"🧹 Cleaned up temp cache") except Exception as e: print(f"⚠️ Failed to save final cache: {e}") return out def compute_similarities_from_cache(threshold: float) -> Dict[Tuple[str, str], float]: """Compute similarities from cached embeddings without reprocessing.""" embeddings_path = Path("embeddings_cache.npz") if not embeddings_path.exists(): return {} try: cached = np.load(embeddings_path, allow_pickle=True) embeddings = cached["embeddings"].astype(np.float32) names = list(cached["names"]) # Normalize embeddings norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12 embeddings = embeddings / norms # Compute similarities sims_mat = embeddings @ embeddings.T out = {} for i in range(len(names)): for j in range(i + 1, len(names)): s = float(sims_mat[i, j]) if s >= threshold: out[(names[i], names[j])] = s print(f"⚑ Computed {len(out)} similarities from cache (threshold: {threshold})") return out except Exception as e: print(f"⚠️ Failed to compute from cache: {e}") return {} # ──────────────────────────────────────────────────────────────────────────────── # 2) Scan *modular_*.py* files to build an import‑dependency graph # – only **modeling_*** imports are considered (skip configuration / processing) # ──────────────────────────────────────────────────────────────────────────────── def modular_files(models_root: Path) -> List[Path]: return [p for p in models_root.rglob("modular_*.py") if p.suffix == ".py"] def dependency_graph(modular_files: List[Path], models_root: Path) -> Dict[str, List[Dict[str,str]]]: """Return {derived_model: [{source, imported_class}, ...]} Only `modeling_*` imports are kept; anything coming from configuration/processing/ image* utils is ignored so the visual graph focuses strictly on modelling code. Excludes edges to sources whose model name is not a model dir. """ model_names = {p.name for p in models_root.iterdir() if p.is_dir()} deps: Dict[str, List[Dict[str,str]]] = defaultdict(list) for fp in modular_files: derived = fp.parent.name try: tree = ast.parse(fp.read_text(encoding="utf‑8"), filename=str(fp)) except Exception as e: print(f"⚠️ AST parse failed for {fp}: {e}") continue for node in ast.walk(tree): if not isinstance(node, ast.ImportFrom) or not node.module: continue mod = node.module # keep only *modeling_* imports, drop anything else if ("modeling_" not in mod or "configuration_" in mod or "processing_" in mod or "image_processing" in mod or "modeling_attn_mask_utils" in mod): continue parts = re.split(r"[./]", mod) src = next((p for p in parts if p not in {"", "models", "transformers"}), "") if not src or src == derived or src not in model_names: continue for alias in node.names: deps[derived].append({"source": src, "imported_class": alias.name}) return dict(deps) # modular_graph_and_candidates.py (top-level) def get_missing_models(models_root: Path, multimodal: bool = False) -> Tuple[List[str], Dict[str, List[Set[str]]], Dict[str, int]]: """Get list of models missing modular implementations.""" bags, pix_hits = build_token_bags(models_root) mod_files = modular_files(models_root) models_with_modular = {p.parent.name for p in mod_files} missing = [m for m in bags if m not in models_with_modular] if multimodal: missing = [m for m in missing if pix_hits[m] >= PIXEL_MIN_HITS] return missing, bags, pix_hits def compute_similarities(models_root: Path, missing: List[str], bags: Dict[str, List[Set[str]]], threshold: float, sim_method: str) -> Dict[Tuple[str, str], float]: """Compute similarities between missing models using specified method.""" if sim_method == "jaccard": return similarity_clusters({m: bags[m] for m in missing}, threshold) else: # Try to use cached embeddings first embeddings_path = Path("embeddings_cache.npz") if embeddings_path.exists(): cached_sims = compute_similarities_from_cache(threshold) if cached_sims: # Cache exists and worked return cached_sims # Fallback to full computation return embedding_similarity_clusters(models_root, missing, threshold) def build_graph_json( transformers_dir: Path, threshold: float = SIM_DEFAULT, multimodal: bool = False, sim_method: str = "jaccard", ) -> dict: """Return the {nodes, links} dict that D3 needs.""" # Check if we can use cached embeddings only embeddings_cache = Path("embeddings_cache.npz") print(f"πŸ” Cache file exists: {embeddings_cache.exists()}, sim_method: {sim_method}") if sim_method == "embedding" and embeddings_cache.exists(): try: # Try to compute from cache without accessing repo cached_sims = compute_similarities_from_cache(threshold) print(f"πŸ” Got {len(cached_sims)} cached similarities") if cached_sims: # Create graph with cached similarities + modular dependencies cached_data = np.load(embeddings_cache, allow_pickle=True) missing = list(cached_data["names"]) # Still need to get modular dependencies from repo models_root = transformers_dir / "src/transformers/models" mod_files = modular_files(models_root) deps = dependency_graph(mod_files, models_root) # Build full graph structure nodes = set(missing) # Start with cached models links = [] # Add dependency links for drv, lst in deps.items(): for d in lst: links.append({ "source": d["source"], "target": drv, "label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports", "cand": False }) nodes.update({d["source"], drv}) # Add similarity links for (a, b), s in cached_sims.items(): links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True}) # Create node list with proper classification targets = {lk["target"] for lk in links if not lk["cand"]} sources = {lk["source"] for lk in links if not lk["cand"]} nodelist = [] for n in sorted(nodes): if n in missing and n not in sources and n not in targets: cls = "cand" elif n in sources and n not in targets: cls = "base" else: cls = "derived" nodelist.append({"id": n, "cls": cls, "sz": 1}) print(f"⚑ Built graph from cache: {len(nodelist)} nodes, {len(links)} links") return {"nodes": nodelist, "links": links} except Exception as e: print(f"⚠️ Cache-only build failed: {e}, falling back to full build") # Full build with repository access models_root = transformers_dir / "src/transformers/models" # Get missing models and their data missing, bags, pix_hits = get_missing_models(models_root, multimodal) # Build dependency graph mod_files = modular_files(models_root) deps = dependency_graph(mod_files, models_root) # Compute similarities sims = compute_similarities(models_root, missing, bags, threshold, sim_method) # ---- assemble nodes & links ---- nodes: Set[str] = set() links: List[dict] = [] for drv, lst in deps.items(): for d in lst: links.append({ "source": d["source"], "target": drv, "label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports", "cand": False }) nodes.update({d["source"], drv}) for (a, b), s in sims.items(): links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True}) nodes.update({a, b}) nodes.update(missing) deg = Counter() for lk in links: deg[lk["source"]] += 1 deg[lk["target"]] += 1 max_deg = max(deg.values() or [1]) targets = {lk["target"] for lk in links if not lk["cand"]} sources = {lk["source"] for lk in links if not lk["cand"]} missing_only = [m for m in missing if m not in sources and m not in targets] nodes.update(missing_only) nodelist = [] for n in sorted(nodes): if n in missing_only: cls = "cand" elif n in sources and n not in targets: cls = "base" else: cls = "derived" nodelist.append({"id": n, "cls": cls, "sz": 1 + 2*(deg[n]/max_deg)}) graph = {"nodes": nodelist, "links": links} return graph def generate_html(graph: dict) -> str: """Return the full HTML string with inlined CSS/JS + graph JSON.""" js = JS.replace("__GRAPH_DATA__", json.dumps(graph, separators=(",", ":"))) return HTML.replace("__CSS__", CSS).replace("__JS__", js) # ──────────────────────────────────────────────────────────────────────────────── # 3) HTML (D3.js) boilerplate – CSS + JS templates (unchanged design) # ──────────────────────────────────────────────────────────────────────────────── CSS = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap'); :root{ --bg:#ffffff; --text:#222222; --muted:#555555; --outline:#ffffff; } @media (prefers-color-scheme: dark){ :root{ --bg:#0b0d10; --text:#e8e8e8; --muted:#c8c8c8; --outline:#000000; } } body{ margin:0; font-family:'Inter',Arial,sans-serif; background:var(--bg); overflow:hidden; } svg{ width:100vw; height:100vh; } .link{ stroke:#999; stroke-opacity:.6; } .link.cand{ stroke:#e63946; stroke-width:2.5; } .node-label{ fill:var(--text); pointer-events:none; text-anchor:middle; font-weight:600; paint-order:stroke fill; stroke:var(--outline); stroke-width:3px; } .link-label{ fill:var(--muted); pointer-events:none; text-anchor:middle; font-size:10px; paint-order:stroke fill; stroke:var(--bg); stroke-width:2px; } .node.base image{ width:60px; height:60px; transform:translate(-30px,-30px); } .node.derived circle{ fill:#1f77b4; } .node.cand circle, .node.cand path{ fill:#e63946; } #legend{ position:fixed; top:18px; left:18px; background:rgba(255,255,255,.92); padding:18px 28px; border-radius:10px; border:1.5px solid #bbb; font-size:18px; box-shadow:0 2px 8px rgba(0,0,0,.08); } @media (prefers-color-scheme: dark){ #legend{ background:rgba(20,22,25,.92); color:#e8e8e8; border-color:#444; } } """ JS = """ function updateVisibility() { const show = document.getElementById('toggleRed').checked; svg.selectAll('.link.cand').style('display', show ? null : 'none'); svg.selectAll('.node.cand').style('display', show ? null : 'none'); svg.selectAll('.link-label').filter(d => d.cand).style('display', show ? null : 'none'); } document.getElementById('toggleRed').addEventListener('change', updateVisibility); const HF_LOGO_URI = "./static/hf-logo.png"; const graph = __GRAPH_DATA__; const W = innerWidth, H = innerHeight; const svg = d3.select('#dependency').call(d3.zoom().on('zoom', e => g.attr('transform', e.transform))); const g = svg.append('g'); const link = g.selectAll('line') .data(graph.links) .join('line') .attr('class', d => d.cand ? 'link cand' : 'link'); const linkLbl = g.selectAll('text.link-label') .data(graph.links) .join('text') .attr('class', 'link-label') .text(d => d.label); const node = g.selectAll('g.node') .data(graph.nodes) .join('g') .attr('class', d => `node ${d.cls}`) .call(d3.drag().on('start', dragStart).on('drag', dragged).on('end', dragEnd)); const baseSel = node.filter(d => d.cls === 'base'); if (HF_LOGO_URI){ baseSel.append('image') .attr('href', HF_LOGO_URI) .attr('width', 40) .attr('height', 40) .attr('x', -20) .attr('y', -20) .on('error', function() { console.log('Image failed to load:', HF_LOGO_URI); // Fallback to circle d3.select(this.parentNode).append('circle') .attr('r', 22).attr('fill', '#ffbe0b'); }); console.log('Loading logo from:', HF_LOGO_URI); }else{ baseSel.append('circle').attr('r', d => 22*d.sz).attr('fill', '#ffbe0b'); } node.filter(d => d.cls !== 'base').append('circle').attr('r', d => 20*d.sz); node.append('text') .attr('class','node-label') .attr('dy','-2.4em') .style('font-size', d => d.cls === 'base' ? '32px' : '28px') .style('font-weight', d => d.cls === 'base' ? 'bold' : 'normal') .text(d => d.id); const sim = d3.forceSimulation(graph.nodes) .force('link', d3.forceLink(graph.links).id(d => d.id).distance(520)) .force('charge', d3.forceManyBody().strength(-600)) .force('center', d3.forceCenter(W / 2, H / 2)) .force('collide', d3.forceCollide(d => 50)); sim.on('tick', () => { link.attr('x1', d=>d.source.x).attr('y1', d=>d.source.y) .attr('x2', d=>d.target.x).attr('y2', d=>d.target.y); linkLbl.attr('x', d=> (d.source.x+d.target.x)/2) .attr('y', d=> (d.source.y+d.target.y)/2); node.attr('transform', d=>`translate(${d.x},${d.y})`); }); function dragStart(e,d){ if(!e.active) sim.alphaTarget(.3).restart(); d.fx=d.x; d.fy=d.y; } function dragged(e,d){ d.fx=e.x; d.fy=e.y; } function dragEnd(e,d){ if(!e.active) sim.alphaTarget(0); d.fx=d.fy=null; } """ HTML = """ Transformers modular graph
🟑 base
πŸ”΅ modular
πŸ”΄ candidate
red edgeΒ = high embedding similarity

""" # ──────────────────────────────────────────────────────────────────────────────── # HTML writer # ──────────────────────────────────────────────────────────────────────────────── def write_html(graph_data: dict, path: Path): path.write_text(generate_html(graph_data), encoding="utf-8") # ──────────────────────────────────────────────────────────────────────────────── # MAIN # ──────────────────────────────────────────────────────────────────────────────── def main(): ap = argparse.ArgumentParser(description="Visualise modular dependencies + candidates") ap.add_argument("transformers", help="Path to local πŸ€— transformers repo root") ap.add_argument("--multimodal", action="store_true", help="filter to models with β‰₯3 'pixel_values'") ap.add_argument("--sim-threshold", type=float, default=SIM_DEFAULT) ap.add_argument("--out", default=HTML_DEFAULT) ap.add_argument("--sim-method", choices=["jaccard", "embedding"], default="jaccard", help="Similarity method: 'jaccard' or 'embedding'") args = ap.parse_args() graph = build_graph_json( transformers_dir=Path(args.transformers).expanduser().resolve(), threshold=args.sim_threshold, multimodal=args.multimodal, sim_method=args.sim_method, ) write_html(graph, Path(args.out).expanduser()) if __name__ == "__main__": main()