Spaces:
Sleeping
Sleeping
File size: 2,208 Bytes
1f5d38f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import laion_clap
import numpy as np
import librosa
import pickle
import os
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import zipfile
import json
dataset_zip = "dataset/one_shot_percussive_sounds.zip"
extracted_folder = "dataset/unzipped"
metadata_path = "dataset/licenses.txt"
audio_embeddings_path = "dataset/audio_embeddings.pkl"
# Unzip if not already extracted
if not os.path.exists(extracted_folder):
with zipfile.ZipFile(dataset_zip, "r") as zip_ref:
zip_ref.extractall("dataset")
# Load the model
model = laion_clap.CLAP_Module(enable_fusion=True)
model.load_ckpt(model_id=3)
# Load dataset metadata
with open(metadata_path, "r") as file:
data = json.load(file)
# Convert the JSON data into a Pandas DataFrame
metadata = pd.DataFrame.from_dict(data, orient="index")
metadata.index = metadata.index.astype(str) + '.wav'
# Load precomputed audio embeddings (to avoid recomputing on every request)
with open(audio_embeddings_path, "rb") as f:
audio_embeddings = pickle.load(f)
def get_clap_embeddings_from_text(text):
"""Convert user text input to a CLAP embedding."""
text_embed = model.get_text_embedding([text])
return text_embed[0, :]
def find_top_sounds(text_embed, instrument, top_N=4):
"""Finds the closest N sounds for an instrument."""
valid_sounds = metadata[metadata["Instrument"] == instrument].index.tolist()
relevant_embeddings = {k: v for k, v in audio_embeddings.items() if k in valid_sounds}
# Compute cosine similarity
all_embeds = np.array([v for v in relevant_embeddings.values()])
similarities = cosine_similarity([text_embed], all_embeds)[0]
# Get top N matches
top_indices = np.argsort(similarities)[-top_N:][::-1]
top_files = [valid_sounds[i] for i in top_indices]
return top_files
def generate_drum_kit(prompt, kit_size=4):
"""Generate a drum kit dictionary from user input."""
text_embed = get_clap_embeddings_from_text(prompt)
drum_kit = {}
for instrument in ["Kick", "Snare", "Hi-Hat", "Tom", "Cymbal", "Clap", "Percussion", "Other"]:
drum_kit[instrument] = find_top_sounds(text_embed, instrument, top_N=kit_size)
return drum_kit |