Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python | |
""" | |
modular_graph_and_candidates.py | |
================================ | |
Create **one** rich view that combines | |
1. The *dependency graph* between existing **modular_*.py** implementations in | |
π€Β Transformers (blue/π‘) **and** | |
2. The list of *missing* modular models (fullβred nodes) **plus** similarity | |
edges (fullβred links) between highlyβoverlapping modelling files β the | |
output of *find_modular_candidates.py* β so you can immediately spot good | |
refactor opportunities. | |
βββΒ UsageΒ βββ | |
```bash | |
python modular_graph_and_candidates.py /path/to/transformers \ | |
--multimodal # keep only models whose modelling code mentions | |
# "pixel_values" β₯Β 3 times | |
--sim-threshold 0.5 # Jaccard cutoff (default 0.50) | |
--out graph.html # output HTML file name | |
``` | |
Colour legend in the generated HTML: | |
* π‘Β **base model**Β β has modular shards *imported* by others but no parent | |
* π΅Β **derived modular model**Β β has a `modular_*.py` and inherits from β₯β―1 model | |
* π΄Β **candidate**Β β no `modular_*.py` yet (and/or very similar to another) | |
* red edges = highβJaccard similarity links (potential to factorise) | |
""" | |
from __future__ import annotations | |
import argparse | |
import ast | |
import json | |
import re | |
import tokenize | |
from collections import Counter, defaultdict | |
from itertools import combinations | |
from pathlib import Path | |
from typing import Dict, List, Set, Tuple | |
from sentence_transformers import SentenceTransformer, util | |
from tqdm import tqdm | |
import numpy as np | |
import spaces | |
import torch | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# CONFIG | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
SIM_DEFAULT = 0.5 # similarity threshold | |
PIXEL_MIN_HITS = 0 # multimodal trigger ("pixel_values") | |
HTML_DEFAULT = "d3_modular_graph.html" | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 1) Helpers to analyse *modelling* files (for similarity & multimodal filter) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def _strip_source(code: str) -> str: | |
"""Remove docβstrings, comments and import lines to keep only the core code.""" | |
code = re.sub(r'("""|\'\'\')(?:.|\n)*?\1', "", code) # docβstrings | |
code = re.sub(r"#.*", "", code) # # comments | |
return "\n".join(ln for ln in code.splitlines() | |
if not re.match(r"\s*(from|import)\s+", ln)) | |
def _tokenise(code: str) -> Set[str]: | |
"""Extract identifiers using regex - more robust than tokenizer for malformed code.""" | |
toks: Set[str] = set() | |
for match in re.finditer(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', code): | |
toks.add(match.group()) | |
return toks | |
def build_token_bags(models_root: Path) -> Tuple[Dict[str, List[Set[str]]], Dict[str, int]]: | |
"""Return tokenβbags of every `modeling_*.py` plus a pixelβvalue counter.""" | |
bags: Dict[str, List[Set[str]]] = defaultdict(list) | |
pixel_hits: Dict[str, int] = defaultdict(int) | |
for mdl_dir in sorted(p for p in models_root.iterdir() if p.is_dir()): | |
for py in mdl_dir.rglob("modeling_*.py"): | |
try: | |
text = py.read_text(encoding="utfβ8") | |
pixel_hits[mdl_dir.name] += text.count("pixel_values") | |
bags[mdl_dir.name].append(_tokenise(_strip_source(text))) | |
except Exception as e: | |
print(f"β οΈ Skipped {py}: {e}") | |
return bags, pixel_hits | |
def _jaccard(a: Set[str], b: Set[str]) -> float: | |
return 0.0 if (not a or not b) else len(a & b) / len(a | b) | |
def similarity_clusters(bags: Dict[str, List[Set[str]]], thr: float) -> Dict[Tuple[str,str], float]: | |
"""Return {(modelA, modelB): score} for pairs with Jaccard β₯ *thr*.""" | |
largest = {m: max(ts, key=len) for m, ts in bags.items() if ts} | |
out: Dict[Tuple[str,str], float] = {} | |
for m1, m2 in combinations(sorted(largest.keys()), 2): | |
s = _jaccard(largest[m1], largest[m2]) | |
if s >= thr: | |
out[(m1, m2)] = s | |
return out | |
def embedding_similarity_clusters(models_root: Path, missing: List[str], thr: float) -> Dict[Tuple[str, str], float]: | |
model = SentenceTransformer("codesage/codesage-large-v2", device="cuda", trust_remote_code=True) | |
try: | |
cfg = model[0].auto_model.config | |
pos_limit = int(getattr(cfg, "n_positions", getattr(cfg, "max_position_embeddings"))) | |
except Exception: | |
pos_limit = 1024 | |
seq_len = min(pos_limit, 2048) | |
model.max_seq_length = seq_len | |
model[0].max_seq_length = seq_len | |
model[0].tokenizer.model_max_length = seq_len | |
texts = {} | |
for name in tqdm(missing, desc="Reading modeling files"): | |
if any(skip in name.lower() for skip in ["mobilebert", "lxmert"]): | |
print(f"Skipping {name} (causes GPU abort)") | |
continue | |
code = "" | |
for py in (models_root / name).rglob("modeling_*.py"): | |
try: | |
code += _strip_source(py.read_text(encoding="utf-8")) + "\n" | |
except Exception: | |
continue | |
texts[name] = code.strip() or " " | |
names = list(texts) | |
all_embeddings = [] | |
print(f"Encoding embeddings for {len(names)} models...") | |
batch_size = 4 # keep your default | |
# ββ two-stage caching: temp (for resume) + permanent (for reuse) βββββββββββββ | |
temp_cache_path = Path("temp_embeddings.npz") # For resuming computation | |
final_cache_path = Path("embeddings_cache.npz") # For permanent storage | |
start_idx = 0 | |
emb_dim = getattr(model, "get_sentence_embedding_dimension", lambda: 768)() | |
# Try to load from permanent cache first | |
if final_cache_path.exists(): | |
try: | |
cached = np.load(final_cache_path, allow_pickle=True) | |
cached_names = list(cached["names"]) | |
if names == cached_names: # Exact match - use final cache | |
print(f"β Using final embeddings cache ({len(cached_names)} models)") | |
return compute_similarities_from_cache(thr) | |
except Exception as e: | |
print(f"β οΈ Failed to load final cache: {e}") | |
# Try to resume from temp cache | |
if temp_cache_path.exists(): | |
try: | |
cached = np.load(temp_cache_path, allow_pickle=True) | |
cached_names = list(cached["names"]) | |
if names[:len(cached_names)] == cached_names: | |
loaded = cached["embeddings"].astype(np.float32) | |
all_embeddings.append(loaded) | |
start_idx = len(cached_names) | |
print(f"π Resuming from temp cache: {start_idx}/{len(names)} models") | |
except Exception as e: | |
print(f"β οΈ Failed to load temp cache: {e}") | |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
for i in tqdm(range(start_idx, len(names), batch_size), desc="Batches", leave=False): | |
batch_names = names[i:i+batch_size] | |
batch_texts = [texts[name] for name in batch_names] | |
try: | |
print(f"Processing batch: {batch_names}") | |
emb = model.encode(batch_texts, convert_to_numpy=True, show_progress_bar=False) | |
except Exception as e: | |
print(f"β οΈ GPU worker error for batch {batch_names}: {type(e).__name__}: {e}") | |
emb = np.zeros((len(batch_names), emb_dim), dtype=np.float32) | |
all_embeddings.append(emb) | |
# save to temp cache after each batch (for resume) | |
try: | |
cur = np.vstack(all_embeddings).astype(np.float32) | |
np.savez( | |
temp_cache_path, | |
embeddings=cur, | |
names=np.array(names[:i+len(batch_names)], dtype=object), | |
) | |
except Exception as e: | |
print(f"β οΈ Failed to write temp cache: {e}") | |
if (i - start_idx) % (3 * batch_size) == 0 and torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
print(f"π§Ή Cleared GPU cache after batch {(i - start_idx)//batch_size + 1}") | |
embeddings = np.vstack(all_embeddings).astype(np.float32) | |
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12 | |
embeddings = embeddings / norms | |
print("Computing pairwise similarities...") | |
sims_mat = embeddings @ embeddings.T | |
out = {} | |
matrix_size = embeddings.shape[0] | |
processed_names = names[:matrix_size] | |
for i in range(matrix_size): | |
for j in range(i + 1, matrix_size): | |
s = float(sims_mat[i, j]) | |
if s >= thr: | |
out[(processed_names[i], processed_names[j])] = s | |
# Save to final cache when complete | |
try: | |
np.savez(final_cache_path, embeddings=embeddings, names=np.array(names, dtype=object)) | |
print(f"πΎ Final embeddings saved to {final_cache_path}") | |
# Clean up temp cache | |
if temp_cache_path.exists(): | |
temp_cache_path.unlink() | |
print(f"π§Ή Cleaned up temp cache") | |
except Exception as e: | |
print(f"β οΈ Failed to save final cache: {e}") | |
return out | |
def compute_similarities_from_cache(threshold: float) -> Dict[Tuple[str, str], float]: | |
"""Compute similarities from cached embeddings without reprocessing.""" | |
embeddings_path = Path("embeddings_cache.npz") | |
if not embeddings_path.exists(): | |
return {} | |
try: | |
cached = np.load(embeddings_path, allow_pickle=True) | |
embeddings = cached["embeddings"].astype(np.float32) | |
names = list(cached["names"]) | |
# Normalize embeddings | |
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-12 | |
embeddings = embeddings / norms | |
# Compute similarities | |
sims_mat = embeddings @ embeddings.T | |
out = {} | |
for i in range(len(names)): | |
for j in range(i + 1, len(names)): | |
s = float(sims_mat[i, j]) | |
if s >= threshold: | |
out[(names[i], names[j])] = s | |
print(f"β‘ Computed {len(out)} similarities from cache (threshold: {threshold})") | |
return out | |
except Exception as e: | |
print(f"β οΈ Failed to compute from cache: {e}") | |
return {} | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 2) Scan *modular_*.py* files to build an importβdependency graph | |
# β only **modeling_*** imports are considered (skip configuration / processing) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def modular_files(models_root: Path) -> List[Path]: | |
return [p for p in models_root.rglob("modular_*.py") if p.suffix == ".py"] | |
def dependency_graph(modular_files: List[Path], models_root: Path) -> Dict[str, List[Dict[str,str]]]: | |
"""Return {derived_model: [{source, imported_class}, ...]} | |
Only `modeling_*` imports are kept; anything coming from configuration/processing/ | |
image* utils is ignored so the visual graph focuses strictly on modelling code. | |
Excludes edges to sources whose model name is not a model dir. | |
""" | |
model_names = {p.name for p in models_root.iterdir() if p.is_dir()} | |
deps: Dict[str, List[Dict[str,str]]] = defaultdict(list) | |
for fp in modular_files: | |
derived = fp.parent.name | |
try: | |
tree = ast.parse(fp.read_text(encoding="utfβ8"), filename=str(fp)) | |
except Exception as e: | |
print(f"β οΈ AST parse failed for {fp}: {e}") | |
continue | |
for node in ast.walk(tree): | |
if not isinstance(node, ast.ImportFrom) or not node.module: | |
continue | |
mod = node.module | |
# keep only *modeling_* imports, drop anything else | |
if ("modeling_" not in mod or | |
"configuration_" in mod or | |
"processing_" in mod or | |
"image_processing" in mod or | |
"modeling_attn_mask_utils" in mod): | |
continue | |
parts = re.split(r"[./]", mod) | |
src = next((p for p in parts if p not in {"", "models", "transformers"}), "") | |
if not src or src == derived or src not in model_names: | |
continue | |
for alias in node.names: | |
deps[derived].append({"source": src, "imported_class": alias.name}) | |
return dict(deps) | |
# modular_graph_and_candidates.py (top-level) | |
def get_missing_models(models_root: Path, multimodal: bool = False) -> Tuple[List[str], Dict[str, List[Set[str]]], Dict[str, int]]: | |
"""Get list of models missing modular implementations.""" | |
bags, pix_hits = build_token_bags(models_root) | |
mod_files = modular_files(models_root) | |
models_with_modular = {p.parent.name for p in mod_files} | |
missing = [m for m in bags if m not in models_with_modular] | |
if multimodal: | |
missing = [m for m in missing if pix_hits[m] >= PIXEL_MIN_HITS] | |
return missing, bags, pix_hits | |
def compute_similarities(models_root: Path, missing: List[str], bags: Dict[str, List[Set[str]]], | |
threshold: float, sim_method: str) -> Dict[Tuple[str, str], float]: | |
"""Compute similarities between missing models using specified method.""" | |
if sim_method == "jaccard": | |
return similarity_clusters({m: bags[m] for m in missing}, threshold) | |
else: | |
# Try to use cached embeddings first | |
embeddings_path = Path("embeddings_cache.npz") | |
if embeddings_path.exists(): | |
cached_sims = compute_similarities_from_cache(threshold) | |
if cached_sims: # Cache exists and worked | |
return cached_sims | |
# Fallback to full computation | |
return embedding_similarity_clusters(models_root, missing, threshold) | |
def build_graph_json( | |
transformers_dir: Path, | |
threshold: float = SIM_DEFAULT, | |
multimodal: bool = False, | |
sim_method: str = "jaccard", | |
) -> dict: | |
"""Return the {nodes, links} dict that D3 needs.""" | |
# Check if we can use cached embeddings only | |
embeddings_cache = Path("embeddings_cache.npz") | |
print(f"π Cache file exists: {embeddings_cache.exists()}, sim_method: {sim_method}") | |
if sim_method == "embedding" and embeddings_cache.exists(): | |
try: | |
# Try to compute from cache without accessing repo | |
cached_sims = compute_similarities_from_cache(threshold) | |
print(f"π Got {len(cached_sims)} cached similarities") | |
if cached_sims: | |
# Create graph with cached similarities + modular dependencies | |
cached_data = np.load(embeddings_cache, allow_pickle=True) | |
missing = list(cached_data["names"]) | |
# Still need to get modular dependencies from repo | |
models_root = transformers_dir / "src/transformers/models" | |
mod_files = modular_files(models_root) | |
deps = dependency_graph(mod_files, models_root) | |
# Build full graph structure | |
nodes = set(missing) # Start with cached models | |
links = [] | |
# Add dependency links | |
for drv, lst in deps.items(): | |
for d in lst: | |
links.append({ | |
"source": d["source"], | |
"target": drv, | |
"label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports", | |
"cand": False | |
}) | |
nodes.update({d["source"], drv}) | |
# Add similarity links | |
for (a, b), s in cached_sims.items(): | |
links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True}) | |
# Create node list with proper classification | |
targets = {lk["target"] for lk in links if not lk["cand"]} | |
sources = {lk["source"] for lk in links if not lk["cand"]} | |
nodelist = [] | |
for n in sorted(nodes): | |
if n in missing and n not in sources and n not in targets: | |
cls = "cand" | |
elif n in sources and n not in targets: | |
cls = "base" | |
else: | |
cls = "derived" | |
nodelist.append({"id": n, "cls": cls, "sz": 1}) | |
print(f"β‘ Built graph from cache: {len(nodelist)} nodes, {len(links)} links") | |
return {"nodes": nodelist, "links": links} | |
except Exception as e: | |
print(f"β οΈ Cache-only build failed: {e}, falling back to full build") | |
# Full build with repository access | |
models_root = transformers_dir / "src/transformers/models" | |
# Get missing models and their data | |
missing, bags, pix_hits = get_missing_models(models_root, multimodal) | |
# Build dependency graph | |
mod_files = modular_files(models_root) | |
deps = dependency_graph(mod_files, models_root) | |
# Compute similarities | |
sims = compute_similarities(models_root, missing, bags, threshold, sim_method) | |
# ---- assemble nodes & links ---- | |
nodes: Set[str] = set() | |
links: List[dict] = [] | |
for drv, lst in deps.items(): | |
for d in lst: | |
links.append({ | |
"source": d["source"], | |
"target": drv, | |
"label": f"{sum(1 for x in lst if x['source'] == d['source'])} imports", | |
"cand": False | |
}) | |
nodes.update({d["source"], drv}) | |
for (a, b), s in sims.items(): | |
links.append({"source": a, "target": b, "label": f"{s*100:.1f}%", "cand": True}) | |
nodes.update({a, b}) | |
nodes.update(missing) | |
deg = Counter() | |
for lk in links: | |
deg[lk["source"]] += 1 | |
deg[lk["target"]] += 1 | |
max_deg = max(deg.values() or [1]) | |
targets = {lk["target"] for lk in links if not lk["cand"]} | |
sources = {lk["source"] for lk in links if not lk["cand"]} | |
missing_only = [m for m in missing if m not in sources and m not in targets] | |
nodes.update(missing_only) | |
nodelist = [] | |
for n in sorted(nodes): | |
if n in missing_only: | |
cls = "cand" | |
elif n in sources and n not in targets: | |
cls = "base" | |
else: | |
cls = "derived" | |
nodelist.append({"id": n, "cls": cls, "sz": 1 + 2*(deg[n]/max_deg)}) | |
graph = {"nodes": nodelist, "links": links} | |
return graph | |
def generate_html(graph: dict) -> str: | |
"""Return the full HTML string with inlined CSS/JS + graph JSON.""" | |
js = JS.replace("__GRAPH_DATA__", json.dumps(graph, separators=(",", ":"))) | |
return HTML.replace("__CSS__", CSS).replace("__JS__", js) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 3) HTML (D3.js) boilerplate β CSS + JS templates (unchanged design) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
CSS = """ | |
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap'); | |
:root{ | |
--bg:#ffffff; | |
--text:#222222; | |
--muted:#555555; | |
--outline:#ffffff; | |
} | |
@media (prefers-color-scheme: dark){ | |
:root{ | |
--bg:#0b0d10; | |
--text:#e8e8e8; | |
--muted:#c8c8c8; | |
--outline:#000000; | |
} | |
} | |
body{ margin:0; font-family:'Inter',Arial,sans-serif; background:var(--bg); overflow:hidden; } | |
svg{ width:100vw; height:100vh; } | |
.link{ stroke:#999; stroke-opacity:.6; } | |
.link.cand{ stroke:#e63946; stroke-width:2.5; } | |
.node-label{ | |
fill:var(--text); | |
pointer-events:none; | |
text-anchor:middle; | |
font-weight:600; | |
paint-order:stroke fill; | |
stroke:var(--outline); | |
stroke-width:3px; | |
} | |
.link-label{ | |
fill:var(--muted); | |
pointer-events:none; | |
text-anchor:middle; | |
font-size:10px; | |
paint-order:stroke fill; | |
stroke:var(--bg); | |
stroke-width:2px; | |
} | |
.node.base image{ width:60px; height:60px; transform:translate(-30px,-30px); } | |
.node.derived circle{ fill:#1f77b4; } | |
.node.cand circle, .node.cand path{ fill:#e63946; } | |
#legend{ | |
position:fixed; top:18px; left:18px; | |
background:rgba(255,255,255,.92); | |
padding:18px 28px; border-radius:10px; border:1.5px solid #bbb; | |
font-size:18px; box-shadow:0 2px 8px rgba(0,0,0,.08); | |
} | |
@media (prefers-color-scheme: dark){ | |
#legend{ background:rgba(20,22,25,.92); color:#e8e8e8; border-color:#444; } | |
} | |
""" | |
JS = """ | |
function updateVisibility() { | |
const show = document.getElementById('toggleRed').checked; | |
svg.selectAll('.link.cand').style('display', show ? null : 'none'); | |
svg.selectAll('.node.cand').style('display', show ? null : 'none'); | |
svg.selectAll('.link-label').filter(d => d.cand).style('display', show ? null : 'none'); | |
} | |
document.getElementById('toggleRed').addEventListener('change', updateVisibility); | |
const HF_LOGO_URI = "./static/hf-logo.png"; | |
const graph = __GRAPH_DATA__; | |
const W = innerWidth, H = innerHeight; | |
const svg = d3.select('#dependency').call(d3.zoom().on('zoom', e => g.attr('transform', e.transform))); | |
const g = svg.append('g'); | |
const link = g.selectAll('line') | |
.data(graph.links) | |
.join('line') | |
.attr('class', d => d.cand ? 'link cand' : 'link'); | |
const linkLbl = g.selectAll('text.link-label') | |
.data(graph.links) | |
.join('text') | |
.attr('class', 'link-label') | |
.text(d => d.label); | |
const node = g.selectAll('g.node') | |
.data(graph.nodes) | |
.join('g') | |
.attr('class', d => `node ${d.cls}`) | |
.call(d3.drag().on('start', dragStart).on('drag', dragged).on('end', dragEnd)); | |
const baseSel = node.filter(d => d.cls === 'base'); | |
if (HF_LOGO_URI){ | |
baseSel.append('image') | |
.attr('href', HF_LOGO_URI) | |
.attr('width', 40) | |
.attr('height', 40) | |
.attr('x', -20) | |
.attr('y', -20) | |
.on('error', function() { | |
console.log('Image failed to load:', HF_LOGO_URI); | |
// Fallback to circle | |
d3.select(this.parentNode).append('circle') | |
.attr('r', 22).attr('fill', '#ffbe0b'); | |
}); | |
console.log('Loading logo from:', HF_LOGO_URI); | |
}else{ | |
baseSel.append('circle').attr('r', d => 22*d.sz).attr('fill', '#ffbe0b'); | |
} | |
node.filter(d => d.cls !== 'base').append('circle').attr('r', d => 20*d.sz); | |
node.append('text') | |
.attr('class','node-label') | |
.attr('dy','-2.4em') | |
.style('font-size', d => d.cls === 'base' ? '32px' : '28px') | |
.style('font-weight', d => d.cls === 'base' ? 'bold' : 'normal') | |
.text(d => d.id); | |
const sim = d3.forceSimulation(graph.nodes) | |
.force('link', d3.forceLink(graph.links).id(d => d.id).distance(520)) | |
.force('charge', d3.forceManyBody().strength(-600)) | |
.force('center', d3.forceCenter(W / 2, H / 2)) | |
.force('collide', d3.forceCollide(d => 50)); | |
sim.on('tick', () => { | |
link.attr('x1', d=>d.source.x).attr('y1', d=>d.source.y) | |
.attr('x2', d=>d.target.x).attr('y2', d=>d.target.y); | |
linkLbl.attr('x', d=> (d.source.x+d.target.x)/2) | |
.attr('y', d=> (d.source.y+d.target.y)/2); | |
node.attr('transform', d=>`translate(${d.x},${d.y})`); | |
}); | |
function dragStart(e,d){ if(!e.active) sim.alphaTarget(.3).restart(); d.fx=d.x; d.fy=d.y; } | |
function dragged(e,d){ d.fx=e.x; d.fy=e.y; } | |
function dragEnd(e,d){ if(!e.active) sim.alphaTarget(0); d.fx=d.fy=null; } | |
""" | |
HTML = """ | |
<!DOCTYPE html> | |
<html lang='en'><head><meta charset='UTF-8'> | |
<title>Transformers modular graph</title> | |
<style>__CSS__</style></head><body> | |
<div id='legend'> | |
π‘ base<br>π΅ modular<br>π΄ candidate<br>red edgeΒ = high embedding similarity<br><br> | |
<label><input type="checkbox" id="toggleRed" checked> Show candidates edges and nodes</label> | |
</div> | |
<svg id='dependency'></svg> | |
<script src='https://d3js.org/d3.v7.min.js'></script> | |
<script>__JS__</script></body></html> | |
""" | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# HTML writer | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def write_html(graph_data: dict, path: Path): | |
path.write_text(generate_html(graph_data), encoding="utf-8") | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# MAIN | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def main(): | |
ap = argparse.ArgumentParser(description="Visualise modular dependencies + candidates") | |
ap.add_argument("transformers", help="Path to local π€ transformers repo root") | |
ap.add_argument("--multimodal", action="store_true", help="filter to models with β₯3 'pixel_values'") | |
ap.add_argument("--sim-threshold", type=float, default=SIM_DEFAULT) | |
ap.add_argument("--out", default=HTML_DEFAULT) | |
ap.add_argument("--sim-method", choices=["jaccard", "embedding"], default="jaccard", | |
help="Similarity method: 'jaccard' or 'embedding'") | |
args = ap.parse_args() | |
graph = build_graph_json( | |
transformers_dir=Path(args.transformers).expanduser().resolve(), | |
threshold=args.sim_threshold, | |
multimodal=args.multimodal, | |
sim_method=args.sim_method, | |
) | |
write_html(graph, Path(args.out).expanduser()) | |
if __name__ == "__main__": | |
main() | |