File size: 3,109 Bytes
f2865dc a51a160 f2865dc a51a160 f2865dc a51a160 f2865dc a51a160 f2865dc a51a160 f2865dc a51a160 f2865dc a51a160 f2865dc a51a160 f2865dc a51a160 f2865dc a51a160 f2865dc a51a160 f2865dc a51a160 f2865dc a51a160 f2865dc a51a160 f2865dc a51a160 f2865dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import hashlib
import os
from glob import glob
import laion_clap
from diskcache import Cache
from qdrant_client import QdrantClient
from qdrant_client.http import models
from tqdm import tqdm
# Utiliser les variables d'environnement pour la configuration
QDRANT_HOST = os.getenv('QDRANT_HOST', 'localhost')
QDRANT_PORT = int(os.getenv('QDRANT_PORT', 6333))
# Functions utils
def get_md5(fpath):
with open(fpath, "rb") as f:
file_hash = hashlib.md5()
while chunk := f.read(8192):
file_hash.update(chunk)
return file_hash.hexdigest()
# PARAMETERS
CACHE_FOLDER = '/home/nahia/data/audio/'
KAGGLE_TRAIN_PATH = '/home/nahia/Documents/audio/actor/Actor_01/'
# Charger le modèle CLAP
print("[INFO] Loading the model...")
model_name = 'music_speech_epoch_15_esc_89.25.pt'
model = laion_clap.CLAP_Module(enable_fusion=False)
model.load_ckpt() # télécharger le checkpoint préentraîné par défaut
# Initialiser le cache
os.makedirs(CACHE_FOLDER, exist_ok=True)
cache = Cache(CACHE_FOLDER)
# Embarquer les fichiers audio
audio_files = [p for p in glob(os.path.join(KAGGLE_TRAIN_PATH, '*.wav'))]
audio_embeddings = []
chunk_size = 100
total_chunks = int(len(audio_files) / chunk_size)
# Utiliser tqdm pour une barre de progression
for i in tqdm(range(0, len(audio_files), chunk_size), total=total_chunks):
chunk = audio_files[i:i + chunk_size] # Obtenir un chunk de fichiers audio
chunk_embeddings = []
for audio_file in chunk:
# Calculer un hash unique pour le fichier audio
file_key = get_md5(audio_file)
if file_key in cache:
# Si l'embedding pour ce fichier est en cache, le récupérer
embedding = cache[file_key]
else:
# Sinon, calculer l'embedding et le mettre en cache
embedding = model.get_audio_embedding_from_filelist(x=[audio_file], use_tensor=False)[
0] # Assumer que le modèle retourne une liste
cache[file_key] = embedding
chunk_embeddings.append(embedding)
audio_embeddings.extend(chunk_embeddings)
# Fermer le cache quand terminé
cache.close()
# Créer une collection qdrant
client = QdrantClient(QDRANT_HOST, port=QDRANT_PORT)
print("[INFO] Client created...")
print("[INFO] Creating qdrant data collection...")
client.create_collection(
collection_name="demo_db7",
vectors_config=models.VectorParams(
size=audio_embeddings[0].shape[0],
distance=models.Distance.COSINE
),
)
# Créer des enregistrements Qdrant à partir des embeddings
records = []
for idx, (audio_path, embedding) in enumerate(zip(audio_files, audio_embeddings)):
record = models.PointStruct(
id=idx,
vector=embedding,
payload={"audio_path": audio_path, "style": audio_path.split('/')[-2]}
)
records.append(record)
# Téléverser les enregistrements dans la collection Qdrant
print("[INFO] Uploading data records to data collection...")
client.upload_points(collection_name="demo_db7", points=records)
print("[INFO] Successfully uploaded data records to data collection!")
|