Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import librosa | |
import librosa.display | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
import plotly.express as px | |
import soundfile as sf | |
from scipy.signal import stft | |
# Dummy CNN Model for Audio | |
class AudioCNN(nn.Module): | |
def __init__(self): | |
super(AudioCNN, self).__init__() | |
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) | |
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) | |
self.fc1 = nn.Linear(32 * 32 * 8, 128) # Adjusted for typical spectrogram size | |
self.fc2 = nn.Linear(128, 10) | |
def forward(self, x): | |
x1 = F.relu(self.conv1(x)) # First conv layer activation | |
x2 = F.relu(self.conv2(x1)) | |
x3 = F.adaptive_avg_pool2d(x2, (8, 32)) | |
x4 = x3.view(x3.size(0), -1) | |
x5 = F.relu(self.fc1(x4)) | |
x6 = self.fc2(x5) | |
return x6, x1 | |
# Audio processing functions | |
def load_audio(file): | |
audio, sr = librosa.load(file, sr=None, mono=True) | |
return audio, sr | |
def apply_fft(audio): | |
fft = np.fft.fft(audio) | |
magnitude = np.abs(fft) | |
phase = np.angle(fft) | |
return fft, magnitude, phase | |
def filter_fft(fft, percentage): | |
magnitude = np.abs(fft) | |
sorted_indices = np.argsort(magnitude)[::-1] | |
num_keep = int(len(sorted_indices) * percentage / 100) | |
mask = np.zeros_like(fft) | |
mask[sorted_indices[:num_keep]] = 1 | |
return fft * mask | |
def create_spectrogram(audio, sr): | |
n_fft = 2048 | |
hop_length = 512 | |
stft = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length) | |
spectrogram = np.abs(stft) | |
return spectrogram, n_fft, hop_length | |
# Visualization functions | |
def plot_waveform(audio, sr, title): | |
fig = go.Figure() | |
time = np.arange(len(audio)) / sr | |
fig.add_trace(go.Scatter(x=time, y=audio, mode='lines')) | |
fig.update_layout(title=title, xaxis_title='Time (s)', yaxis_title='Amplitude') | |
return fig | |
def plot_fft(magnitude, phase, sr): | |
fig = make_subplots(rows=2, cols=1, subplot_titles=('Magnitude Spectrum', 'Phase Spectrum')) | |
freq = np.fft.fftfreq(len(magnitude), 1/sr) | |
fig.add_trace(go.Scatter(x=freq, y=magnitude, mode='lines', name='Magnitude'), row=1, col=1) | |
fig.add_trace(go.Scatter(x=freq, y=phase, mode='lines', name='Phase'), row=2, col=1) | |
fig.update_xaxes(title_text='Frequency (Hz)', row=1, col=1) | |
fig.update_xaxes(title_text='Frequency (Hz)', row=2, col=1) | |
fig.update_yaxes(title_text='Magnitude', row=1, col=1) | |
fig.update_yaxes(title_text='Phase (radians)', row=2, col=1) | |
return fig | |
def plot_3d_fft(magnitude, phase, sr): | |
freq = np.fft.fftfreq(len(magnitude), 1/sr) | |
fig = go.Figure(data=[go.Scatter3d( | |
x=freq, | |
y=magnitude, | |
z=phase, | |
mode='markers', | |
marker=dict( | |
size=5, | |
color=phase, # Color by phase | |
colorscale='Viridis', # Choose a colorscale | |
opacity=0.8 | |
) | |
)]) | |
fig.update_layout(scene=dict( | |
xaxis_title='Frequency (Hz)', | |
yaxis_title='Magnitude', | |
zaxis_title='Phase (radians)' | |
)) | |
return fig | |
def plot_spectrogram(spectrogram, sr, hop_length): | |
fig, ax = plt.subplots() | |
img = librosa.display.specshow(librosa.amplitude_to_db(spectrogram, ref=np.max), | |
sr=sr, hop_length=hop_length, x_axis='time', y_axis='log', ax=ax) | |
plt.colorbar(img, ax=ax, format='%+2.0f dB') | |
plt.title('Spectrogram') | |
return fig | |
def create_fft_table(magnitude, phase, sr): | |
freq = np.fft.fftfreq(len(magnitude), 1/sr) | |
df = pd.DataFrame({ | |
'Frequency (Hz)': freq, | |
'Magnitude': magnitude, | |
'Phase (radians)': phase | |
}) | |
return df | |
# Streamlit UI | |
st.set_page_config(layout="wide") | |
st.title("Audio Frequency Analysis with CNN") | |
# Initialize session state | |
if 'audio_data' not in st.session_state: | |
st.session_state.audio_data = None | |
if 'sr' not in st.session_state: | |
st.session_state.sr = None | |
if 'fft' not in st.session_state: | |
st.session_state.fft = None | |
# File uploader | |
uploaded_file = st.file_uploader("Upload an audio file", type=['wav', 'mp3', 'ogg']) | |
if uploaded_file is not None: | |
# Load and process audio | |
audio, sr = load_audio(uploaded_file) | |
st.session_state.audio_data = audio | |
st.session_state.sr = sr | |
# Display original waveform | |
st.subheader("Original Audio Waveform") | |
st.plotly_chart(plot_waveform(audio, sr, "Original Waveform"), use_container_width=True) | |
# Apply FFT | |
fft, magnitude, phase = apply_fft(audio) | |
st.session_state.fft = fft | |
# Display FFT results | |
st.subheader("Frequency Domain Analysis") | |
st.plotly_chart(plot_fft(magnitude, phase, sr), use_container_width=True) | |
# 3D FFT Plot | |
st.subheader("3D Frequency Domain Analysis") | |
st.plotly_chart(plot_3d_fft(magnitude, phase, sr), use_container_width=True) | |
# FFT Table | |
st.subheader("FFT Values Table") | |
fft_table = create_fft_table(magnitude, phase, sr) | |
st.dataframe(fft_table) | |
# Frequency filtering | |
percentage = st.slider("Percentage of frequencies to retain:", 0.1, 100.0, 10.0, 0.1) | |
if st.button("Apply Frequency Filter"): | |
filtered_fft = filter_fft(st.session_state.fft, percentage) | |
reconstructed = np.fft.ifft(filtered_fft).real | |
# Display reconstructed waveform | |
st.subheader("Reconstructed Audio") | |
st.plotly_chart(plot_waveform(reconstructed, sr, "Filtered Waveform"), use_container_width=True) | |
# Play audio | |
st.audio(reconstructed, sample_rate=sr) | |
# Spectrogram creation | |
st.subheader("Spectrogram Analysis") | |
spectrogram, n_fft, hop_length = create_spectrogram(audio, sr) | |
st.pyplot(plot_spectrogram(spectrogram, sr, hop_length)) | |
# CNN Processing | |
if st.button("Process with CNN"): | |
# Convert spectrogram to tensor | |
spec_tensor = torch.tensor(spectrogram[np.newaxis, np.newaxis, ...], dtype=torch.float32) | |
model = AudioCNN() | |
with torch.no_grad(): | |
output, activations = model(spec_tensor) | |
# Visualize activations | |
st.subheader("CNN Layer Activations") | |
# Input spectrogram | |
st.write("### Input Spectrogram") | |
fig_input, ax = plt.subplots() | |
ax.imshow(spectrogram, aspect='auto', origin='lower') | |
st.pyplot(fig_input) | |
# First conv layer activations | |
st.write("### First Convolution Layer Activations") | |
activation = activations.detach().numpy()[0] | |
cols = 4 | |
rows = 4 | |
fig, axs = plt.subplots(rows, cols, figsize=(20, 20)) | |
for i in range(16): | |
ax = axs[i//cols, i%cols] | |
ax.imshow(activation[i], aspect='auto', origin='lower') | |
ax.set_title(f'Channel {i+1}') | |
plt.tight_layout() | |
st.pyplot(fig) | |
# Classification results | |
st.write("### Classification Output") | |
probabilities = F.softmax(output, dim=1).numpy()[0] | |
classes = [f"Class {i}" for i in range(10)] | |
df = pd.DataFrame({"Class": classes, "Probability": probabilities}) | |
fig = px.bar(df, x="Class", y="Probability", color="Probability") | |
st.plotly_chart(fig) | |
# Add some styling | |
st.markdown(""" | |
<style> | |
.stButton>button { | |
padding: 10px 20px; | |
font-size: 16px; | |
background-color: #4CAF50; | |
color: white; | |
} | |
.stSlider>div>div>div>div { | |
background-color: #4CAF50; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |