Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
import torch.nn.functional as F | |
from dataclasses import dataclass | |
from pytorch_lightning import seed_everything | |
from src.model.pipeline import AudioLDMPipeline, TangoPipeline | |
from src.utils.utils import ( | |
process_move, | |
process_paste, | |
process_remove, | |
) | |
from src.utils.audio_processing import TacotronSTFT, wav_to_fbank, maybe_add_dimension | |
from src.utils.factory import slerp, fill_with_neighbor, optimize_neighborhood_points | |
# NUM_DDIM_STEPS = 50 # 50 | |
SIZES = { | |
0: 4, | |
1: 2, | |
2: 1, | |
3: 1, | |
} | |
class SoundEditorOutput: | |
waveform: torch.tensor | |
mel_spectrogram: torch.tensor | |
class AudioMorphix: | |
def __init__( | |
self, | |
pretrained_model_path, | |
num_ddim_steps=50, | |
device = "cuda" if torch.cuda.is_available() else "cpu", | |
): | |
self.ip_scale = 0.1 | |
self.precision = torch.float32 # torch.float16 | |
if "audioldm" in pretrained_model_path: | |
_pipe_cls = AudioLDMPipeline | |
elif "tango" in pretrained_model_path: | |
_pipe_cls = TangoPipeline | |
self.editor = _pipe_cls( | |
sd_id=pretrained_model_path, | |
NUM_DDIM_STEPS=num_ddim_steps, | |
precision=self.precision, | |
ip_scale=self.ip_scale, | |
device=device, | |
) | |
self.up_ft_index = [2, 3] # fixed in gradio demo # TODO: change to 2,3 | |
self.up_scale = 2 # fixed in gradio demo | |
self.device = device | |
self.num_ddim_steps = num_ddim_steps | |
def to(self, device): | |
self.editor.pipe = self.editor.pipe.to(device) | |
self.editor.pipe._device = device | |
self.editor.device = device | |
self.device = device | |
def run_move( | |
self, | |
fbank_org, | |
mask, | |
dx, dy, | |
mask_ref, | |
prompt, | |
resize_scale_x, | |
resize_scale_y, | |
w_edit, | |
w_content, | |
w_contrast, | |
w_inpaint, | |
seed, | |
guidance_scale, | |
energy_scale, | |
SDE_strength, | |
mask_keep=None, | |
ip_scale=None, | |
save_kv=False, | |
disable_tangent_proj=False, | |
scale_denoised=True, | |
): | |
seed_everything(seed) | |
energy_scale = energy_scale * 1e3 | |
# Prepare input spec and mask | |
input_scale = 1 | |
fbank_org = maybe_add_dimension(fbank_org, 4).to( | |
self.device, dtype=self.precision | |
) # shape = (B,C,T,F) | |
f, t = fbank_org.shape[-1], fbank_org.shape[-2] | |
if save_kv: | |
self.editor.load_adapter() | |
### FIXME | |
if mask_ref is not None and np.sum(mask_ref) != 0: | |
mask_ref = np.repeat(mask_ref[:,:,None], 3, 2) | |
else: | |
mask_ref = None | |
latent = self.editor.fbank2latent(fbank_org) | |
ddim_latents = self.editor.ddim_inv(latent=latent, prompt=prompt) | |
latent_in = ddim_latents[-1].squeeze(2) | |
scale = 4 * SIZES[max(self.up_ft_index)] / self.up_scale | |
edit_kwargs = process_move( | |
path_mask=mask, | |
h=f, | |
w=t, | |
dx=dx, | |
dy=dy, | |
scale=scale, | |
input_scale=input_scale, | |
resize_scale_x=resize_scale_x, | |
resize_scale_y=resize_scale_y, | |
up_scale=self.up_scale, | |
up_ft_index=self.up_ft_index, | |
w_edit=w_edit, | |
w_content=w_content, | |
w_contrast=w_contrast, | |
w_inpaint=w_inpaint, | |
precision=self.precision, | |
path_mask_ref=mask_ref, | |
path_mask_keep=mask_keep, | |
) | |
# Pre-process zT | |
mask_tmp = (F.interpolate(mask.unsqueeze(0).unsqueeze(0), (int(latent_in.shape[-2]*resize_scale_y), int(latent_in.shape[-1]*resize_scale_x)))>0).float().to('cuda', dtype=latent_in.dtype) | |
latent_tmp = F.interpolate(latent_in, (int(latent_in.shape[-2]*resize_scale_y), int(latent_in.shape[-1]*resize_scale_x))) | |
mask_tmp = torch.roll(mask_tmp, (int(dy/(t/latent_in.shape[-2])*resize_scale_y), int(dx/(t/latent_in.shape[-2])*resize_scale_x)), (-2,-1)) | |
latent_tmp = torch.roll(latent_tmp, (int(dy/(t/latent_in.shape[-2])*resize_scale_y), int(dx/(t/latent_in.shape[-2])*resize_scale_x)), (-2,-1)) | |
_mask_temp = torch.zeros(1,1,latent_in.shape[-2], latent_in.shape[-1]).to( | |
latent_in.device, dtype=latent_in.dtype) | |
_latent_temp = torch.zeros_like(latent_in) | |
pad_x = (_mask_temp.shape[-1] - mask_tmp.shape[-1]) // 2 | |
pad_y = (_mask_temp.shape[-2] - mask_tmp.shape[-2]) // 2 | |
px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
_mask_temp[:,:,py_tmp:mask_tmp.shape[-2]+py_tmp,px_tmp:mask_tmp.shape[-1]+px_tmp] = mask_tmp[ | |
:,:,py_tar:_mask_temp.shape[-2]+py_tar,px_tar:_mask_temp.shape[-1]+px_tar] | |
_latent_temp[:,:,py_tmp:latent_tmp.shape[-2]+py_tmp,px_tmp:latent_tmp.shape[-1]+px_tmp] = latent_tmp[ | |
:,:,py_tar:_latent_temp.shape[-2]+py_tar,px_tar:_latent_temp.shape[-1]+px_tar] | |
mask_tmp = (_mask_temp>0.5).float() | |
latent_tmp = _latent_temp | |
if edit_kwargs["mask_keep"] is not None: | |
mask_keep = edit_kwargs["mask_keep"] | |
mask_keep = (F.interpolate(mask_keep, (latent_in.shape[-2], latent_in.shape[-1]))>0).float().to('cuda', dtype=latent_in.dtype) | |
else: | |
mask_keep = 1 - mask_tmp | |
latent_in = (torch.zeros_like(latent_in)+latent_in*mask_keep+latent_tmp*mask_tmp).to(dtype=latent_in.dtype) | |
latent_rec = self.editor.pipe.edit( | |
mode='move', | |
latent=latent_in, | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
energy_scale=energy_scale, | |
latent_noise_ref=ddim_latents, | |
SDE_strength=SDE_strength, | |
edit_kwargs=edit_kwargs, | |
disable_tangent_proj=disable_tangent_proj, | |
) | |
# Scale output latent | |
if scale_denoised: | |
_max = torch.max(torch.abs(latent_rec)) | |
latent_rec = latent_rec * 5 / _max | |
spec_rec = self.editor.decode_latents(latent_rec) | |
wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
torch.cuda.empty_cache() | |
return SoundEditorOutput(wav_rc, spec_rec) | |
def run_paste( | |
self, | |
fbank_bg, | |
mask_bg, | |
fbank_fg, | |
prompt, | |
prompt_replace, | |
w_edit, | |
w_content, | |
seed, | |
guidance_scale, | |
energy_scale, | |
dx, | |
dy, | |
resize_scale_x, | |
resize_scale_y, | |
SDE_strength, | |
save_kv=False, | |
disable_tangent_proj=False, | |
scale_denoised=True, | |
): | |
seed_everything(seed) | |
energy_scale = energy_scale * 1e3 | |
# Prepare input spec and mask | |
input_scale = 1 | |
fbank_bg = maybe_add_dimension(fbank_bg, 4).to( | |
self.device, dtype=self.precision | |
) # shape = (B,C,T,F) | |
f, t = fbank_bg.shape[-1], fbank_bg.shape[-2] | |
fbank_fg = maybe_add_dimension(fbank_fg, 4).to( | |
self.device, dtype=self.precision | |
) | |
# mask_bg = maybe_add_dimension(mask_bg, 3).permute(1,2,0).numpy().astype('uint8') # shape = (C,T,F) | |
# mask_bg = mask_bg.numpy().astype('uint8') # shape = (C,T,F) | |
if save_kv: | |
self.editor.load_adapter() | |
latent_base = self.editor.fbank2latent(fbank_bg) | |
#####[START] Original rescale and fit method.##### | |
# if resize_scale != 1: | |
# hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
# fbank_fg = F.interpolate( | |
# fbank_fg, (int(hr * resize_scale), int(wr * resize_scale)) | |
# ) | |
# pad_size_x = abs(fbank_fg.shape[-1] - wr) // 2 | |
# pad_size_y = abs(fbank_fg.shape[-2] - hr) // 2 | |
# if resize_scale > 1: | |
# fbank_fg = fbank_fg[ | |
# :, :, pad_size_y : pad_size_y + hr, pad_size_x : pad_size_x + wr | |
# ] | |
# else: | |
# temp = torch.zeros(1, 1, hr, wr).to(self.device, dtype=self.precision) | |
# temp[ | |
# :, | |
# :, | |
# pad_size_y : pad_size_y + fbank_fg.shape[-2], | |
# pad_size_x : pad_size_x + fbank_fg.shape[-1], | |
# ] = fbank_fg | |
# fbank_fg = temp | |
#####[END] Original rescale and fit method.##### | |
hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
fbank_tmp = torch.zeros_like(fbank_fg) | |
fbank_fg = F.interpolate( | |
fbank_fg, (int(hr * resize_scale_y), int(wr * resize_scale_x)) | |
) | |
pad_x = (wr - fbank_fg.shape[-1]) // 2 | |
pad_y = (hr - fbank_fg.shape[-2]) // 2 | |
px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
fbank_tmp[:,:,py_tmp:fbank_fg.shape[-2]+py_tmp,px_tmp:fbank_fg.shape[-1]+px_tmp] = fbank_fg[ | |
:,:,py_tar:fbank_tmp.shape[-2]+py_tar,px_tar:fbank_tmp.shape[-1]+px_tar] | |
fbank_fg = fbank_tmp | |
latent_replace = self.editor.fbank2latent(fbank_fg) | |
ddim_latents = self.editor.ddim_inv( | |
latent=torch.cat([latent_base, latent_replace]), | |
prompt=[prompt, prompt_replace], | |
) | |
latent_in = ddim_latents[-1][:1].squeeze(2) # latent_base_noise | |
scale = 8 * SIZES[max(self.up_ft_index)] / self.up_scale / 2 | |
edit_kwargs = process_paste( | |
path_mask=mask_bg, | |
h=f, | |
w=t, | |
dx=dx, | |
dy=dy, | |
scale=scale, | |
input_scale=input_scale, | |
up_scale=self.up_scale, | |
up_ft_index=self.up_ft_index, | |
w_edit=w_edit, | |
w_content=w_content, | |
precision=self.precision, | |
resize_scale_x=resize_scale_x, | |
resize_scale_y=resize_scale_y, | |
) | |
mask_tmp = ( | |
F.interpolate( | |
edit_kwargs["mask_base_cur"].float(), | |
(latent_in.shape[-2], latent_in.shape[-1]), | |
) | |
> 0 | |
).float() | |
# latent_replace_noise with rolling | |
latent_tmp = torch.roll( | |
ddim_latents[-1][1:].squeeze(2), | |
(int(dy / (t / latent_in.shape[-2])), int(dx / (t / latent_in.shape[-2]))), | |
(-2, -1), | |
) | |
# blended latent | |
latent_in = (latent_in * (1 - mask_tmp) + latent_tmp * mask_tmp).to( | |
dtype=latent_in.dtype | |
) | |
latent_rec = self.editor.pipe.edit( | |
mode="paste", | |
latent=latent_in, | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
energy_scale=energy_scale, | |
latent_noise_ref=ddim_latents, | |
SDE_strength=SDE_strength, | |
edit_kwargs=edit_kwargs, | |
disable_tangent_proj=disable_tangent_proj, | |
) | |
# Scale output latent | |
if scale_denoised: | |
_max = torch.max(torch.abs(latent_rec)) | |
latent_rec = latent_rec * 5 / _max | |
spec_rec = self.editor.decode_latents(latent_rec) | |
wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
torch.cuda.empty_cache() | |
return SoundEditorOutput(wav_rc, spec_rec) | |
def run_mix( | |
self, | |
fbank_bg, | |
mask_bg, | |
fbank_fg, | |
prompt, | |
prompt_replace, | |
w_edit, | |
w_content, | |
seed, | |
guidance_scale, | |
energy_scale, | |
dx, | |
dy, | |
resize_scale_x, | |
resize_scale_y, | |
SDE_strength, | |
save_kv=False, | |
bg_to_fg_ratio=0.7, | |
disable_tangent_proj=False, | |
scale_denoised=False, | |
): | |
seed_everything(seed) | |
energy_scale = energy_scale * 1e3 | |
# Prepare input spec and mask | |
input_scale = 1 | |
fbank_bg = maybe_add_dimension(fbank_bg, 4).to( | |
self.device, dtype=self.precision | |
) # shape = (B,C,T,F) | |
f, t = fbank_bg.shape[-1], fbank_bg.shape[-2] | |
fbank_fg = maybe_add_dimension(fbank_fg, 4).to( | |
self.device, dtype=self.precision | |
) | |
if save_kv: | |
self.editor.load_adapter() | |
latent_base = self.editor.fbank2latent(fbank_bg) | |
#####[START] Original rescale and fit method.##### | |
# if resize_scale != 1: | |
# hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
# fbank_fg = F.interpolate( | |
# fbank_fg, (int(hr * resize_scale), int(wr * resize_scale)) | |
# ) | |
# pad_size_x = abs(fbank_fg.shape[-1] - wr) // 2 | |
# pad_size_y = abs(fbank_fg.shape[-2] - hr) // 2 | |
# if resize_scale > 1: | |
# fbank_fg = fbank_fg[ | |
# :, :, pad_size_y : pad_size_y + hr, pad_size_x : pad_size_x + wr | |
# ] | |
# else: | |
# temp = torch.zeros(1, 1, hr, wr).to(self.device, dtype=self.precision) | |
# temp[ | |
# :, | |
# :, | |
# pad_size_y : pad_size_y + fbank_fg.shape[-2], | |
# pad_size_x : pad_size_x + fbank_fg.shape[-1], | |
# ] = fbank_fg | |
# fbank_fg = temp | |
#####[END] Original rescale and fit method.##### | |
hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
fbank_tmp = torch.zeros_like(fbank_fg) | |
fbank_fg = F.interpolate( | |
fbank_fg, (int(hr * resize_scale_y), int(wr * resize_scale_x)) | |
) | |
pad_x = (wr - fbank_fg.shape[-1]) // 2 | |
pad_y = (hr - fbank_fg.shape[-2]) // 2 | |
px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
fbank_tmp[:,:,py_tmp:fbank_fg.shape[-2]+py_tmp,px_tmp:fbank_fg.shape[-1]+px_tmp] = fbank_fg[ | |
:,:,py_tar:fbank_tmp.shape[-2]+py_tar,px_tar:fbank_tmp.shape[-1]+px_tar] | |
fbank_fg = fbank_tmp | |
latent_replace = self.editor.fbank2latent(fbank_fg) | |
ddim_latents = self.editor.ddim_inv( | |
latent=torch.cat([latent_base, latent_replace]), | |
prompt=[prompt, prompt_replace], | |
) | |
latent_in = ddim_latents[-1][:1].squeeze(2) # latent_base_noise | |
# TODO: adapt it to different Gen models | |
scale = 4 * SIZES[max(self.up_ft_index)] / self.up_scale | |
edit_kwargs = process_paste( | |
path_mask=mask_bg, | |
h=f, | |
w=t, | |
dx=dx, | |
dy=dy, | |
scale=scale, | |
input_scale=input_scale, | |
up_scale=self.up_scale, | |
up_ft_index=self.up_ft_index, | |
w_edit=w_edit, | |
w_content=w_content, | |
precision=self.precision, | |
resize_scale_x=resize_scale_x, | |
resize_scale_y=resize_scale_y, | |
) | |
mask_tmp = ( | |
F.interpolate( | |
edit_kwargs["mask_base_cur"].float(), | |
(latent_in.shape[-2], latent_in.shape[-1]), | |
) | |
> 0 | |
).float() | |
# latent_replace_noise with rolling | |
latent_tmp = torch.roll( | |
ddim_latents[-1][1:].squeeze(2), | |
(int(dy / (t / latent_in.shape[-2])), int(dx / (t / latent_in.shape[-2]))), | |
(-2, -1), | |
) | |
latent_mix = slerp(bg_to_fg_ratio, latent_in, latent_tmp) | |
latent_in = (latent_in * (1 - mask_tmp) + latent_mix * mask_tmp).to( | |
dtype=latent_in.dtype | |
) | |
latent_rec = self.editor.pipe.edit( | |
mode="mix", | |
latent=latent_in, | |
prompt=prompt, # NOTE: emperically, make the rec the same as prompt base is the best | |
guidance_scale=guidance_scale, | |
energy_scale=energy_scale, | |
latent_noise_ref=ddim_latents, | |
SDE_strength=SDE_strength, | |
edit_kwargs=edit_kwargs, | |
disable_tangent_proj=disable_tangent_proj, | |
) | |
# Scale output latent | |
if scale_denoised: | |
_max = torch.max(torch.abs(latent_rec)) | |
latent_rec = latent_rec * 5 / _max | |
spec_rec = self.editor.decode_latents(latent_rec) | |
wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
torch.cuda.empty_cache() | |
return SoundEditorOutput(wav_rc, spec_rec) | |
def run_remove( | |
self, | |
fbank_bg, | |
mask_bg, | |
fbank_fg, | |
prompt, | |
prompt_replace, | |
w_edit, | |
w_contrast, | |
w_content, | |
seed, | |
guidance_scale, | |
energy_scale, | |
dx, | |
dy, | |
resize_scale_x, | |
resize_scale_y, | |
SDE_strength, | |
save_kv=False, | |
bg_to_fg_ratio=0.5, | |
iterations=50, | |
enable_penalty=True, | |
disable_tangent_proj=False, | |
scale_denoised=True, | |
): | |
seed_everything(seed) | |
energy_scale = energy_scale * 1e3 | |
# Prepare input spec and mask | |
input_scale = 1 | |
fbank_bg = maybe_add_dimension(fbank_bg, 4).to( | |
self.device, dtype=self.precision | |
) # shape = (B,C,T,F) | |
f, t = fbank_bg.shape[-1], fbank_bg.shape[-2] | |
fbank_fg = maybe_add_dimension(fbank_fg, 4).to( | |
self.device, dtype=self.precision | |
) | |
if save_kv: | |
self.editor.load_adapter() | |
latent_base = self.editor.fbank2latent(fbank_bg) | |
#####[START] Original rescale and fit method.##### | |
# if resize_scale != 1: | |
# hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
# fbank_fg = F.interpolate( | |
# fbank_fg, (int(hr * resize_scale), int(wr * resize_scale)) | |
# ) | |
# pad_size_x = abs(fbank_fg.shape[-1] - wr) // 2 | |
# pad_size_y = abs(fbank_fg.shape[-2] - hr) // 2 | |
# if resize_scale > 1: | |
# fbank_fg = fbank_fg[ | |
# :, :, pad_size_y : pad_size_y + hr, pad_size_x : pad_size_x + wr | |
# ] | |
# else: | |
# temp = torch.zeros(1, 1, hr, wr).to(self.device, dtype=self.precision) | |
# temp[ | |
# :, | |
# :, | |
# pad_size_y : pad_size_y + fbank_fg.shape[-2], | |
# pad_size_x : pad_size_x + fbank_fg.shape[-1], | |
# ] = fbank_fg | |
# fbank_fg = temp | |
#####[END] Original rescale and fit method.##### | |
hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
fbank_tmp = torch.zeros_like(fbank_fg) | |
fbank_fg = F.interpolate( | |
fbank_fg, (int(hr * resize_scale_y), int(wr * resize_scale_x)) | |
) | |
pad_x = (wr - fbank_fg.shape[-1]) // 2 | |
pad_y = (hr - fbank_fg.shape[-2]) // 2 | |
px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
fbank_tmp[:,:,py_tmp:fbank_fg.shape[-2]+py_tmp,px_tmp:fbank_fg.shape[-1]+px_tmp] = fbank_fg[ | |
:,:,py_tar:fbank_tmp.shape[-2]+py_tar,px_tar:fbank_tmp.shape[-1]+px_tar] | |
fbank_fg = fbank_tmp | |
latent_replace = self.editor.fbank2latent(fbank_fg) | |
ddim_latents = self.editor.ddim_inv( | |
latent=torch.cat([latent_base, latent_replace]), | |
prompt=[prompt, prompt_replace], | |
) | |
latent_in = ddim_latents[-1][:1].squeeze(2) | |
# TODO: adapt it to different Gen models | |
scale = 4 * SIZES[max(self.up_ft_index)] / self.up_scale | |
edit_kwargs = process_remove( | |
path_mask=mask_bg, | |
h=f, | |
w=t, | |
dx=dx, | |
dy=dy, | |
scale=scale, | |
input_scale=input_scale, | |
up_scale=self.up_scale, | |
up_ft_index=self.up_ft_index, | |
w_edit=w_edit, | |
w_contrast=w_contrast, | |
w_content=w_content, | |
precision=self.precision, | |
resize_scale_x=resize_scale_x, | |
resize_scale_y=resize_scale_y, | |
) | |
mask_tmp = ( | |
F.interpolate( | |
edit_kwargs["mask_base_cur"].float(), | |
(latent_in.shape[-2], latent_in.shape[-1]), | |
) | |
> 0 | |
).float() | |
latent_tmp = torch.roll( | |
ddim_latents[-1][1:].squeeze(2), | |
(int(dy / (t / latent_in.shape[-2])), int(dx / (t / latent_in.shape[-2]))), | |
(-2, -1), | |
) | |
# # F(B) <- F(M) - a * F(A) | |
# latent_new = torch.randn_like(latent_tmp) | |
# # latent_tmp = latent_tmp * latent_in.max()/latent_tmp.max() * 0.6 # 0.6 is the scale factor, a | |
# m_ori, s_ori = latent_new.mean(dim=-2, keepdim=True), latent_new.std(dim=-2, keepdim=True) | |
# # m_ref, s_ref = latent_tmp.mean(dim=-2, keepdim=True), latent_tmp.std(dim=-2, keepdim=True) | |
# m_src, s_src = latent_in.mean(dim=-2, keepdim=True), latent_in.std(dim=-2, keepdim=True) | |
# # s_new = torch.sqrt(s_src**2 - s_ref**2) | |
# # latent_new = (latent_new - m_ori) / s_ori * s_new + (m_src - m_ref) | |
# latent_new = (latent_new - m_ori) / s_ori * s_src + m_src | |
# # Start from the latent of neighbor region | |
# _m = mask_tmp.squeeze().sum(dim=1).nonzero().cpu() | |
# stt_frame, end_frame = _m.min(), _m.max() | |
# latent_neighbor = fill_with_neighbor( | |
# latent_in.squeeze(0), stt_frame, end_frame, neighbor_length=100 | |
# ) # 1s | |
# __neighbor_energy_per_freq = (latent_neighbor*mask_tmp).mean(dim=0) | |
# latent_neighbor[:,:,8:] *= 0.0001 | |
# Latent neighbor start from randomlized latent | |
latent_neighbor = torch.randn_like(latent_in.squeeze(0)) * 0.9 | |
latent_neighbor = latent_neighbor + torch.randn_like(latent_neighbor) * 1e-3 # a little perturbation | |
latent_neighbor, _ = optimize_neighborhood_points( | |
latent_neighbor * mask_tmp, | |
latent_tmp * mask_tmp, | |
latent_in * mask_tmp, | |
t=bg_to_fg_ratio, | |
iterations=iterations, | |
enable_penalty=enable_penalty, | |
enable_tangent_proj=True, | |
) # TODO: try to turn off tangent | |
latent_in = (latent_in * (1 - mask_tmp) + latent_neighbor * mask_tmp).to( | |
dtype=latent_in.dtype | |
) | |
# latent_neighbor = torch.randn_like(latent_in) * 0.9 | |
# latent_in = (latent_in * (1 - mask_tmp) + latent_neighbor * mask_tmp).to( | |
# dtype=latent_in.dtype | |
# ) | |
latent_rec = self.editor.pipe.edit( | |
mode="remove", | |
latent=latent_in, | |
prompt=prompt, | |
guidance_scale=guidance_scale, | |
energy_scale=energy_scale, | |
latent_noise_ref=ddim_latents, | |
SDE_strength=SDE_strength, | |
edit_kwargs=edit_kwargs, | |
num_inference_steps=self.num_ddim_steps, | |
start_time=self.num_ddim_steps, | |
disable_tangent_proj=disable_tangent_proj, | |
) | |
# Scale output latent | |
if scale_denoised: | |
_max = torch.max(torch.abs(latent_rec)) | |
latent_rec = latent_rec * 5 / _max | |
spec_rec = self.editor.decode_latents(latent_rec) | |
wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
torch.cuda.empty_cache() | |
return SoundEditorOutput(wav_rc, spec_rec) | |
def run_audio_generation( | |
self, | |
fbank_bg, | |
mask_bg, | |
fbank_fg, | |
prompt, | |
prompt_replace, | |
w_edit, | |
w_content, | |
seed, | |
guidance_scale, | |
energy_scale, | |
dx, | |
dy, | |
resize_scale_x, | |
resize_scale_y, | |
SDE_strength, | |
save_kv=False, | |
disable_tangent_proj=False, | |
scale_denoised=True, | |
): | |
seed_everything(seed) | |
energy_scale = energy_scale * 1e3 | |
# Prepare input spec and mask | |
input_scale = 1 | |
fbank_bg = maybe_add_dimension(fbank_bg, 4).to( | |
self.device, dtype=self.precision | |
) # shape = (B,C,T,F) | |
f, t = fbank_bg.shape[-1], fbank_bg.shape[-2] | |
fbank_fg = maybe_add_dimension(fbank_fg, 4).to( | |
self.device, dtype=self.precision | |
) | |
if save_kv: | |
self.editor.load_adapter() | |
latent_base = self.editor.fbank2latent(fbank_bg) | |
#####[START] Original rescale and fit method.##### | |
# if resize_scale != 1: | |
# hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
# fbank_fg = F.interpolate( | |
# fbank_fg, (int(hr * resize_scale), int(wr * resize_scale)) | |
# ) | |
# pad_size_x = abs(fbank_fg.shape[-1] - wr) // 2 | |
# pad_size_y = abs(fbank_fg.shape[-2] - hr) // 2 | |
# if resize_scale > 1: | |
# fbank_fg = fbank_fg[ | |
# :, :, pad_size_y : pad_size_y + hr, pad_size_x : pad_size_x + wr | |
# ] | |
# else: | |
# temp = torch.zeros(1, 1, hr, wr).to(self.device, dtype=self.precision) | |
# temp[ | |
# :, | |
# :, | |
# pad_size_y : pad_size_y + fbank_fg.shape[-2], | |
# pad_size_x : pad_size_x + fbank_fg.shape[-1], | |
# ] = fbank_fg | |
# fbank_fg = temp | |
#####[END] Original rescale and fit method.##### | |
hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
fbank_tmp = torch.zeros_like(fbank_fg) | |
fbank_fg = F.interpolate( | |
fbank_fg, (int(hr * resize_scale_y), int(wr * resize_scale_x)) | |
) | |
pad_x = (wr - fbank_fg.shape[-1]) // 2 | |
pad_y = (hr - fbank_fg.shape[-2]) // 2 | |
px_tmp, py_tmp = max(pad_x, 0), max(pad_y, 0) | |
px_tar, py_tar = max(-pad_x, 0), max(-pad_y, 0) | |
fbank_tmp[:,:,py_tmp:fbank_fg.shape[-2]+py_tmp,px_tmp:fbank_fg.shape[-1]+px_tmp] = fbank_fg[ | |
:,:,py_tar:fbank_tmp.shape[-2]+py_tar,px_tar:fbank_tmp.shape[-1]+px_tar] | |
fbank_fg = fbank_tmp | |
ddim_latents = self.editor.ddim_inv( | |
latent=torch.cat([latent_base, latent_base]), prompt=[prompt, prompt] | |
) | |
latent_in = ddim_latents[-1][:1].squeeze(2) | |
# TODO: adapt it to different Gen models | |
scale = 4 * SIZES[max(self.up_ft_index)] / self.up_scale | |
edit_kwargs = process_paste( | |
path_mask=mask_bg, | |
h=f, | |
w=t, | |
dx=dx, | |
dy=dy, | |
scale=scale, | |
input_scale=input_scale, | |
up_scale=self.up_scale, | |
up_ft_index=self.up_ft_index, | |
w_edit=w_edit, | |
w_content=w_content, | |
precision=self.precision, | |
resize_scale_x=resize_scale_x, | |
resize_scale_y=resize_scale_y, | |
) | |
latent_tmp = torch.randn_like(latent_in) | |
mean, std = latent_in.mean(dim=-1, keepdim=True), latent_in.std( | |
dim=-1, keepdim=True | |
) | |
m_ori, s_ori = latent_tmp.mean(dim=-1, keepdim=True), latent_tmp.std( | |
dim=-1, keepdim=True | |
) | |
latent_tmp = (latent_tmp - m_ori) / s_ori * std + mean | |
latent_in = latent_tmp | |
latent_rec = self.editor.pipe.edit( | |
mode="generate", | |
latent=latent_in, | |
prompt=prompt_replace, | |
guidance_scale=guidance_scale, | |
energy_scale=0, | |
latent_noise_ref=ddim_latents, | |
SDE_strength=SDE_strength, | |
edit_kwargs=edit_kwargs, | |
num_inference_steps=self.num_ddim_steps, | |
start_time=self.num_ddim_steps, | |
alg="D", | |
disable_tangent_proj=disable_tangent_proj, | |
) | |
# Scale output latent | |
if scale_denoised: | |
_max = torch.max(torch.abs(latent_rec)) | |
latent_rec = latent_rec * 5 / _max | |
spec_rec = self.editor.decode_latents(latent_rec) | |
wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
torch.cuda.empty_cache() | |
return SoundEditorOutput(wav_rc, spec_rec) | |
def run_style_transferring( | |
self, | |
fbank_bg, | |
mask_bg, | |
fbank_fg, | |
prompt, | |
prompt_replace, | |
w_edit, | |
w_content, | |
seed, | |
guidance_scale, | |
energy_scale, | |
dx, | |
dy, | |
resize_scale_x, | |
resize_scale_y, | |
SDE_strength, | |
save_kv=True, | |
disable_tangent_proj=False, | |
scale_denoised=True, | |
): | |
seed_everything(seed) | |
energy_scale = energy_scale * 1e3 | |
# Prepare input spec and mask | |
input_scale = 1 | |
fbank_bg = maybe_add_dimension(fbank_bg, 4).to( | |
self.device, dtype=self.precision | |
) # shape = (B,C,T,F) | |
f, t = fbank_bg.shape[-1], fbank_bg.shape[-2] | |
if save_kv: | |
self.editor.load_adapter() | |
latent_base = self.editor.fbank2latent(fbank_bg) | |
# if(torch.max(torch.abs(latent_base)) > 1e2): | |
# latent_base = torch.clip(latent_base, min=-10, max=10) | |
ddim_latents = self.editor.ddim_inv(latent=latent_base, prompt=prompt, | |
save_kv=True, mode="style_transfer",) | |
latent_in = ddim_latents[-1].squeeze(2) | |
scale = 4 * SIZES[max(self.up_ft_index)] / self.up_scale | |
edit_kwargs = process_paste( | |
path_mask=mask_bg, | |
h=f, | |
w=t, | |
dx=dx, | |
dy=dy, | |
scale=scale, | |
input_scale=input_scale, | |
up_scale=self.up_scale, | |
up_ft_index=self.up_ft_index, | |
w_edit=w_edit, | |
w_content=w_content, | |
precision=self.precision, | |
resize_scale_x=resize_scale_x, | |
resize_scale_y=resize_scale_y, | |
) | |
# latent_tmp = torch.randn_like(latent_in) | |
# mean, std = latent_in.mean(dim=-1, keepdim=True), latent_in.std(dim=-1, keepdim=True) | |
# m_ori, s_ori = latent_tmp.mean(dim=-1, keepdim=True), latent_tmp.std(dim=-1, keepdim=True) | |
# latent_tmp = (latent_tmp - m_ori) / s_ori * std + mean | |
# latent_in = latent_tmp | |
# import pdb; pdb.set_trace() | |
latent_rec = self.editor.pipe.edit( | |
mode="style_transfer", | |
latent=latent_in, | |
prompt=prompt_replace, | |
guidance_scale=guidance_scale, | |
energy_scale=energy_scale, | |
latent_noise_ref=ddim_latents, | |
SDE_strength=SDE_strength, | |
edit_kwargs=edit_kwargs, | |
num_inference_steps=self.num_ddim_steps, | |
start_time=self.num_ddim_steps, | |
alg="D", | |
disable_tangent_proj=disable_tangent_proj, | |
) | |
# Scale output latent | |
if scale_denoised: | |
_max = torch.max(torch.abs(latent_rec)) | |
latent_rec = latent_rec * 5 / _max | |
spec_rec = self.editor.decode_latents(latent_rec) | |
wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
torch.cuda.empty_cache() | |
return SoundEditorOutput(wav_rc, spec_rec) | |
def run_ddim_inversion( | |
self, | |
fbank_bg, | |
mask_bg, | |
fbank_fg, | |
prompt, | |
prompt_replace, | |
w_edit, | |
w_content, | |
seed, | |
guidance_scale, | |
energy_scale, | |
dx, | |
dy, | |
resize_scale_x, | |
resize_scale_y, | |
SDE_strength, | |
save_kv=False, | |
disable_tangent_proj=False, | |
scale_denoised=True, | |
): | |
seed_everything(seed) | |
energy_scale = energy_scale * 1e3 | |
# Prepare input spec and mask | |
input_scale = 1 | |
fbank_bg = maybe_add_dimension(fbank_bg, 4).to( | |
self.device, dtype=self.precision | |
) # shape = (B,C,T,F) | |
f, t = fbank_bg.shape[-1], fbank_bg.shape[-2] | |
fbank_fg = maybe_add_dimension(fbank_fg, 4).to( | |
self.device, dtype=self.precision | |
) | |
if save_kv: | |
self.editor.load_adapter() | |
latent_base = self.editor.fbank2latent(fbank_bg) | |
if resize_scale != 1: | |
hr, wr = fbank_fg.shape[-2], fbank_fg.shape[-1] | |
fbank_fg = F.interpolate( | |
fbank_fg, (int(hr * resize_scale), int(wr * resize_scale)) | |
) | |
pad_size_x = abs(fbank_fg.shape[-1] - wr) // 2 | |
pad_size_y = abs(fbank_fg.shape[-2] - hr) // 2 | |
if resize_scale > 1: | |
fbank_fg = fbank_fg[ | |
:, :, pad_size_y : pad_size_y + hr, pad_size_x : pad_size_x + wr | |
] | |
else: | |
temp = torch.zeros(1, 1, hr, wr).to(self.device, dtype=self.precision) | |
temp[ | |
:, | |
:, | |
pad_size_y : pad_size_y + fbank_fg.shape[-2], | |
pad_size_x : pad_size_x + fbank_fg.shape[-1], | |
] = fbank_fg | |
fbank_fg = temp | |
# latent_replace = self.editor.fbank2latent(fbank_fg) | |
ddim_latents = self.editor.ddim_inv( | |
latent=torch.cat([latent_base, latent_base]), prompt=[prompt, prompt] | |
) | |
latent_in = ddim_latents[-1][:1].squeeze(2) | |
# TODO: adapt it to different Gen models | |
scale = 4 * SIZES[max(self.up_ft_index)] / self.up_scale | |
edit_kwargs = process_paste( | |
path_mask=mask_bg, | |
h=f, | |
w=t, | |
dx=dx, | |
dy=dy, | |
scale=scale, | |
input_scale=input_scale, | |
up_scale=self.up_scale, | |
up_ft_index=self.up_ft_index, | |
w_edit=w_edit, | |
w_content=w_content, | |
precision=self.precision, | |
resize_scale_x=resize_scale_x, | |
resize_scale_y=resize_scale_y, | |
) | |
latent_rec = self.editor.pipe.edit( | |
mode="generate", | |
latent=latent_in, | |
prompt=prompt_replace, | |
guidance_scale=guidance_scale, | |
energy_scale=0, | |
latent_noise_ref=ddim_latents, | |
SDE_strength=SDE_strength, | |
edit_kwargs=edit_kwargs, | |
num_inference_steps=self.num_ddim_steps, | |
start_time=self.num_ddim_steps, | |
alg="D", | |
disable_tangent_proj=disable_tangent_proj, | |
) | |
# Scale output latent | |
if scale_denoised: | |
_max = torch.max(torch.abs(latent_rec)) | |
latent_rec = latent_rec * 5 / _max | |
spec_rec = self.editor.decode_latents(latent_rec) | |
wav_rc = self.editor.mel_spectrogram_to_waveform(spec_rec) | |
torch.cuda.empty_cache() | |
return SoundEditorOutput(wav_rc, spec_rec) | |
if __name__ == "__main__": | |
mdl = AudioMorphix( | |
"cvssp/audioldm-l-full", num_ddim_steps=50 | |
) # "cvssp/audioldm-l-full" | "declare-lab/tango" | |
print(mdl.__dict__) | |