Spaces:
Sleeping
Sleeping
# finalproject.py | |
import torch | |
from diffusers import StableDiffusionPipeline, DDIMScheduler | |
from huggingface_hub import snapshot_download | |
from converter import Generator, denormalize_spectrogram | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from io import BytesIO | |
import base64 | |
from PIL import Image | |
# Model Setup | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
dtype = torch.float16 | |
MODEL_NAME = "auffusion/auffusion-full-no-adapter" | |
model_path = snapshot_download(MODEL_NAME) | |
vocoder = Generator.from_pretrained(model_path, subfolder="vocoder").to(device=device, dtype=dtype) | |
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device) | |
vae = pipe.vae | |
tokenizer = pipe.tokenizer | |
text_encoder = pipe.text_encoder | |
unet = pipe.unet | |
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler") | |
def get_text_embeds(prompt): | |
input_ids = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, return_tensors='pt').input_ids.to(device) | |
return text_encoder(input_ids)[0] | |
def generate_reversible_audio(prompt1: str, prompt2: str, steps: int = 100, guidance_scale=7.5, weight=0.5, seed=42): | |
# Text Embeddings | |
cond1 = get_text_embeds(prompt1) | |
uncond1 = get_text_embeds("") | |
embed1 = torch.cat([uncond1, cond1], dim=0) | |
cond2 = get_text_embeds(prompt2) | |
uncond2 = get_text_embeds("") | |
embed2 = torch.cat([uncond2, cond2], dim=0) | |
# Latents | |
generator = torch.Generator(device=device).manual_seed(seed) | |
latents = torch.randn((1, unet.config.in_channels, 32, 128), generator=generator, dtype=dtype, device=device) | |
latents = latents * scheduler.init_noise_sigma | |
scheduler.set_timesteps(steps) | |
for t in scheduler.timesteps: | |
noise1 = estimate_noise(t, latents, embed1, guidance_scale) | |
noise2 = estimate_noise(t, latents.flip(3), embed2, guidance_scale).flip(3) | |
noise = weight * noise1 + (1 - weight) * noise2 | |
latents = scheduler.step(noise, t, latents).prev_sample | |
spec = decode_latents(latents).squeeze(0) | |
denorm = denormalize_spectrogram(spec) | |
audio = vocoder.inference(denorm) | |
# Flip and reverse | |
flipped_spec = spec.flip(dims=[2]) | |
flipped_denorm = denormalize_spectrogram(flipped_spec) | |
flipped_audio = vocoder.inference(flipped_denorm) | |
return spec, audio, flipped_spec, flipped_audio | |
def estimate_noise(t, latents, embeds, guidance_scale=7.5): | |
input_latents = torch.cat([latents] * 2) | |
noise_pred = unet(input_latents, t, encoder_hidden_states=embeds)["sample"] | |
uncond, cond = noise_pred.chunk(2) | |
return uncond + guidance_scale * (cond - uncond) | |
def decode_latents(latents): | |
latents = 1 / vae.config.scaling_factor * latents | |
image = vae.decode(latents).sample | |
return (image / 2 + 0.5).clamp(0, 1) | |
def plot_spectrogram(spec): | |
array = spec.detach().cpu().permute(1, 2, 0).numpy() | |
fig, ax = plt.subplots() | |
ax.imshow(array) | |
ax.axis("off") | |
buf = BytesIO() | |
fig.savefig(buf, format="png", bbox_inches="tight") | |
plt.close(fig) | |
buf.seek(0) | |
return base64.b64encode(buf.read()).decode("utf-8") | |