import gradio as gr import torch import torchaudio from transformers import AutoFeatureExtractor, ASTForAudioClassification model_name = "MIT/ast-finetuned-audioset-10-10-0.4593" model = ASTForAudioClassification.from_pretrained(model_name) feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) device = torch.device("cpu") model.to(device) def classify_sound(file_path): wv, sr = torchaudio.load(file_path) original_shape = wv.shape # Convert to mono if wv.shape[0] > 1: wv = wv.mean(dim=0, keepdim=True) inputs = feature_extractor( wv.squeeze().numpy(), sampling_rate=16000, return_tensors="pt" ) with torch.no_grad(): logits = model(**inputs).logits probs = torch.softmax(logits, dim=-1)[0] top5 = torch.topk(probs, k=5) top5_labels = [ (model.config.id2label[idx.item()], round(prob.item(), 4)) for idx, prob in zip(top5.indices, top5.values) ] top20 = torch.topk(probs, k=20) top20_probs = { model.config.id2label[idx.item()]: round(prob.item(), 4) for idx, prob in zip(top20.indices, top20.values) } return ( dict(top5_labels), str(sr), str(original_shape), top20_probs ) demo = gr.Interface( fn=classify_sound, inputs=gr.Audio(sources="upload", type="filepath"), outputs=[ gr.Label(label = "Top 5 Pred", num_top_classes=5), gr.Textbox(label="Sample Rate"), gr.Textbox(label="Waveform Shape"), gr.JSON(label="All Class Probabilities") ], title="Audio Classification with AST", description="Upload an audio clip (speech, music, ambient sound, etc.). Model: MIT AST fine-tuned on AudioSet (10 classes).", live=False, ) demo.launch()