Spaces:
Sleeping
Sleeping
File size: 3,213 Bytes
3006faa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# finalproject.py
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from huggingface_hub import snapshot_download
from converter import Generator, denormalize_spectrogram
import matplotlib.pyplot as plt
import numpy as np
from io import BytesIO
import base64
from PIL import Image
# Model Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16
MODEL_NAME = "auffusion/auffusion-full-no-adapter"
model_path = snapshot_download(MODEL_NAME)
vocoder = Generator.from_pretrained(model_path, subfolder="vocoder").to(device=device, dtype=dtype)
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
vae = pipe.vae
tokenizer = pipe.tokenizer
text_encoder = pipe.text_encoder
unet = pipe.unet
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
def get_text_embeds(prompt):
input_ids = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, return_tensors='pt').input_ids.to(device)
return text_encoder(input_ids)[0]
@torch.no_grad()
def generate_reversible_audio(prompt1: str, prompt2: str, steps: int = 100, guidance_scale=7.5, weight=0.5, seed=42):
# Text Embeddings
cond1 = get_text_embeds(prompt1)
uncond1 = get_text_embeds("")
embed1 = torch.cat([uncond1, cond1], dim=0)
cond2 = get_text_embeds(prompt2)
uncond2 = get_text_embeds("")
embed2 = torch.cat([uncond2, cond2], dim=0)
# Latents
generator = torch.Generator(device=device).manual_seed(seed)
latents = torch.randn((1, unet.config.in_channels, 32, 128), generator=generator, dtype=dtype, device=device)
latents = latents * scheduler.init_noise_sigma
scheduler.set_timesteps(steps)
for t in scheduler.timesteps:
noise1 = estimate_noise(t, latents, embed1, guidance_scale)
noise2 = estimate_noise(t, latents.flip(3), embed2, guidance_scale).flip(3)
noise = weight * noise1 + (1 - weight) * noise2
latents = scheduler.step(noise, t, latents).prev_sample
spec = decode_latents(latents).squeeze(0)
denorm = denormalize_spectrogram(spec)
audio = vocoder.inference(denorm)
# Flip and reverse
flipped_spec = spec.flip(dims=[2])
flipped_denorm = denormalize_spectrogram(flipped_spec)
flipped_audio = vocoder.inference(flipped_denorm)
return spec, audio, flipped_spec, flipped_audio
@torch.no_grad()
def estimate_noise(t, latents, embeds, guidance_scale=7.5):
input_latents = torch.cat([latents] * 2)
noise_pred = unet(input_latents, t, encoder_hidden_states=embeds)["sample"]
uncond, cond = noise_pred.chunk(2)
return uncond + guidance_scale * (cond - uncond)
@torch.no_grad()
def decode_latents(latents):
latents = 1 / vae.config.scaling_factor * latents
image = vae.decode(latents).sample
return (image / 2 + 0.5).clamp(0, 1)
def plot_spectrogram(spec):
array = spec.detach().cpu().permute(1, 2, 0).numpy()
fig, ax = plt.subplots()
ax.imshow(array)
ax.axis("off")
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
plt.close(fig)
buf.seek(0)
return base64.b64encode(buf.read()).decode("utf-8")
|