import gradio as gr import numpy as np import torch import librosa from transformers import AutoFeatureExtractor, AutoModelForAudioClassification import matplotlib.pyplot as plt from matplotlib.colors import Normalize # Constants SAMPLING_RATE = 16000 MODEL_NAME = "MIT/ast-finetuned-audioset-10-10-0.4593" DEFAULT_THRESHOLD = 0.7 # Load model and feature extractor feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME) model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME) def analyze_audio(audio_array, threshold=DEFAULT_THRESHOLD): """ Process audio and detect anomalies Returns: - classification result - confidence score - spectrogram visualization """ try: # Handle different audio input formats if isinstance(audio_array, tuple): sr, audio = audio_array if sr != SAMPLING_RATE: audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLING_RATE) else: audio = audio_array if len(audio.shape) > 1: audio = librosa.to_mono(audio) # Extract features inputs = feature_extractor( audio, sampling_rate=SAMPLING_RATE, return_tensors="pt", padding=True, return_attention_mask=True ) # Run inference with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.softmax(logits, dim=-1) # Get predicted class and confidence predicted_class = "Normal" if probs[0][0] > threshold else "Anomaly" confidence = probs[0][0].item() if predicted_class == "Normal" else 1 - probs[0][0].item() # Create spectrogram visualization spectrogram = librosa.feature.melspectrogram( y=audio, sr=SAMPLING_RATE, n_mels=64, # Reduced from 128 to avoid warning fmax=8000 ) db_spec = librosa.power_to_db(spectrogram, ref=np.max) fig, ax = plt.subplots(figsize=(10, 4)) img = librosa.display.specshow( db_spec, x_axis='time', y_axis='mel', sr=SAMPLING_RATE, fmax=8000, ax=ax ) fig.colorbar(img, ax=ax, format='%+2.0f dB') ax.set(title='Mel Spectrogram') plt.tight_layout() plt.savefig('spec.png', bbox_inches='tight') plt.close() return ( predicted_class, f"{confidence:.1%}", 'spec.png', str(probs.tolist()[0]) ) except Exception as e: return f"Error: {str(e)}", "", "", "" # Gradio interface with gr.Blocks(title="Industrial Audio Analyzer", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🏭 Industrial Equipment Sound Analyzer ### Powered by Audio Spectrogram Transformer (AST) """) with gr.Row(): with gr.Column(): audio_input = gr.Audio( label="Upload Equipment Audio Recording", type="numpy" ) threshold = gr.Slider( minimum=0.5, maximum=0.95, step=0.05, value=DEFAULT_THRESHOLD, label="Anomaly Detection Threshold", info="Higher values reduce false positives but may miss subtle anomalies" ) analyze_btn = gr.Button("🔍 Analyze Sound", variant="primary") with gr.Column(): result_label = gr.Label(label="Detection Result") confidence = gr.Textbox(label="Confidence Score") spectrogram = gr.Image(label="Spectrogram Visualization") raw_probs = gr.Textbox( label="Model Output Probabilities", visible=False ) analyze_btn.click( fn=analyze_audio, inputs=[audio_input, threshold], outputs=[result_label, confidence, spectrogram, raw_probs] ) gr.Markdown(""" ## How It Works - Upload audio recordings from industrial equipment - The AI analyzes sound patterns using spectrogram analysis - Detects anomalies indicating potential equipment issues **Tip**: For best results, use 5-10 second recordings of steady operation """) if __name__ == "__main__": demo.launch()