Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | |
| import contextlib | |
| import logging | |
| import os | |
| from pathlib import Path | |
| import torch.distributed | |
| import wandb | |
| import xformers.profiler | |
| from pydantic import BaseModel | |
| from torch.profiler.profiler import profile | |
| from xformers.profiler import MemSnapshotsProfiler, PyTorchProfiler | |
| from bytelatent.distributed import get_is_master | |
| class ProfilerArgs(BaseModel): | |
| run: bool = False | |
| trace_folder: str = "profiling" | |
| mem_warmup: int = 100 | |
| mem_steps: int = 2 | |
| profile_warmup: int = 102 | |
| profile_steps: int = 2 | |
| logger = logging.getLogger() | |
| def perfetto_to_html(json_file, html_file): | |
| import gzip | |
| import string | |
| import viztracer | |
| root = os.path.dirname(viztracer.__file__) | |
| sub = {} | |
| json_file = gzip.open(json_file) if ".gz" in str(json_file) else open(json_file) | |
| with open( | |
| os.path.join(root, "html/trace_viewer_embedder.html"), encoding="utf-8" | |
| ) as f: | |
| tmpl = f.read() | |
| with open(os.path.join(root, "html/trace_viewer_full.html"), encoding="utf-8") as f: | |
| sub["trace_viewer_full"] = f.read() | |
| with json_file as j: | |
| content = j.read() | |
| if isinstance(content, bytes): | |
| content = content.decode("utf-8") | |
| sub["json_data"] = content.replace("</script>", "<\\/script>") # type: ignore | |
| with open(html_file, "w+", encoding="utf-8") as output_file: | |
| output_file.write(string.Template(tmpl).substitute(sub)) | |
| class PyTorchProfilerWandb(PyTorchProfiler): | |
| def __init__(self, main_profiler) -> None: | |
| self.main_profiler = main_profiler | |
| self.num_steps = 0 | |
| self.pytorch_profiler = torch.profiler.profile( | |
| on_trace_ready=self._on_trace, | |
| profile_memory=True, | |
| record_shapes=True, | |
| # With stack gives huge profile traces | |
| # and bugs out because of some non ascii | |
| # character somewhere in pytorch | |
| with_stack=False, | |
| with_flops=True, | |
| activities=self.ACTIVITIES, | |
| ) | |
| def _analyze_trace(self, prof: profile): | |
| logger.info("Begin analyze trace") | |
| super()._analyze_trace(prof) | |
| logger.info("End analyze trace") | |
| def _on_trace(self, prof: torch.profiler.profiler.profile) -> None: | |
| super()._on_trace(prof) | |
| if get_is_master() and wandb.run is not None: | |
| filename = list( | |
| Path(self.main_profiler.output_dir).glob( | |
| "profile_CPU_CUDA*/*.pt.trace.json*" | |
| ) | |
| )[0] | |
| html_path = str(filename).replace(".json", ".html") | |
| perfetto_to_html(filename, html_path) | |
| wandb.log({"profile_trace": wandb.Html(html_path)}) | |
| class MemSnapshotsProfilerWandb(MemSnapshotsProfiler): | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| super().__exit__(exc_type, exc_val, exc_tb) | |
| if get_is_master() and wandb.run is not None: | |
| filename = list( | |
| Path(self.main_profiler.output_dir).glob("memory_trace_plot/*.html") | |
| )[0] | |
| wandb.log({"memory_trace": wandb.Html(open(filename), inject=False)}) | |
| def maybe_run_profiler(dump_dir, module, config: ProfilerArgs): | |
| # get user defined profiler settings | |
| if config.run: | |
| trace_dir = os.path.join(dump_dir, config.trace_folder) | |
| logger.info(f"Profiling active. Traces will be saved at {trace_dir}") | |
| if get_is_master() and not os.path.exists(trace_dir): | |
| os.makedirs(trace_dir) | |
| if torch.distributed.is_initialized(): | |
| torch.distributed.barrier() | |
| with xformers.profiler.profile( | |
| output_dir=trace_dir, | |
| module=module, | |
| schedule=[ | |
| ( | |
| MemSnapshotsProfilerWandb, | |
| config.mem_warmup, | |
| config.mem_warmup + config.mem_steps, | |
| ), | |
| ( | |
| PyTorchProfilerWandb, | |
| config.profile_warmup, | |
| config.profile_warmup + config.profile_steps, | |
| ), | |
| ], | |
| ) as profiler: | |
| yield profiler | |
| else: | |
| torch_profiler = contextlib.nullcontext() | |
| yield None | |