CantusSVS-hf / inference /pipeline.py
liampond
Clean deploy snapshot
c42fe7e
# inference/pipeline.py
import os
import json
import sys
from pathlib import Path
from typing import Optional
from utils.hparams import set_hparams, hparams
from inference.ds_variance import DiffSingerVarianceInfer
from inference.ds_acoustic import DiffSingerAcousticInfer
from utils.infer_utils import parse_commandline_spk_mix, trans_key
from webapp.services.parsing.ds_validator import validate_ds
PROJECT_ROOT = Path(__file__).resolve().parent.parent
HF_CHECKPOINTS_DIR = "/tmp/cantussvs_v1/checkpoints"
def run_inference(
ds_path: Path,
output_dir: Path,
title: str,
*,
variance_exp: str = "regular_variance_v1",
acoustic_exp: str = "debug_test",
seed: int = 42,
num_runs: int = 1,
key_shift: int = 0,
gender: Optional[float] = None
) -> Path:
"""
Runs the full pipeline: variance model => acoustic model;
returns the path to the generated WAV.
"""
sys.argv = [
"",
"--config", str(PROJECT_ROOT / "checkpoints" / variance_exp / "config.yaml"),
"--exp_name", variance_exp,
"--infer"
]
set_hparams(print_hparams=False)
# 1) Check input DS exists
if not ds_path.exists():
raise FileNotFoundError(f"Input DS file not found: {ds_path}")
# 2) Load DS params
with open(ds_path, "r", encoding="utf-8") as f:
params = json.load(f)
if not isinstance(params, list):
params = [params]
# Validate loaded DS files
for p in params:
try:
validate_ds(p)
except Exception as e:
raise ValueError(f"Invalid input DS file: {e}")
# Ensure ph_seq present
for p in params:
if "ph_seq" not in p:
text = p.get("text", "")
p["ph_seq"] = " ".join(list(text.replace(" ", "")))
# Transpose
if key_shift != 0:
params = trans_key(params, key_shift)
# Speaker mix
spk_mix = parse_commandline_spk_mix(None) if hparams.get("use_spk_id") else None
for p in params:
if gender is not None and hparams.get("use_key_shift_embed"):
p["gender"] = gender
if spk_mix is not None:
p["spk_mix"] = spk_mix
# ==== Variance Inference ==== #
print(f"[pipeline] Loading variance exp: {variance_exp}")
variance_config_path = os.path.join(HF_CHECKPOINTS_DIR, variance_exp, "config.yaml")
sys.argv = [
"",
"--config", variance_config_path,
"--exp_name", variance_exp,
"--infer"
]
set_hparams(print_hparams=False)
print("[pipeline] Variance hparams keys:", sorted(hparams.keys()))
var_infer = DiffSingerVarianceInfer(ckpt_steps=None, predictions={"dur", "pitch"})
ds_out = output_dir / f"{title}.ds"
var_infer.run_inference(params, out_dir=output_dir, title=title, num_runs=1, seed=seed)
if not ds_out.exists():
raise RuntimeError(f"Variance inference failed; missing {ds_out}")
# Reload params from variance output
with open(ds_out, "r", encoding="utf-8") as f:
params = json.load(f)
if not isinstance(params, list):
params = [params]
# Validate variance output DS
for p in params:
try:
validate_ds(p)
except Exception as e:
raise ValueError(f"Invalid DS after variance inference: {e}")
# ==== Acoustic Inference ==== #
print(f"[pipeline] Loading acoustic exp: {acoustic_exp}")
acoustic_config_path = os.path.join(HF_CHECKPOINTS_DIR, acoustic_exp, "config.yaml")
sys.argv = [
"",
"--config", acoustic_config_path,
"--exp_name", acoustic_exp,
"--infer"
]
set_hparams(print_hparams=False)
print("[pipeline] Acoustic hparams keys:", sorted(hparams.keys()))
ac_infer = DiffSingerAcousticInfer(load_vocoder=True, ckpt_steps=None)
ac_infer.run_inference(params, out_dir=output_dir, title=title, num_runs=num_runs, seed=seed)
wav_out = output_dir / f"{title}.wav"
if not wav_out.exists():
raise RuntimeError(f"Acoustic inference failed; missing {wav_out}")
return wav_out
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run full DiffSinger inference pipeline")
parser.add_argument("ds_path", type=Path)
parser.add_argument("output_dir", type=Path)
parser.add_argument("--title", type=str, default=None)
parser.add_argument("--variance_exp", type=str, default="regular_variance_v1")
parser.add_argument("--acoustic_exp", type=str, default="debug_test")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num_runs", type=int, default=1)
parser.add_argument("--key_shift", type=int, default=0)
parser.add_argument("--gender", type=float, default=None)
args = parser.parse_args()
title = args.title or args.ds_path.stem
run_inference(
ds_path=args.ds_path,
output_dir=args.output_dir,
title=title,
variance_exp=args.variance_exp,
acoustic_exp=args.acoustic_exp,
seed=args.seed,
num_runs=args.num_runs,
key_shift=args.key_shift,
gender=args.gender,
)