Spaces:
Running
on
Zero
Running
on
Zero
File size: 648 Bytes
f0c48c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
"""
"""
import torch
from huggingface_hub import hf_hub_download
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
from spaces.zero.torch.aoti import ZeroGPUWeights
def aoti_load(module: torch.nn.Module, repo_id: str):
repeated_blocks = module._repeated_blocks
aoti_files = {name: hf_hub_download(repo_id, f'{name}.pt2') for name in repeated_blocks}
for block_name, aoti_file in aoti_files.items():
for block in module.modules():
if block.__class__.__name__ == block_name:
weights = ZeroGPUWeights(block.state_dict())
block.forward = ZeroGPUCompiledModel(aoti_file, weights)
|