File size: 3,213 Bytes
3006faa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# finalproject.py

import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from huggingface_hub import snapshot_download
from converter import Generator, denormalize_spectrogram
import matplotlib.pyplot as plt
import numpy as np
from io import BytesIO
import base64
from PIL import Image

# Model Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16

MODEL_NAME = "auffusion/auffusion-full-no-adapter"
model_path = snapshot_download(MODEL_NAME)

vocoder = Generator.from_pretrained(model_path, subfolder="vocoder").to(device=device, dtype=dtype)
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device)
vae = pipe.vae
tokenizer = pipe.tokenizer
text_encoder = pipe.text_encoder
unet = pipe.unet
scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")

def get_text_embeds(prompt):
    input_ids = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, return_tensors='pt').input_ids.to(device)
    return text_encoder(input_ids)[0]

@torch.no_grad()
def generate_reversible_audio(prompt1: str, prompt2: str, steps: int = 100, guidance_scale=7.5, weight=0.5, seed=42):
    # Text Embeddings
    cond1 = get_text_embeds(prompt1)
    uncond1 = get_text_embeds("")
    embed1 = torch.cat([uncond1, cond1], dim=0)

    cond2 = get_text_embeds(prompt2)
    uncond2 = get_text_embeds("")
    embed2 = torch.cat([uncond2, cond2], dim=0)

    # Latents
    generator = torch.Generator(device=device).manual_seed(seed)
    latents = torch.randn((1, unet.config.in_channels, 32, 128), generator=generator, dtype=dtype, device=device)
    latents = latents * scheduler.init_noise_sigma
    scheduler.set_timesteps(steps)

    for t in scheduler.timesteps:
        noise1 = estimate_noise(t, latents, embed1, guidance_scale)
        noise2 = estimate_noise(t, latents.flip(3), embed2, guidance_scale).flip(3)
        noise = weight * noise1 + (1 - weight) * noise2
        latents = scheduler.step(noise, t, latents).prev_sample

    spec = decode_latents(latents).squeeze(0)
    denorm = denormalize_spectrogram(spec)
    audio = vocoder.inference(denorm)

    # Flip and reverse
    flipped_spec = spec.flip(dims=[2])
    flipped_denorm = denormalize_spectrogram(flipped_spec)
    flipped_audio = vocoder.inference(flipped_denorm)

    return spec, audio, flipped_spec, flipped_audio

@torch.no_grad()
def estimate_noise(t, latents, embeds, guidance_scale=7.5):
    input_latents = torch.cat([latents] * 2)
    noise_pred = unet(input_latents, t, encoder_hidden_states=embeds)["sample"]
    uncond, cond = noise_pred.chunk(2)
    return uncond + guidance_scale * (cond - uncond)

@torch.no_grad()
def decode_latents(latents):
    latents = 1 / vae.config.scaling_factor * latents
    image = vae.decode(latents).sample
    return (image / 2 + 0.5).clamp(0, 1)

def plot_spectrogram(spec):
    array = spec.detach().cpu().permute(1, 2, 0).numpy()
    fig, ax = plt.subplots()
    ax.imshow(array)
    ax.axis("off")
    buf = BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight")
    plt.close(fig)
    buf.seek(0)
    return base64.b64encode(buf.read()).decode("utf-8")