wanghaofan's picture
10x demo speedup (#1)
f0c48c3 verified
raw
history blame contribute delete
648 Bytes
"""
"""
import torch
from huggingface_hub import hf_hub_download
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
from spaces.zero.torch.aoti import ZeroGPUWeights
def aoti_load(module: torch.nn.Module, repo_id: str):
repeated_blocks = module._repeated_blocks
aoti_files = {name: hf_hub_download(repo_id, f'{name}.pt2') for name in repeated_blocks}
for block_name, aoti_file in aoti_files.items():
for block in module.modules():
if block.__class__.__name__ == block_name:
weights = ZeroGPUWeights(block.state_dict())
block.forward = ZeroGPUCompiledModel(aoti_file, weights)