zainulabedin949 commited on
Commit
e9b0e37
Β·
verified Β·
1 Parent(s): 2ab151c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -33
app.py CHANGED
@@ -2,9 +2,12 @@ import gradio as gr
2
  import numpy as np
3
  import torch
4
  import librosa
 
5
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
6
  import matplotlib.pyplot as plt
7
  from matplotlib.colors import Normalize
 
 
8
 
9
  # Constants
10
  SAMPLING_RATE = 16000
@@ -15,26 +18,42 @@ DEFAULT_THRESHOLD = 0.7
15
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
16
  model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
17
 
18
- def analyze_audio(audio_array, threshold=DEFAULT_THRESHOLD):
19
- """
20
- Process audio and detect anomalies
21
- Returns:
22
- - classification result
23
- - confidence score
24
- - spectrogram visualization
25
- """
26
  try:
27
- # Handle different audio input formats
28
- if isinstance(audio_array, tuple):
29
- sr, audio = audio_array
30
- if sr != SAMPLING_RATE:
31
- audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLING_RATE)
32
- else:
33
- audio = audio_array
34
-
 
35
  if len(audio.shape) > 1:
36
- audio = librosa.to_mono(audio)
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # Extract features
39
  inputs = feature_extractor(
40
  audio,
@@ -50,15 +69,15 @@ def analyze_audio(audio_array, threshold=DEFAULT_THRESHOLD):
50
  logits = outputs.logits
51
  probs = torch.softmax(logits, dim=-1)
52
 
53
- # Get predicted class and confidence
54
  predicted_class = "Normal" if probs[0][0] > threshold else "Anomaly"
55
  confidence = probs[0][0].item() if predicted_class == "Normal" else 1 - probs[0][0].item()
56
 
57
- # Create spectrogram visualization
58
  spectrogram = librosa.feature.melspectrogram(
59
  y=audio,
60
  sr=SAMPLING_RATE,
61
- n_mels=64, # Reduced from 128 to avoid warning
62
  fmax=8000
63
  )
64
  db_spec = librosa.power_to_db(spectrogram, ref=np.max)
@@ -75,18 +94,21 @@ def analyze_audio(audio_array, threshold=DEFAULT_THRESHOLD):
75
  fig.colorbar(img, ax=ax, format='%+2.0f dB')
76
  ax.set(title='Mel Spectrogram')
77
  plt.tight_layout()
78
- plt.savefig('spec.png', bbox_inches='tight')
 
 
 
79
  plt.close()
80
 
81
  return (
82
  predicted_class,
83
  f"{confidence:.1%}",
84
- 'spec.png',
85
  str(probs.tolist()[0])
86
  )
87
 
88
  except Exception as e:
89
- return f"Error: {str(e)}", "", "", ""
90
 
91
  # Gradio interface
92
  with gr.Blocks(title="Industrial Audio Analyzer", theme=gr.themes.Soft()) as demo:
@@ -98,16 +120,15 @@ with gr.Blocks(title="Industrial Audio Analyzer", theme=gr.themes.Soft()) as dem
98
  with gr.Row():
99
  with gr.Column():
100
  audio_input = gr.Audio(
101
- label="Upload Equipment Audio Recording",
102
- type="numpy"
103
  )
104
  threshold = gr.Slider(
105
  minimum=0.5,
106
  maximum=0.95,
107
  step=0.05,
108
  value=DEFAULT_THRESHOLD,
109
- label="Anomaly Detection Threshold",
110
- info="Higher values reduce false positives but may miss subtle anomalies"
111
  )
112
  analyze_btn = gr.Button("πŸ” Analyze Sound", variant="primary")
113
 
@@ -127,12 +148,10 @@ with gr.Blocks(title="Industrial Audio Analyzer", theme=gr.themes.Soft()) as dem
127
  )
128
 
129
  gr.Markdown("""
130
- ## How It Works
131
- - Upload audio recordings from industrial equipment
132
- - The AI analyzes sound patterns using spectrogram analysis
133
- - Detects anomalies indicating potential equipment issues
134
-
135
- **Tip**: For best results, use 5-10 second recordings of steady operation
136
  """)
137
 
138
  if __name__ == "__main__":
 
2
  import numpy as np
3
  import torch
4
  import librosa
5
+ import soundfile as sf
6
  from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
7
  import matplotlib.pyplot as plt
8
  from matplotlib.colors import Normalize
9
+ import tempfile
10
+ import os
11
 
12
  # Constants
13
  SAMPLING_RATE = 16000
 
18
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
19
  model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME)
20
 
21
+ def handle_audio_file(audio_file):
22
+ """Handle uploaded audio file and convert to numpy array"""
 
 
 
 
 
 
23
  try:
24
+ # Save to temp file and load with soundfile
25
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
26
+ tmp.write(audio_file.read())
27
+ tmp_path = tmp.name
28
+
29
+ audio, sr = sf.read(tmp_path)
30
+ os.unlink(tmp_path) # Clean up temp file
31
+
32
+ # Convert to mono if needed
33
  if len(audio.shape) > 1:
34
+ audio = np.mean(audio, axis=1)
35
 
36
+ return audio, sr
37
+ except Exception as e:
38
+ raise ValueError(f"Error processing audio file: {str(e)}")
39
+
40
+ def analyze_audio(audio_input, threshold=DEFAULT_THRESHOLD):
41
+ """Process audio and detect anomalies"""
42
+ try:
43
+ # Handle different input types
44
+ if isinstance(audio_input, str): # File path
45
+ audio, sr = handle_audio_file(open(audio_input, 'rb'))
46
+ elif hasattr(audio_input, 'name'): # Gradio file object
47
+ audio, sr = handle_audio_file(audio_input)
48
+ elif isinstance(audio_input, tuple): # Direct numpy array
49
+ sr, audio = audio_input
50
+ else:
51
+ raise ValueError("Unsupported audio input format")
52
+
53
+ # Resample if needed
54
+ if sr != SAMPLING_RATE:
55
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLING_RATE)
56
+
57
  # Extract features
58
  inputs = feature_extractor(
59
  audio,
 
69
  logits = outputs.logits
70
  probs = torch.softmax(logits, dim=-1)
71
 
72
+ # Get results
73
  predicted_class = "Normal" if probs[0][0] > threshold else "Anomaly"
74
  confidence = probs[0][0].item() if predicted_class == "Normal" else 1 - probs[0][0].item()
75
 
76
+ # Create spectrogram
77
  spectrogram = librosa.feature.melspectrogram(
78
  y=audio,
79
  sr=SAMPLING_RATE,
80
+ n_mels=64,
81
  fmax=8000
82
  )
83
  db_spec = librosa.power_to_db(spectrogram, ref=np.max)
 
94
  fig.colorbar(img, ax=ax, format='%+2.0f dB')
95
  ax.set(title='Mel Spectrogram')
96
  plt.tight_layout()
97
+
98
+ # Save to temp file
99
+ spec_path = os.path.join(tempfile.gettempdir(), 'spec.png')
100
+ plt.savefig(spec_path, bbox_inches='tight')
101
  plt.close()
102
 
103
  return (
104
  predicted_class,
105
  f"{confidence:.1%}",
106
+ spec_path,
107
  str(probs.tolist()[0])
108
  )
109
 
110
  except Exception as e:
111
+ return f"Error: {str(e)}", "", None, ""
112
 
113
  # Gradio interface
114
  with gr.Blocks(title="Industrial Audio Analyzer", theme=gr.themes.Soft()) as demo:
 
120
  with gr.Row():
121
  with gr.Column():
122
  audio_input = gr.Audio(
123
+ label="Upload Equipment Audio (.wav)",
124
+ type="filepath"
125
  )
126
  threshold = gr.Slider(
127
  minimum=0.5,
128
  maximum=0.95,
129
  step=0.05,
130
  value=DEFAULT_THRESHOLD,
131
+ label="Anomaly Detection Threshold"
 
132
  )
133
  analyze_btn = gr.Button("πŸ” Analyze Sound", variant="primary")
134
 
 
148
  )
149
 
150
  gr.Markdown("""
151
+ **Instructions:**
152
+ - Upload .wav audio recordings (5-10 seconds recommended)
153
+ - Adjust threshold to control sensitivity
154
+ - Results show Normal/Anomaly classification with confidence
 
 
155
  """)
156
 
157
  if __name__ == "__main__":