zainulabedin949's picture
Update app.py
07c6db0 verified
raw
history blame
4.49 kB
import gradio as gr
import numpy as np
import torch
import librosa
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
# Constants
SAMPLING_RATE = 16000
MODEL_NAME = "MIT/ast-finetuned-audioset-10-10-0.4593"
DEFAULT_THRESHOLD = 0.7
# Load model and feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
def analyze_audio(audio_array, threshold=DEFAULT_THRESHOLD):
"""
Process audio and detect anomalies
Returns:
- classification result
- confidence score
- spectrogram visualization
"""
try:
# Handle different audio input formats
if isinstance(audio_array, tuple):
sr, audio = audio_array
if sr != SAMPLING_RATE:
audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLING_RATE)
else:
audio = audio_array
if len(audio.shape) > 1:
audio = librosa.to_mono(audio)
# Extract features
inputs = feature_extractor(
audio,
sampling_rate=SAMPLING_RATE,
return_tensors="pt",
padding=True,
return_attention_mask=True
)
# Run inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1)
# Get predicted class and confidence
predicted_class = "Normal" if probs[0][0] > threshold else "Anomaly"
confidence = probs[0][0].item() if predicted_class == "Normal" else 1 - probs[0][0].item()
# Create spectrogram visualization
spectrogram = librosa.feature.melspectrogram(
y=audio,
sr=SAMPLING_RATE,
n_mels=64, # Reduced from 128 to avoid warning
fmax=8000
)
db_spec = librosa.power_to_db(spectrogram, ref=np.max)
fig, ax = plt.subplots(figsize=(10, 4))
img = librosa.display.specshow(
db_spec,
x_axis='time',
y_axis='mel',
sr=SAMPLING_RATE,
fmax=8000,
ax=ax
)
fig.colorbar(img, ax=ax, format='%+2.0f dB')
ax.set(title='Mel Spectrogram')
plt.tight_layout()
plt.savefig('spec.png', bbox_inches='tight')
plt.close()
return (
predicted_class,
f"{confidence:.1%}",
'spec.png',
str(probs.tolist()[0])
)
except Exception as e:
return f"Error: {str(e)}", "", "", ""
# Gradio interface
with gr.Blocks(title="Industrial Audio Analyzer", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🏭 Industrial Equipment Sound Analyzer
### Powered by Audio Spectrogram Transformer (AST)
""")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="Upload Equipment Audio Recording",
type="numpy"
)
threshold = gr.Slider(
minimum=0.5,
maximum=0.95,
step=0.05,
value=DEFAULT_THRESHOLD,
label="Anomaly Detection Threshold",
info="Higher values reduce false positives but may miss subtle anomalies"
)
analyze_btn = gr.Button("πŸ” Analyze Sound", variant="primary")
with gr.Column():
result_label = gr.Label(label="Detection Result")
confidence = gr.Textbox(label="Confidence Score")
spectrogram = gr.Image(label="Spectrogram Visualization")
raw_probs = gr.Textbox(
label="Model Output Probabilities",
visible=False
)
analyze_btn.click(
fn=analyze_audio,
inputs=[audio_input, threshold],
outputs=[result_label, confidence, spectrogram, raw_probs]
)
gr.Markdown("""
## How It Works
- Upload audio recordings from industrial equipment
- The AI analyzes sound patterns using spectrogram analysis
- Detects anomalies indicating potential equipment issues
**Tip**: For best results, use 5-10 second recordings of steady operation
""")
if __name__ == "__main__":
demo.launch()