Spaces:
Running
on
Zero
Running
on
Zero
""" | |
""" | |
import torch | |
from huggingface_hub import hf_hub_download | |
from spaces.zero.torch.aoti import ZeroGPUCompiledModel | |
from spaces.zero.torch.aoti import ZeroGPUWeights | |
def aoti_load(module: torch.nn.Module, repo_id: str): | |
repeated_blocks = module._repeated_blocks | |
aoti_files = {name: hf_hub_download(repo_id, f'{name}.pt2') for name in repeated_blocks} | |
for block_name, aoti_file in aoti_files.items(): | |
for block in module.modules(): | |
if block.__class__.__name__ == block_name: | |
weights = ZeroGPUWeights(block.state_dict()) | |
block.forward = ZeroGPUCompiledModel(aoti_file, weights) | |