audio_anagrams / finalproject.py
Mark Ogata
try 1
3006faa
# finalproject.py
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from huggingface_hub import snapshot_download
from converter import Generator, denormalize_spectrogram
import matplotlib.pyplot as plt
import numpy as np
from io import BytesIO
import base64
from PIL import Image
# Model Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16
MODEL_NAME = "auffusion/auffusion-full-no-adapter"
model_path = snapshot_download(MODEL_NAME)
vocoder = Generator.from_pretrained(model_path, subfolder="vocoder").to(device=device, dtype=dtype)
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
vae = pipe.vae
tokenizer = pipe.tokenizer
text_encoder = pipe.text_encoder
unet = pipe.unet
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
def get_text_embeds(prompt):
input_ids = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, return_tensors='pt').input_ids.to(device)
return text_encoder(input_ids)[0]
@torch.no_grad()
def generate_reversible_audio(prompt1: str, prompt2: str, steps: int = 100, guidance_scale=7.5, weight=0.5, seed=42):
# Text Embeddings
cond1 = get_text_embeds(prompt1)
uncond1 = get_text_embeds("")
embed1 = torch.cat([uncond1, cond1], dim=0)
cond2 = get_text_embeds(prompt2)
uncond2 = get_text_embeds("")
embed2 = torch.cat([uncond2, cond2], dim=0)
# Latents
generator = torch.Generator(device=device).manual_seed(seed)
latents = torch.randn((1, unet.config.in_channels, 32, 128), generator=generator, dtype=dtype, device=device)
latents = latents * scheduler.init_noise_sigma
scheduler.set_timesteps(steps)
for t in scheduler.timesteps:
noise1 = estimate_noise(t, latents, embed1, guidance_scale)
noise2 = estimate_noise(t, latents.flip(3), embed2, guidance_scale).flip(3)
noise = weight * noise1 + (1 - weight) * noise2
latents = scheduler.step(noise, t, latents).prev_sample
spec = decode_latents(latents).squeeze(0)
denorm = denormalize_spectrogram(spec)
audio = vocoder.inference(denorm)
# Flip and reverse
flipped_spec = spec.flip(dims=[2])
flipped_denorm = denormalize_spectrogram(flipped_spec)
flipped_audio = vocoder.inference(flipped_denorm)
return spec, audio, flipped_spec, flipped_audio
@torch.no_grad()
def estimate_noise(t, latents, embeds, guidance_scale=7.5):
input_latents = torch.cat([latents] * 2)
noise_pred = unet(input_latents, t, encoder_hidden_states=embeds)["sample"]
uncond, cond = noise_pred.chunk(2)
return uncond + guidance_scale * (cond - uncond)
@torch.no_grad()
def decode_latents(latents):
latents = 1 / vae.config.scaling_factor * latents
image = vae.decode(latents).sample
return (image / 2 + 0.5).clamp(0, 1)
def plot_spectrogram(spec):
array = spec.detach().cpu().permute(1, 2, 0).numpy()
fig, ax = plt.subplots()
ax.imshow(array)
ax.axis("off")
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
plt.close(fig)
buf.seek(0)
return base64.b64encode(buf.read()).decode("utf-8")