File size: 7,788 Bytes
1835c00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import streamlit as st
import numpy as np
import librosa
import librosa.display
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import plotly.express as px
import soundfile as sf
from scipy.signal import stft

# Dummy CNN Model for Audio
class AudioCNN(nn.Module):
    def __init__(self):
        super(AudioCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 32 * 8, 128)  # Adjusted for typical spectrogram size
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x1 = F.relu(self.conv1(x))  # First conv layer activation
        x2 = F.relu(self.conv2(x1))
        x3 = F.adaptive_avg_pool2d(x2, (8, 32))
        x4 = x3.view(x3.size(0), -1)
        x5 = F.relu(self.fc1(x4))
        x6 = self.fc2(x5)
        return x6, x1

# Audio processing functions
def load_audio(file):
    audio, sr = librosa.load(file, sr=None, mono=True)
    return audio, sr

def apply_fft(audio):
    fft = np.fft.fft(audio)
    magnitude = np.abs(fft)
    phase = np.angle(fft)
    return fft, magnitude, phase

def filter_fft(fft, percentage):
    magnitude = np.abs(fft)
    sorted_indices = np.argsort(magnitude)[::-1]
    num_keep = int(len(sorted_indices) * percentage / 100)
    mask = np.zeros_like(fft)
    mask[sorted_indices[:num_keep]] = 1
    return fft * mask

def create_spectrogram(audio, sr):
    n_fft = 2048
    hop_length = 512
    stft = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
    spectrogram = np.abs(stft)
    return spectrogram, n_fft, hop_length

# Visualization functions
def plot_waveform(audio, sr, title):
    fig = go.Figure()
    time = np.arange(len(audio)) / sr
    fig.add_trace(go.Scatter(x=time, y=audio, mode='lines'))
    fig.update_layout(title=title, xaxis_title='Time (s)', yaxis_title='Amplitude')
    return fig

def plot_fft(magnitude, phase, sr):
    fig = make_subplots(rows=2, cols=1, subplot_titles=('Magnitude Spectrum', 'Phase Spectrum'))
    freq = np.fft.fftfreq(len(magnitude), 1/sr)
    
    fig.add_trace(go.Scatter(x=freq, y=magnitude, mode='lines', name='Magnitude'), row=1, col=1)
    fig.add_trace(go.Scatter(x=freq, y=phase, mode='lines', name='Phase'), row=2, col=1)
    
    fig.update_xaxes(title_text='Frequency (Hz)', row=1, col=1)
    fig.update_xaxes(title_text='Frequency (Hz)', row=2, col=1)
    fig.update_yaxes(title_text='Magnitude', row=1, col=1)
    fig.update_yaxes(title_text='Phase (radians)', row=2, col=1)
    
    return fig

def plot_3d_fft(magnitude, phase, sr):
    freq = np.fft.fftfreq(len(magnitude), 1/sr)
    fig = go.Figure(data=[go.Scatter3d(
        x=freq,
        y=magnitude,
        z=phase,
        mode='markers',
        marker=dict(
            size=5,
            color=phase,  # Color by phase
            colorscale='Viridis',  # Choose a colorscale
            opacity=0.8
        )
    )])
    
    fig.update_layout(scene=dict(
        xaxis_title='Frequency (Hz)',
        yaxis_title='Magnitude',
        zaxis_title='Phase (radians)'
    ))
    
    return fig

def plot_spectrogram(spectrogram, sr, hop_length):
    fig, ax = plt.subplots()
    img = librosa.display.specshow(librosa.amplitude_to_db(spectrogram, ref=np.max),
                                   sr=sr, hop_length=hop_length, x_axis='time', y_axis='log', ax=ax)
    plt.colorbar(img, ax=ax, format='%+2.0f dB')
    plt.title('Spectrogram')
    return fig

def create_fft_table(magnitude, phase, sr):
    freq = np.fft.fftfreq(len(magnitude), 1/sr)
    df = pd.DataFrame({
        'Frequency (Hz)': freq,
        'Magnitude': magnitude,
        'Phase (radians)': phase
    })
    return df

# Streamlit UI
st.set_page_config(layout="wide")
st.title("Audio Frequency Analysis with CNN")

# Initialize session state
if 'audio_data' not in st.session_state:
    st.session_state.audio_data = None
if 'sr' not in st.session_state:
    st.session_state.sr = None
if 'fft' not in st.session_state:
    st.session_state.fft = None

# File uploader
uploaded_file = st.file_uploader("Upload an audio file", type=['wav', 'mp3', 'ogg'])

if uploaded_file is not None:
    # Load and process audio
    audio, sr = load_audio(uploaded_file)
    st.session_state.audio_data = audio
    st.session_state.sr = sr
    
    # Display original waveform
    st.subheader("Original Audio Waveform")
    st.plotly_chart(plot_waveform(audio, sr, "Original Waveform"), use_container_width=True)
    
    # Apply FFT
    fft, magnitude, phase = apply_fft(audio)
    st.session_state.fft = fft
    
    # Display FFT results
    st.subheader("Frequency Domain Analysis")
    st.plotly_chart(plot_fft(magnitude, phase, sr), use_container_width=True)
    
    # 3D FFT Plot
    st.subheader("3D Frequency Domain Analysis")
    st.plotly_chart(plot_3d_fft(magnitude, phase, sr), use_container_width=True)
    
    # FFT Table
    st.subheader("FFT Values Table")
    fft_table = create_fft_table(magnitude, phase, sr)
    st.dataframe(fft_table)
    
    # Frequency filtering
    percentage = st.slider("Percentage of frequencies to retain:", 0.1, 100.0, 10.0, 0.1)
    
    if st.button("Apply Frequency Filter"):
        filtered_fft = filter_fft(st.session_state.fft, percentage)
        reconstructed = np.fft.ifft(filtered_fft).real
        
        # Display reconstructed waveform
        st.subheader("Reconstructed Audio")
        st.plotly_chart(plot_waveform(reconstructed, sr, "Filtered Waveform"), use_container_width=True)
        
        # Play audio
        st.audio(reconstructed, sample_rate=sr)
    
    # Spectrogram creation
    st.subheader("Spectrogram Analysis")
    spectrogram, n_fft, hop_length = create_spectrogram(audio, sr)
    st.pyplot(plot_spectrogram(spectrogram, sr, hop_length))
    
    # CNN Processing
    if st.button("Process with CNN"):
        # Convert spectrogram to tensor
        spec_tensor = torch.tensor(spectrogram[np.newaxis, np.newaxis, ...], dtype=torch.float32)
        
        model = AudioCNN()
        with torch.no_grad():
            output, activations = model(spec_tensor)
        
        # Visualize activations
        st.subheader("CNN Layer Activations")
        
        # Input spectrogram
        st.write("### Input Spectrogram")
        fig_input, ax = plt.subplots()
        ax.imshow(spectrogram, aspect='auto', origin='lower')
        st.pyplot(fig_input)
        
        # First conv layer activations
        st.write("### First Convolution Layer Activations")
        activation = activations.detach().numpy()[0]
        
        cols = 4
        rows = 4
        fig, axs = plt.subplots(rows, cols, figsize=(20, 20))
        for i in range(16):
            ax = axs[i//cols, i%cols]
            ax.imshow(activation[i], aspect='auto', origin='lower')
            ax.set_title(f'Channel {i+1}')
        plt.tight_layout()
        st.pyplot(fig)
        
        # Classification results
        st.write("### Classification Output")
        probabilities = F.softmax(output, dim=1).numpy()[0]
        classes = [f"Class {i}" for i in range(10)]
        df = pd.DataFrame({"Class": classes, "Probability": probabilities})
        fig = px.bar(df, x="Class", y="Probability", color="Probability")
        st.plotly_chart(fig)

# Add some styling
st.markdown("""
    <style>
        .stButton>button {
            padding: 10px 20px;
            font-size: 16px;
            background-color: #4CAF50;
            color: white;
        }
        .stSlider>div>div>div>div {
            background-color: #4CAF50;
        }
    </style>
""", unsafe_allow_html=True)