CantusSVS-hf / basics /base_exporter.py
liampond
Clean deploy snapshot
c42fe7e
raw
history blame
1.95 kB
import json
from pathlib import Path
from typing import Union
import torch
import torch.nn as nn
from utils.hparams import hparams
class BaseExporter:
def __init__(
self,
device: Union[str, torch.device] = None,
cache_dir: Path = None,
**kwargs
):
self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.cache_dir: Path = cache_dir.resolve() if cache_dir is not None \
else Path(__file__).parent.parent / 'deployment' / 'cache'
self.cache_dir.mkdir(parents=True, exist_ok=True)
# noinspection PyMethodMayBeStatic
def build_spk_map(self) -> dict:
if hparams['use_spk_id']:
with open(Path(hparams['work_dir']) / 'spk_map.json', 'r', encoding='utf8') as f:
spk_map = json.load(f)
assert isinstance(spk_map, dict) and len(spk_map) > 0, 'Invalid or empty speaker map!'
assert len(spk_map) == len(set(spk_map.values())), 'Duplicate speaker id in speaker map!'
return spk_map
else:
return {}
def build_model(self) -> nn.Module:
"""
Creates an instance of nn.Module and load its state dict on the target device.
"""
raise NotImplementedError()
def export_model(self, path: Path):
"""
Exports the model to ONNX format.
:param path: the target model path
"""
raise NotImplementedError()
def export_attachments(self, path: Path):
"""
Exports related files and configs (e.g. the dictionary) to the target directory.
:param path: the target directory
"""
raise NotImplementedError()
def export(self, path: Path):
"""
Exports all the artifacts to the target directory.
:param path: the target directory
"""
raise NotImplementedError()