File size: 648 Bytes
f0c48c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
"""
"""

import torch
from huggingface_hub import hf_hub_download
from spaces.zero.torch.aoti import ZeroGPUCompiledModel
from spaces.zero.torch.aoti import ZeroGPUWeights


def aoti_load(module: torch.nn.Module, repo_id: str):
    repeated_blocks = module._repeated_blocks
    aoti_files = {name: hf_hub_download(repo_id, f'{name}.pt2') for name in repeated_blocks}
    for block_name, aoti_file in aoti_files.items():
        for block in module.modules():
            if block.__class__.__name__ == block_name:
                weights = ZeroGPUWeights(block.state_dict())
                block.forward = ZeroGPUCompiledModel(aoti_file, weights)