#!/usr/bin/env python """ modular_graph_and_candidates.py ================================ Create **one** rich view that combines 1. The *dependency graph* between existing **modular_*.py** implementations in πŸ€—Β Transformers (blue/🟑) **and** 2. The list of *missing* modular models (full‑red nodes) **plus** similarity edges (full‑red links) between highly‑overlapping modelling files – the output of *find_modular_candidates.py* – so you can immediately spot good refactor opportunities. ––– Usage ––– ```bash python modular_graph_and_candidates.py /path/to/transformers \ --multimodal # keep only models whose modelling code mentions # "pixel_values" β‰₯Β 3 times --sim-threshold 0.5 # Jaccard cutoff (default 0.50) --out graph.html # output HTML file name ``` Colour legend in the generated HTML: * 🟑 **base model**Β β€” has modular shards *imported* by others but no parent * πŸ”΅Β **derived modular model**Β β€” has a `modular_*.py` and inherits from β‰₯β€―1 model * πŸ”΄Β **candidate**Β β€” no `modular_*.py` yet (and/or very similar to another) * red edges = high‑Jaccard similarity links (potential to factorise) """ from __future__ import annotations import argparse import ast import json import re import tokenize from collections import Counter, defaultdict from itertools import combinations from pathlib import Path from typing import Dict, List, Set, Tuple from sentence_transformers import SentenceTransformer, util from tqdm import tqdm import numpy as np import spaces # ──────────────────────────────────────────────────────────────────────────────── # CONFIG # ─────────────────────────────────────────────────────────────────────────────── SIM_DEFAULT = 0.78 # Jaccard similarity threshold PIXEL_MIN_HITS = 0 # multimodal trigger ("pixel_values") HTML_DEFAULT = "d3_modular_graph.html" # ──────────────────────────────────────────────────────────────────────────────── # 1) Helpers to analyse *modelling* files (for similarity & multimodal filter) # ──────────────────────────────────────────────────────────────────────────────── def _strip_source(code: str) -> str: """Remove doc‑strings, comments and import lines to keep only the core code.""" code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code) # doc‑strings code = re.sub(r"#.*", "", code) # # comments return "\n".join(ln for ln in code.splitlines() if not re.match(r"\s*(from|import)\s+", ln)) def _tokenise(code: str) -> Set[str]: toks: Set[str] = set() for tok in tokenize.generate_tokens(iter(code.splitlines(keepends=True)).__next__): if tok.type == tokenize.NAME: toks.add(tok.string) return toks def build_token_bags(models_root: Path) -> Tuple[Dict[str, List[Set[str]]], Dict[str, int]]: """Return token‑bags of every `modeling_*.py` plus a pixel‑value counter.""" bags: Dict[str, List[Set[str]]] = defaultdict(list) pixel_hits: Dict[str, int] = defaultdict(int) for mdl_dir in sorted(p for p in models_root.iterdir() if p.is_dir()): for py in mdl_dir.rglob("modeling_*.py"): try: text = py.read_text(encoding="utf‑8") pixel_hits[mdl_dir.name] += text.count("pixel_values") bags[mdl_dir.name].append(_tokenise(_strip_source(text))) except Exception as e: print(f"⚠️ Skipped {py}: {e}") return bags, pixel_hits def _jaccard(a: Set[str], b: Set[str]) -> float: return 0.0 if (not a or not b) else len(a & b) / len(a | b) def similarity_clusters(bags: Dict[str, List[Set[str]]], thr: float) -> Dict[Tuple[str,str], float]: """Return {(modelA, modelB): score} for pairs with Jaccard β‰₯ *thr*.""" largest = {m: max(ts, key=len) for m, ts in bags.items() if ts} out: Dict[Tuple[str,str], float] = {} for m1, m2 in combinations(sorted(largest.keys()), 2): s = _jaccard(largest[m1], largest[m2]) if s >= thr: out[(m1, m2)] = s return out @spaces.GPU def embedding_similarity_clusters(models_root: Path, missing: List[str], thr: float) -> Dict[Tuple[str, str], float]: model = SentenceTransformer("codesage/codesage-large-v2", trust_remote_code=True) model.max_seq_length = 4096 # truncate overly long modeling files texts = {} for name in tqdm(missing, desc="Reading modeling files"): code = "" for py in (models_root / name).rglob("modeling_*.py"): try: code += _strip_source(py.read_text(encoding="utf-8")) + "\n" except Exception: continue texts[name] = code.strip() or " " names = list(texts) all_embeddings = [] print("Encoding embeddings...") batch_size = 8 # or 2 if memory is tight for i in tqdm(range(0, len(names), batch_size), desc="Batches", leave=False): batch = [texts[n] for n in names[i:i+batch_size]] emb = model.encode(batch, convert_to_numpy=True, show_progress_bar=False) all_embeddings.append(emb) embeddings = np.vstack(all_embeddings) # [N, D] print("Computing pairwise similarities...") sims = embeddings @ embeddings.T # cosine since already normalized out = {} for i in range(len(names)): for j in range(i + 1, len(names)): s = sims[i, j] if s >= thr: out[(names[i], names[j])] = float(s) return out # ──────────────────────────────────────────────────────────────────────────────── # 2) Scan *modular_*.py* files to build an import‑dependency graph # – only **modeling_*** imports are considered (skip configuration / processing) # ──────────────────────────────────────────────────────────────────────────────── def modular_files(models_root: Path) -> List[Path]: return [p for p in models_root.rglob("modular_*.py") if p.suffix == ".py"] def dependency_graph(modular_files: List[Path], models_root: Path) -> Dict[str, List[Dict[str,str]]]: """Return {derived_model: [{source, imported_class}, ...]} Only `modeling_*` imports are kept; anything coming from configuration/processing/ image* utils is ignored so the visual graph focuses strictly on modelling code. Excludes edges to sources whose model name is not a model dir. """ model_names = {p.name for p in models_root.iterdir() if p.is_dir()} deps: Dict[str, List[Dict[str,str]]] = defaultdict(list) for fp in modular_files: derived = fp.parent.name try: tree = ast.parse(fp.read_text(encoding="utf‑8"), filename=str(fp)) except Exception as e: print(f"⚠️ AST parse failed for {fp}: {e}") continue for node in ast.walk(tree): if not isinstance(node, ast.ImportFrom) or not node.module: continue mod = node.module # keep only *modeling_* imports, drop anything else if ("modeling_" not in mod or "configuration_" in mod or "processing_" in mod or "image_processing" in mod or "modeling_attn_mask_utils" in mod): continue parts = re.split(r"[./]", mod) src = next((p for p in parts if p not in {"", "models", "transformers"}), "") if not src or src == derived or src not in model_names: continue for alias in node.names: deps[derived].append({"source": src, "imported_class": alias.name}) return dict(deps) # modular_graph_and_candidates.py (top-level) def build_graph_json( transformers_dir: Path, threshold: float = SIM_DEFAULT, multimodal: bool = False, sim_method: str = "jaccard", ) -> dict: """Return the {nodes, links} dict that D3 needs.""" models_root = transformers_dir / "src/transformers/models" bags, pix_hits = build_token_bags(models_root) mod_files = modular_files(models_root) deps = dependency_graph(mod_files, models_root) models_with_modular = {p.parent.name for p in mod_files} missing = [m for m in bags if m not in models_with_modular] if multimodal: missing = [m for m in missing if pix_hits[m] >= PIXEL_MIN_HITS] if sim_method == "jaccard": sims = similarity_clusters({m: bags[m] for m in missing}, threshold) else: sims = embedding_similarity_clusters(models_root, missing, threshold) # ---- assemble nodes & links ---- nodes: Set[str] = set() links: List[dict] = [] for drv, lst in deps.items(): for d in lst: links.append({ "source": d["source"], "target": drv, "label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports", "cand": False }) nodes.update({d["source"], drv}) for (a, b), s in sims.items(): links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True}) nodes.update({a, b}) nodes.update(missing) deg = Counter() for lk in links: deg[lk["source"]] += 1 deg[lk["target"]] += 1 max_deg = max(deg.values() or [1]) targets = {lk["target"] for lk in links if not lk["cand"]} sources = {lk["source"] for lk in links if not lk["cand"]} missing_only = [m for m in missing if m not in sources and m not in targets] nodes.update(missing_only) nodelist = [] for n in sorted(nodes): if n in missing_only: cls = "cand" elif n in sources and n not in targets: cls = "base" else: cls = "derived" nodelist.append({"id": n, "cls": cls, "sz": 1 + 2*(deg[n]/max_deg)}) graph = {"nodes": nodelist, "links": links} return graph def generate_html(graph: dict) -> str: """Return the full HTML string with inlined CSS/JS + graph JSON.""" js = JS.replace("__GRAPH_DATA__", json.dumps(graph, separators=(",", ":"))) return HTML.replace("__CSS__", CSS).replace("__JS__", js) # ──────────────────────────────────────────────────────────────────────────────── # 3) HTML (D3.js) boilerplate – CSS + JS templates (unchanged design) # ──────────────────────────────────────────────────────────────────────────────── CSS = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap'); :root { --base: 60px; } body { margin:0; font-family:'Inter',Arial,sans-serif; background:transparent; overflow:hidden; } svg { width:100vw; height:100vh; } .link { stroke:#999; stroke-opacity:.6; } .link.cand { stroke:#e63946; stroke-width:2.5; } .node-label { fill:#333; pointer-events:none; text-anchor:middle; font-weight:600; } .link-label { fill:#555; font-size:10px; pointer-events:none; text-anchor:middle; } .node.base path { fill:#ffbe0b; } .node.derived circle { fill:#1f77b4; } .node.cand circle, .node.cand path { fill:#e63946; } #legend { position:fixed; top:18px; left:18px; background:rgba(255,255,255,.92); padding:18px 28px; border-radius:10px; border:1.5px solid #bbb; font-size:18px; box-shadow:0 2px 8px rgba(0,0,0,.08); } """ JS = """ function updateVisibility() { const show = document.getElementById('toggleRed').checked; svg.selectAll('.link.cand').style('display', show ? null : 'none'); svg.selectAll('.node.cand').style('display', show ? null : 'none'); svg.selectAll('.link-label') .filter(d => d.cand) .style('display', show ? null : 'none'); } document.getElementById('toggleRed').addEventListener('change', updateVisibility); const graph = __GRAPH_DATA__; const W = innerWidth, H = innerHeight; const svg = d3.select('#dependency').call(d3.zoom().on('zoom', e => g.attr('transform', e.transform))); const g = svg.append('g'); const link = g.selectAll('line') .data(graph.links) .join('line') .attr('class', d => d.cand ? 'link cand' : 'link'); const linkLbl = g.selectAll('text.link-label') .data(graph.links) .join('text') .attr('class', 'link-label') .text(d => d.label); const node = g.selectAll('g.node') .data(graph.nodes) .join('g') .attr('class', d => `node ${d.cls}`) .call(d3.drag().on('start', dragStart).on('drag', dragged).on('end', dragEnd)); node.filter(d => d.cls==='base').append('image') .attr('xlink:href', 'hf-logo.svg').attr('x', -30).attr('y', -30).attr('width', 60).attr('height', 60); node.filter(d => d.cls!=='base').append('circle').attr('r', d => 20*d.sz); node.append('text').attr('class','node-label').attr('dy','-2.4em').text(d => d.id); const sim = d3.forceSimulation(graph.nodes) .force('link', d3.forceLink(graph.links).id(d => d.id).distance(520)) // tighter links .force('charge', d3.forceManyBody().strength(-600)) // weaker repulsion .force('center', d3.forceCenter(W / 2, H / 2)) .force('collide', d3.forceCollide(d => d.cls === 'base' ? 50 : 50)); // smaller bubble spacing sim.on('tick', () => { link.attr('x1', d=>d.source.x).attr('y1', d=>d.source.y) .attr('x2', d=>d.target.x).attr('y2', d=>d.target.y); linkLbl.attr('x', d=> (d.source.x+d.target.x)/2) .attr('y', d=> (d.source.y+d.target.y)/2); node.attr('transform', d=>`translate(${d.x},${d.y})`); }); function dragStart(e,d){ if(!e.active) sim.alphaTarget(.3).restart(); d.fx=d.x; d.fy=d.y; } function dragged(e,d){ d.fx=e.x; d.fy=e.y; } function dragEnd(e,d){ if(!e.active) sim.alphaTarget(0); d.fx=d.fy=null; } """ HTML = """ Transformers modular graph
🟑 base
πŸ”΅ modular
πŸ”΄ candidate
red edgeΒ = high embedding similarity

""" # ──────────────────────────────────────────────────────────────────────────────── # HTML writer # ──────────────────────────────────────────────────────────────────────────────── def write_html(graph_data: dict, path: Path): path.write_text(generate_html(graph_data), encoding="utf-8") # ──────────────────────────────────────────────────────────────────────────────── # MAIN # ──────────────────────────────────────────────────────────────────────────────── def main(): ap = argparse.ArgumentParser(description="Visualise modular dependencies + candidates") ap.add_argument("transformers", help="Path to local πŸ€— transformers repo root") ap.add_argument("--multimodal", action="store_true", help="filter to models with β‰₯3 'pixel_values'") ap.add_argument("--sim-threshold", type=float, default=SIM_DEFAULT) ap.add_argument("--out", default=HTML_DEFAULT) ap.add_argument("--sim-method", choices=["jaccard", "embedding"], default="jaccard", help="Similarity method: 'jaccard' or 'embedding'") args = ap.parse_args() graph = build_graph_json( transformers_dir=Path(args.transformers).expanduser().resolve(), threshold=args.sim_threshold, multimodal=args.multimodal, sim_method=args.sim_method, ) write_html(graph, Path(args.out).expanduser()) if __name__ == "__main__": main()