|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
from transformers import AutoFeatureExtractor, ASTForAudioClassification |
|
|
|
model_name = "MIT/ast-finetuned-audioset-10-10-0.4593" |
|
model = ASTForAudioClassification.from_pretrained(model_name) |
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) |
|
|
|
device = torch.device("cpu") |
|
model.to(device) |
|
|
|
def classify_sound(file_path): |
|
wv, sr = torchaudio.load(file_path) |
|
original_shape = wv.shape |
|
|
|
|
|
if wv.shape[0] > 1: |
|
wv = wv.mean(dim=0, keepdim=True) |
|
|
|
inputs = feature_extractor( |
|
wv.squeeze().numpy(), sampling_rate=16000, return_tensors="pt" |
|
) |
|
|
|
with torch.no_grad(): |
|
logits = model(**inputs).logits |
|
|
|
probs = torch.softmax(logits, dim=-1)[0] |
|
top5 = torch.topk(probs, k=5) |
|
|
|
top5_labels = [ |
|
(model.config.id2label[idx.item()], round(prob.item(), 4)) |
|
for idx, prob in zip(top5.indices, top5.values) |
|
] |
|
|
|
top20 = torch.topk(probs, k=20) |
|
top20_probs = { |
|
model.config.id2label[idx.item()]: round(prob.item(), 4) |
|
for idx, prob in zip(top20.indices, top20.values) |
|
} |
|
|
|
return ( |
|
dict(top5_labels), |
|
str(sr), |
|
str(original_shape), |
|
top20_probs |
|
) |
|
|
|
demo = gr.Interface( |
|
fn=classify_sound, |
|
inputs=gr.Audio(sources="upload", type="filepath"), |
|
outputs=[ |
|
gr.Label(label = "Top 5 Pred", num_top_classes=5), |
|
gr.Textbox(label="Sample Rate"), |
|
gr.Textbox(label="Waveform Shape"), |
|
gr.JSON(label="All Class Probabilities") |
|
], |
|
title="Audio Classification with AST", |
|
description="Upload an audio clip (speech, music, ambient sound, etc.). Model: MIT AST fine-tuned on AudioSet (10 classes).", |
|
live=False, |
|
) |
|
|
|
demo.launch() |
|
|
|
|