diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..b591e62dcbe280caf1fd6d090c3184364634bbcf 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,68 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/TIGER.png filter=lfs diff=lfs merge=lfs -text
+assets/dnr-demo/sample1/dialog.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/dnr-demo/sample1/effect.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/dnr-demo/sample1/mixture.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/dnr-demo/sample1/music.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/dnr-demo/sample2/dialog.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/dnr-demo/sample2/effect.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/dnr-demo/sample2/mixture.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/dnr-demo/sample2/music.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/dnr-demo/sample3/dialog.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/dnr-demo/sample3/effect.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/dnr-demo/sample3/mixture.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/dnr-demo/sample3/music.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/dnr.png filter=lfs diff=lfs merge=lfs -text
+assets/efficiency.png filter=lfs diff=lfs merge=lfs -text
+assets/logo.png filter=lfs diff=lfs merge=lfs -text
+assets/result.png filter=lfs diff=lfs merge=lfs -text
+assets/sample1/GroundTruth/mix.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample1/GroundTruth/s1.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample1/GroundTruth/s2.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample1/TFGNet/s1.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample1/TFGNet/s2.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample1/TIGER/s1.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample1/TIGER/s2.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample1/spec/TFGNet_s1.png filter=lfs diff=lfs merge=lfs -text
+assets/sample1/spec/TFGNet_s2.png filter=lfs diff=lfs merge=lfs -text
+assets/sample1/spec/TIGER_s1.png filter=lfs diff=lfs merge=lfs -text
+assets/sample1/spec/TIGER_s2.png filter=lfs diff=lfs merge=lfs -text
+assets/sample1/spec/ground_truth_s1.png filter=lfs diff=lfs merge=lfs -text
+assets/sample1/spec/ground_truth_s2.png filter=lfs diff=lfs merge=lfs -text
+assets/sample2/GroundTruth/mix.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample2/GroundTruth/s1.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample2/GroundTruth/s2.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample2/TFGNet/s1.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample2/TFGNet/s2.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample2/TIGER/s1.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample2/TIGER/s2.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample2/spec/TFGNet_s1.png filter=lfs diff=lfs merge=lfs -text
+assets/sample2/spec/TFGNet_s2.png filter=lfs diff=lfs merge=lfs -text
+assets/sample2/spec/TIGER_s1.png filter=lfs diff=lfs merge=lfs -text
+assets/sample2/spec/TIGER_s2.png filter=lfs diff=lfs merge=lfs -text
+assets/sample2/spec/ground_truth_s1.png filter=lfs diff=lfs merge=lfs -text
+assets/sample2/spec/ground_truth_s2.png filter=lfs diff=lfs merge=lfs -text
+assets/sample3/GroundTruth/mix.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample3/GroundTruth/s1.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample3/GroundTruth/s2.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample3/TFGNet/s1.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample3/TFGNet/s2.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample3/TIGER/s1.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample3/TIGER/s2.wav filter=lfs diff=lfs merge=lfs -text
+assets/sample3/spec/TFGNet_s1.png filter=lfs diff=lfs merge=lfs -text
+assets/sample3/spec/TFGNet_s2.png filter=lfs diff=lfs merge=lfs -text
+assets/sample3/spec/TIGER_s1.png filter=lfs diff=lfs merge=lfs -text
+assets/sample3/spec/TIGER_s2.png filter=lfs diff=lfs merge=lfs -text
+assets/sample3/spec/ground_truth_s1.png filter=lfs diff=lfs merge=lfs -text
+assets/sample3/spec/ground_truth_s2.png filter=lfs diff=lfs merge=lfs -text
+test/mix.wav filter=lfs diff=lfs merge=lfs -text
+test/s1.wav filter=lfs diff=lfs merge=lfs -text
+test/s2.wav filter=lfs diff=lfs merge=lfs -text
+test/spk1.wav filter=lfs diff=lfs merge=lfs -text
+test/spk2.wav filter=lfs diff=lfs merge=lfs -text
+test/test_mixture_466.wav filter=lfs diff=lfs merge=lfs -text
+test/test_target_dialog_466.wav filter=lfs diff=lfs merge=lfs -text
+test/test_target_effect_466.wav filter=lfs diff=lfs merge=lfs -text
+test/test_target_music_466.wav filter=lfs diff=lfs merge=lfs -text
diff --git a/DataPreProcess/preprocess_lrs2_audio.py b/DataPreProcess/preprocess_lrs2_audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b51deeb067dae98ba2d860af933991ade5245be
--- /dev/null
+++ b/DataPreProcess/preprocess_lrs2_audio.py
@@ -0,0 +1,71 @@
+import argparse
+import json
+import os
+import soundfile as sf
+from tqdm import tqdm
+
+
+def get_mouth_path(in_mouth_dir, wav_file, spk, data_type):
+ wav_file = wav_file.replace(".wav", "").split("_")
+ if spk == "s1":
+ file_path = os.path.join(
+ in_mouth_dir, "{}_{}.npz".format(wav_file[0], wav_file[1])
+ )
+ else:
+ file_path = os.path.join(
+ in_mouth_dir, "{}_{}.npz".format(wav_file[3], wav_file[4])
+ )
+ return file_path
+
+
+def preprocess_one_dir(in_data_dir, out_dir, data_type, spk):
+ """Create .json file for one condition."""
+ file_infos = []
+ in_dir = os.path.abspath(os.path.join(in_data_dir, data_type, spk))
+ wav_list = os.listdir(in_dir)
+ wav_list.sort()
+ for wav_file in tqdm(wav_list):
+ if not wav_file.endswith(".wav"):
+ continue
+ wav_path = os.path.join(in_dir, wav_file)
+ samples = sf.SoundFile(wav_path)
+ if spk == "mix":
+ file_infos.append((wav_path, len(samples)))
+ else:
+ file_infos.append(
+ (
+ wav_path,
+ # get_mouth_path(os.path.join(in_data_dir, data_type, 'mouths'), wav_file, spk, data_type),
+ len(samples),
+ )
+ )
+ if not os.path.exists(os.path.join(out_dir, data_type)):
+ os.makedirs(os.path.join(out_dir, data_type))
+ with open(os.path.join(out_dir, data_type, spk + ".json"), "w") as f:
+ json.dump(file_infos, f, indent=4)
+
+
+def preprocess_lrs2_audio(inp_args):
+ """Create .json files for all conditions."""
+ speaker_list = ["mix", "s1", "s2"]
+ for data_type in ["tr", "cv", "tt"]:
+ for spk in speaker_list:
+ preprocess_one_dir(
+ inp_args.in_dir, inp_args.out_dir, data_type, spk,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser("LRS2 audio data preprocessing")
+ parser.add_argument(
+ "--in_dir",
+ type=str,
+ default=None,
+ help="Directory path of audio including tr, cv and tt",
+ )
+ parser.add_argument(
+ "--out_dir", type=str, default=None, help="Directory path to put output files"
+ )
+ args = parser.parse_args()
+ print(args)
+ preprocess_lrs2_audio(args)
diff --git a/DataPreProcess/process_echoset.py b/DataPreProcess/process_echoset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a384042e7e415dce72281f2dd4b65109b873053b
--- /dev/null
+++ b/DataPreProcess/process_echoset.py
@@ -0,0 +1,75 @@
+import argparse
+import json
+import os
+import soundfile as sf
+from tqdm import tqdm
+from rich import print
+
+
+def preprocess_one_dir(in_data_dir, out_dir, data_type):
+ """Create .json file for one condition."""
+ mix_infos = []
+ s1_infos = []
+ s2_infos = []
+ in_dir = os.path.abspath(os.path.join(in_data_dir, data_type))
+ print("Process {} set...".format(data_type))
+ for root, dirs, files in os.walk(in_dir):
+ for file in files:
+ if file.endswith(".wav") and file.startswith("mix"):
+ file_path = os.path.join(root, file)
+ audio, _ = sf.read(file_path)
+ mix_infos.append((
+ file_path,
+ len(audio),
+ ))
+
+ file_path = file_path.replace("mix", "spk1_reverb")
+ audio, _ = sf.read(file_path)
+ s1_infos.append((
+ file_path,
+ len(audio),
+ ))
+
+ file_path = file_path.replace("spk1_reverb", "spk2_reverb")
+ audio, _ = sf.read(file_path)
+ s2_infos.append((
+ file_path,
+ len(audio),
+ ))
+ print("Process num: {}".format(len(mix_infos)), end="\r")
+
+ if not os.path.exists(os.path.join(out_dir, data_type)):
+ os.makedirs(os.path.join(out_dir, data_type))
+ with open(os.path.join(out_dir, data_type, "mix.json"), "w") as f:
+ json.dump(mix_infos, f, indent=4)
+
+ with open(os.path.join(out_dir, data_type, "s1.json"), "w") as f:
+ json.dump(s1_infos, f, indent=4)
+
+ with open(os.path.join(out_dir, data_type, "s2.json"), "w") as f:
+ json.dump(s2_infos, f, indent=4)
+
+
+def preprocess_lrs2_audio(inp_args):
+ """Create .json files for all conditions."""
+ for data_type in ["train", "val", "test"]:
+ preprocess_one_dir(
+ inp_args.in_dir, inp_args.out_dir, data_type
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser("LRS2 audio data preprocessing")
+ parser.add_argument(
+ "--in_dir",
+ type=str,
+ default=None,
+ help="Directory path of audio including tr, cv and tt",
+ )
+ parser.add_argument(
+ "--out_dir", type=str, default=None, help="Directory path to put output files"
+ )
+ args = parser.parse_args()
+ print(args)
+ preprocess_lrs2_audio(args)
+
diff --git a/DataPreProcess/process_librimix.py b/DataPreProcess/process_librimix.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5ce29411890fd2e84717f5cb1e3b1462787e2f9
--- /dev/null
+++ b/DataPreProcess/process_librimix.py
@@ -0,0 +1,57 @@
+import argparse
+import json
+import os
+import soundfile as sf
+from tqdm import tqdm
+
+
+def preprocess_one_dir(in_data_dir, out_dir, data_type, spk):
+ """Create .json file for one condition."""
+ file_infos = []
+ in_dir = os.path.abspath(os.path.join(in_data_dir, data_type, spk))
+ wav_list = os.listdir(in_dir)
+ wav_list.sort()
+ for wav_file in tqdm(wav_list):
+ if not wav_file.endswith(".wav"):
+ continue
+ wav_path = os.path.join(in_dir, wav_file)
+ samples = sf.SoundFile(wav_path)
+ if spk == "mix":
+ file_infos.append((wav_path, len(samples)))
+ else:
+ file_infos.append(
+ (
+ wav_path,
+ len(samples),
+ )
+ )
+ if not os.path.exists(os.path.join(out_dir, data_type)):
+ os.makedirs(os.path.join(out_dir, data_type))
+ with open(os.path.join(out_dir, data_type, spk + ".json"), "w") as f:
+ json.dump(file_infos, f, indent=4)
+
+
+def preprocess_librimix_audio(inp_args):
+ """Create .json files for all conditions."""
+ speaker_list = ["mix_both", "s1", "s2"]
+ for data_type in ["train-100", "dev", "test"]:
+ for spk in speaker_list:
+ preprocess_one_dir(
+ inp_args.in_dir, inp_args.out_dir, data_type, spk,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser("Librimix audio data preprocessing")
+ parser.add_argument(
+ "--in_dir",
+ type=str,
+ default=None,
+ help="Directory path of audio including tr, cv and tt",
+ )
+ parser.add_argument(
+ "--out_dir", type=str, default=None, help="Directory path to put output files"
+ )
+ args = parser.parse_args()
+ print(args)
+ preprocess_librimix_audio(args)
\ No newline at end of file
diff --git a/ORIGINAL_README.md b/ORIGINAL_README.md
new file mode 100644
index 0000000000000000000000000000000000000000..2c0fb6c20818d396bcc6ac97eb42d8b14238bd6e
--- /dev/null
+++ b/ORIGINAL_README.md
@@ -0,0 +1,95 @@
+
+
+
+TIGER: Time-frequency Interleaved Gain Extraction and Reconstruction for Efficient Speech Separation
+
+ Mohan Xu* , Kai Li* , Guo Chen, Xiaolin Hu
+ Tsinghua University, Beijing, China
+ * Equal contribution
+ 📜 ICLR 2025 | 🎶 Demo | 🤗 Dataset
+
+
+
+
+
+
+
+
+
+> TIGER is a lightweight model for speech separation which effectively extracts key acoustic features through frequency band-split, multi-scale and full-frequency-frame modeling.
+
+## 💥 News
+
+- **[2025-01-23]** We release the code and pre-trained model of TIGER! 🚀
+- **[2025-01-23]** We release the TIGER model and the EchoSet dataset! 🚀
+
+## 📜 Abstract
+
+In this paper, we propose a speech separation model with significantly reduced parameter size and computational cost: Time-Frequency Interleaved Gain Extraction and Reconstruction Network (TIGER). TIGER leverages prior knowledge to divide frequency bands and applies compression on frequency information. We employ a multi-scale selective attention (MSA) module to extract contextual features, while introducing a full-frequency-frame attention (F^3A) module to capture both temporal and frequency contextual information. Additionally, to more realistically evaluate the performance of speech separation models in complex acoustic environments, we introduce a novel dataset called EchoSet. This dataset includes noise and more realistic reverberation (e.g., considering object occlusions and material properties), with speech from two speakers overlapping at random proportions. Experimental results demonstrated that TIGER significantly outperformed state-of-the-art (SOTA) model TF-GridNet on the EchoSet dataset in both inference speed and separation quality, while reducing the number of parameters by 94.3% and the MACs by 95.3%. These results indicate that by utilizing frequency band-split and interleaved modeling structures, TIGER achieves a substantial reduction in parameters and computational costs while maintaining high performance. Notably, TIGER is the first speech separation model with fewer than 1 million parameters that achieves performance close to the SOTA model.
+
+## TIGER
+
+Overall pipeline of the model architecture of TIGER and its modules.
+
+
+
+## Results
+
+Performance comparisons of TIGER and other existing separation models on ***Libri2Mix, LRS2-2Mix, and EchoSet***. Bold indicates optimal performance, and italics indicate suboptimal performance.
+
+
+
+Efficiency comparisons of TIGER and other models.
+
+
+
+Comparison of performance and efficiency of cinematic sound separation models on DnR. '*' means the result comes from the original paper of DnR.
+
+
+
+## 📦 Installation
+
+```bash
+git clone https://github.com/JusperLee/TIGER.git
+cd TIGER
+pip install -r requirements.txt
+```
+
+## 🚀 Quick Start
+
+### Test with Pre-trained Model
+
+```bash
+# Test using speech
+python inference_speech.py --audio_path test/mix.wav
+
+# Test using DnR
+python inference_dnr.py --audio_path test/test_mixture_466.wav
+```
+
+### Train with EchoSet
+
+```bash
+python audio_train.py --conf_dir configs/tiger.yml
+```
+
+### Evaluate with EchoSet
+
+```bash
+python audio_test.py --conf_dir configs/tiger.yml
+```
+
+## 📖 Citation
+
+```bibtex
+@article{xu2024tiger,
+ title={TIGER: Time-frequency Interleaved Gain Extraction and Reconstruction for Efficient Speech Separation},
+ author={Xu, Mohan and Li, Kai and Chen, Guo and Hu, Xiaolin},
+ journal={arXiv preprint arXiv:2410.01469},
+ year={2024}
+}
+```
+
+## 📧 Contact
+
+If you have any questions, please feel free to contact us via `tsinghua.kaili@gmail.com`.
diff --git a/assets/TIGER.png b/assets/TIGER.png
new file mode 100644
index 0000000000000000000000000000000000000000..21458a248c047762f08f03caf28fd5fe1c5bfac1
--- /dev/null
+++ b/assets/TIGER.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ce1139d21f7b363f2210dc91f4ff6fce1aaebb4cb7432462096d53e76dfd806f
+size 1502140
diff --git a/assets/dnr-demo/sample1/dialog.mp4 b/assets/dnr-demo/sample1/dialog.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..0589d35e1752dc4a91fcec76e3fd5fef3a5b1eeb
--- /dev/null
+++ b/assets/dnr-demo/sample1/dialog.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6f01fa3aeddcca598793810ead2ecf0140f82ebb43e96800578ec52af71f6e35
+size 2810085
diff --git a/assets/dnr-demo/sample1/effect.mp4 b/assets/dnr-demo/sample1/effect.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..c4988833ab761672be742e3ea42ab3ac30f191bb
--- /dev/null
+++ b/assets/dnr-demo/sample1/effect.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a7cbf696ebb01eda225c9ebf77856677afe3ccd983b0e2bc45547ee96cd074fb
+size 2939960
diff --git a/assets/dnr-demo/sample1/mixture.mp4 b/assets/dnr-demo/sample1/mixture.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..3f0caaadda63bb806802a03a2e184b772a323a80
--- /dev/null
+++ b/assets/dnr-demo/sample1/mixture.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a66b7b61dc5a1e4639321b6255a23cc0e504a1492b9396c7e287d28d39627f35
+size 2918024
diff --git a/assets/dnr-demo/sample1/music.mp4 b/assets/dnr-demo/sample1/music.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2d4c20cbd4bf72a53b4e89db0cdd4c90133a05cf
--- /dev/null
+++ b/assets/dnr-demo/sample1/music.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:413d85706158662a0c8852612cb4536be9b62fb4717cd4cbbc7c6a868194a0ca
+size 2884375
diff --git a/assets/dnr-demo/sample2/dialog.mp4 b/assets/dnr-demo/sample2/dialog.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..c445f5774d8e01213747928d915061653c81bc93
--- /dev/null
+++ b/assets/dnr-demo/sample2/dialog.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2677b02f921915a8bafcd9149a297f54657cafbc3464bec833743a6c7ad7779c
+size 1718587
diff --git a/assets/dnr-demo/sample2/effect.mp4 b/assets/dnr-demo/sample2/effect.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..4c45bba6884c07487dad674c5ec3ccddf1c66eb1
--- /dev/null
+++ b/assets/dnr-demo/sample2/effect.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:784a90eadcb0cae93d83ce739d95c06163423d8742448b7ffd373676a2f51a01
+size 1752110
diff --git a/assets/dnr-demo/sample2/mixture.mp4 b/assets/dnr-demo/sample2/mixture.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..15f0c6a76fbf408dd1a687bd116e501f6936a928
--- /dev/null
+++ b/assets/dnr-demo/sample2/mixture.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ab05854c7d9f4bbee331c9e8655b4bdadb4723050332baba027a06b33391fe36
+size 1781103
diff --git a/assets/dnr-demo/sample2/music.mp4 b/assets/dnr-demo/sample2/music.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..320d3f3a7dfb017c9a3dc17e4bebabd03bf36dae
--- /dev/null
+++ b/assets/dnr-demo/sample2/music.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a0b92919ad64ee5e11b14b8a8ab58361fc05c36ba4fe0873d6616f0d9eeead39
+size 1778868
diff --git a/assets/dnr-demo/sample3/dialog.mp4 b/assets/dnr-demo/sample3/dialog.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..da900ea201281156cd09c2139a4d627c8d69f023
--- /dev/null
+++ b/assets/dnr-demo/sample3/dialog.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bc1e9c2336a61630f7f48561d21fc51c5f68731cf2a247e640b2ca8a5a576da2
+size 3988060
diff --git a/assets/dnr-demo/sample3/effect.mp4 b/assets/dnr-demo/sample3/effect.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..8bcb8140cdd6a2fd1b178414508dca16d85dbbcf
--- /dev/null
+++ b/assets/dnr-demo/sample3/effect.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:efaf9714d91b8caf53b7f7b4f3e39a811d6e20393fc48bcbc5d23930ec66742a
+size 4035872
diff --git a/assets/dnr-demo/sample3/mixture.mp4 b/assets/dnr-demo/sample3/mixture.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..6e318d706e0337a3a232a898b8e291237190387a
--- /dev/null
+++ b/assets/dnr-demo/sample3/mixture.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ffb8dee7b106e5c7d1da78608b1992b93bb764825bdc9c22c286f6c04b020fa5
+size 4053761
diff --git a/assets/dnr-demo/sample3/music.mp4 b/assets/dnr-demo/sample3/music.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..8b347505bc68e442047bdb1a2fcdebd13f799338
--- /dev/null
+++ b/assets/dnr-demo/sample3/music.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:90b24a2e5aaf1a4c880c83f68d25179ebb119905658a8f28e8cdc3d33cd24f10
+size 4062344
diff --git a/assets/dnr.png b/assets/dnr.png
new file mode 100644
index 0000000000000000000000000000000000000000..ac81432da61acc90ab2b687a9cd0cd788f0d1cf3
--- /dev/null
+++ b/assets/dnr.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8fce2988567b8637f68dce9d25aa17d17b98fc9b6f0ce4f80ea78144b3eeda80
+size 194016
diff --git a/assets/efficiency.png b/assets/efficiency.png
new file mode 100644
index 0000000000000000000000000000000000000000..fec7fa6b39d910b84343e05ef7f33a323d3bc8e2
--- /dev/null
+++ b/assets/efficiency.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9ccc5b69e2ccbb48b03f44fdbbd88b99366b124307f84b78538996a5ba91d171
+size 405116
diff --git a/assets/logo.png b/assets/logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..d22794ec0ac8766f5e40b0615f3ce606d1421cb7
--- /dev/null
+++ b/assets/logo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e93c060d66ca24d7dcd8f30fab9546e0eea7b09027d8af6b67f72c76f66a88f2
+size 486031
diff --git a/assets/result.png b/assets/result.png
new file mode 100644
index 0000000000000000000000000000000000000000..4047b5bdc4da3b2bc49fde1a1730426e3da4ffb9
--- /dev/null
+++ b/assets/result.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd2d93eec2d4f7c44151a231b830fc687fba3bfd1081a220ac9ab8460eeffcf3
+size 469267
diff --git a/assets/sample1/GroundTruth/mix.wav b/assets/sample1/GroundTruth/mix.wav
new file mode 100644
index 0000000000000000000000000000000000000000..eb063704154aff42d7a52bc909f5a04ac315ac3c
--- /dev/null
+++ b/assets/sample1/GroundTruth/mix.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:96096869c977120503b2ef4b53429a4bce510b344d6a274c3371457be854f38e
+size 192044
diff --git a/assets/sample1/GroundTruth/s1.wav b/assets/sample1/GroundTruth/s1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..ce773338cc58908074a345f7bef11c4e6b49ef24
--- /dev/null
+++ b/assets/sample1/GroundTruth/s1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6128c53d4d57c61eae0b1d7030a155b91db64fb6ca5ec72c982c4f3b6ac3e071
+size 192044
diff --git a/assets/sample1/GroundTruth/s2.wav b/assets/sample1/GroundTruth/s2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..b4c142414daedfb0cbba44923768936d0e17b384
--- /dev/null
+++ b/assets/sample1/GroundTruth/s2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:342f9c15fc0b0445ff82b1026551b33fbafe3ee89c060990ed402ad4ad2a62ee
+size 192044
diff --git a/assets/sample1/TFGNet/s1.wav b/assets/sample1/TFGNet/s1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..bd5bccd81064a97e67fcef30dabe24302dbfbf9c
--- /dev/null
+++ b/assets/sample1/TFGNet/s1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8676fdb6412ed3123d025e145abdd86bbaf2956e55cbc3be5034857ee03b562f
+size 384080
diff --git a/assets/sample1/TFGNet/s2.wav b/assets/sample1/TFGNet/s2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..e94d0d8da0001c4a33b4c36503be5e5d223fd7ee
--- /dev/null
+++ b/assets/sample1/TFGNet/s2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac533785786cac0cacc9f5dde2a32c90c7b057da759be4837b44aa7ddf586b94
+size 384080
diff --git a/assets/sample1/TIGER/s1.wav b/assets/sample1/TIGER/s1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..745df78928513718302157293d6ef5ed69a00d6f
--- /dev/null
+++ b/assets/sample1/TIGER/s1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f8ec094c7c7d9262f734b651735ae209eeaf70f1f7d0e18085ab2db330cdc53f
+size 384080
diff --git a/assets/sample1/TIGER/s2.wav b/assets/sample1/TIGER/s2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..02f697df0a89fd6d89332d3e74c46dbe2efd86b2
--- /dev/null
+++ b/assets/sample1/TIGER/s2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92837bf98498e4f7f23d5f918cf042f60f05c534eb1cc4fd27fadce170e38132
+size 384080
diff --git a/assets/sample1/spec/TFGNet_s1.png b/assets/sample1/spec/TFGNet_s1.png
new file mode 100644
index 0000000000000000000000000000000000000000..a0c3c1c621ce9519265a0058e999e9e6c390fd9a
--- /dev/null
+++ b/assets/sample1/spec/TFGNet_s1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f629a31eebf229ed4aca06fdd15a2da516e2095967cba8590194a54d8719e10b
+size 774751
diff --git a/assets/sample1/spec/TFGNet_s2.png b/assets/sample1/spec/TFGNet_s2.png
new file mode 100644
index 0000000000000000000000000000000000000000..ca3f6fc9f5fd0c88a102153577f620aad7e7dc7e
--- /dev/null
+++ b/assets/sample1/spec/TFGNet_s2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3e9ab62f73ad6721f1b9edd570f364158f95e6ee0e446a4aebd9325aaea284d2
+size 767865
diff --git a/assets/sample1/spec/TIGER_s1.png b/assets/sample1/spec/TIGER_s1.png
new file mode 100644
index 0000000000000000000000000000000000000000..680dc9ee965f198b0eabb0c6cc6871c3e8279714
--- /dev/null
+++ b/assets/sample1/spec/TIGER_s1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:34eac91edfbf96ce850c6474e60a5add51151579bc423d63d27593b07df64c9e
+size 771902
diff --git a/assets/sample1/spec/TIGER_s2.png b/assets/sample1/spec/TIGER_s2.png
new file mode 100644
index 0000000000000000000000000000000000000000..ed03dfa6de727dd286c2b6479dfe09cb03672a55
--- /dev/null
+++ b/assets/sample1/spec/TIGER_s2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:019899feacd4df56a66f4052352e5e4252ac1e20fbcab106099f3696ef5b23ac
+size 792674
diff --git a/assets/sample1/spec/ground_truth_s1.png b/assets/sample1/spec/ground_truth_s1.png
new file mode 100644
index 0000000000000000000000000000000000000000..30deb410116f9ea5571dea52ff3935bc8dcfee2f
--- /dev/null
+++ b/assets/sample1/spec/ground_truth_s1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:87fab6fb1e5cb7415e106b4aed57a108e8160cee893cd9c0ece3ecab2ca5bded
+size 619171
diff --git a/assets/sample1/spec/ground_truth_s2.png b/assets/sample1/spec/ground_truth_s2.png
new file mode 100644
index 0000000000000000000000000000000000000000..07444ce0fd1df8f4a671729d58b7891536170a56
--- /dev/null
+++ b/assets/sample1/spec/ground_truth_s2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3c160511e75a0c0e3998e634490f71db2f6a403b408ec2af8312b614aba6ad12
+size 685215
diff --git a/assets/sample2/GroundTruth/mix.wav b/assets/sample2/GroundTruth/mix.wav
new file mode 100644
index 0000000000000000000000000000000000000000..bcd6f7213bde7e9215008b6d8b4cae1412c585fb
--- /dev/null
+++ b/assets/sample2/GroundTruth/mix.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:26f24e81d600dab059662afec2d0ef6197875ce8e233e554c423d435527560a8
+size 192044
diff --git a/assets/sample2/GroundTruth/s1.wav b/assets/sample2/GroundTruth/s1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..446b30d3e874c0e28b383489cef785e3cfec7a83
--- /dev/null
+++ b/assets/sample2/GroundTruth/s1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bbfcb21a93c1c1896800a5d057f8e9c497c27c92a36233c90920c6a2e1831ee2
+size 192044
diff --git a/assets/sample2/GroundTruth/s2.wav b/assets/sample2/GroundTruth/s2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..28b023dbeeb3d179bca5e732d935bb07f981ffb0
--- /dev/null
+++ b/assets/sample2/GroundTruth/s2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ad6579a1ee21ac0d9fb416f9d79a10d65f52f7f5a51759355d259141e966e33e
+size 192044
diff --git a/assets/sample2/TFGNet/s1.wav b/assets/sample2/TFGNet/s1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..b7bfbbca5fa3745fb072779e8a59a7a2202ed601
--- /dev/null
+++ b/assets/sample2/TFGNet/s1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:401a5efe1d3feda8a1c34a4f887bd353114166561d6aa81e28097fc68933004b
+size 384080
diff --git a/assets/sample2/TFGNet/s2.wav b/assets/sample2/TFGNet/s2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..3d8f0796cae5ddc7c3efe84f0bc0e44cc6b10d96
--- /dev/null
+++ b/assets/sample2/TFGNet/s2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:59c8a509d19b47254ff37a7f15b0194a49bd2b154ad2725ac7d07efe9ff174bd
+size 384080
diff --git a/assets/sample2/TIGER/s1.wav b/assets/sample2/TIGER/s1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..a8c61eff87eea39b0b0b9a668838e618ba6237c4
--- /dev/null
+++ b/assets/sample2/TIGER/s1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5e804f6d6c6d63f3756427e8c6832a9a28a0bbbcb808e2f7faf14acdc5adc5d8
+size 384080
diff --git a/assets/sample2/TIGER/s2.wav b/assets/sample2/TIGER/s2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..f4ec2adebc434bd870c4116f2014035be4da9194
--- /dev/null
+++ b/assets/sample2/TIGER/s2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:45f92e18ce5371fcc3d07a7010e406efad0092090a3003e9252f901e12e353ba
+size 384080
diff --git a/assets/sample2/spec/TFGNet_s1.png b/assets/sample2/spec/TFGNet_s1.png
new file mode 100644
index 0000000000000000000000000000000000000000..1e8fc2d5355eaaf0288c1606d38eecc5361131c1
--- /dev/null
+++ b/assets/sample2/spec/TFGNet_s1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b95c0a00dcd44f794747fc86d9f916031994e9f21728a31ecac97c2e941fc522
+size 434257
diff --git a/assets/sample2/spec/TFGNet_s2.png b/assets/sample2/spec/TFGNet_s2.png
new file mode 100644
index 0000000000000000000000000000000000000000..81d5ca8904760add9bac16fcdd8de0b91638f9a6
--- /dev/null
+++ b/assets/sample2/spec/TFGNet_s2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6bdd22e56556506cec95009049243cdf1286acd2dfd27e073b8ee75319d9c2bf
+size 727285
diff --git a/assets/sample2/spec/TIGER_s1.png b/assets/sample2/spec/TIGER_s1.png
new file mode 100644
index 0000000000000000000000000000000000000000..b1d6bd11d93e25b4f929d3394826ad641a374910
--- /dev/null
+++ b/assets/sample2/spec/TIGER_s1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:f09dba6957f58fe02dbb4557d036d4edae00d101b1e96edfc76de27e44f4ba5a
+size 565372
diff --git a/assets/sample2/spec/TIGER_s2.png b/assets/sample2/spec/TIGER_s2.png
new file mode 100644
index 0000000000000000000000000000000000000000..188f6d3f09329e87f0ca784e4f970a56f572ac61
--- /dev/null
+++ b/assets/sample2/spec/TIGER_s2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ba061680e3a36ceebe158c40f51db15cd2c322de33c41adaaf8c0358b624dd08
+size 582429
diff --git a/assets/sample2/spec/ground_truth_s1.png b/assets/sample2/spec/ground_truth_s1.png
new file mode 100644
index 0000000000000000000000000000000000000000..41ed0fe6dee762c53314c65926fa71907cc4a41e
--- /dev/null
+++ b/assets/sample2/spec/ground_truth_s1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:11216b3f40d513c5297b54da5dc989f96305efac6c915f893c3524d6c0734622
+size 534150
diff --git a/assets/sample2/spec/ground_truth_s2.png b/assets/sample2/spec/ground_truth_s2.png
new file mode 100644
index 0000000000000000000000000000000000000000..931b9de96e5b23528e68ac32caa0a71ad0b5558c
--- /dev/null
+++ b/assets/sample2/spec/ground_truth_s2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:def1f97d15ef8c61d8e8a2e3144e861dcacd6c945e19375be3a2ca3749f15653
+size 584546
diff --git a/assets/sample3/GroundTruth/mix.wav b/assets/sample3/GroundTruth/mix.wav
new file mode 100644
index 0000000000000000000000000000000000000000..b01655e1b58ec92473bd96a096d14eee872376a9
--- /dev/null
+++ b/assets/sample3/GroundTruth/mix.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:92be7d03f4ca97e3ee88272490838fe77d64207fb0d326cd70391da061498a8a
+size 192044
diff --git a/assets/sample3/GroundTruth/s1.wav b/assets/sample3/GroundTruth/s1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..59dbff1b2a3b4ef4ce57bdf4763bfb152c73c5d3
--- /dev/null
+++ b/assets/sample3/GroundTruth/s1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4b4071c7f390ce9e9a3ccfcf58afaec1e5b237aa150c3c87cf46a3c5a361e2f7
+size 192044
diff --git a/assets/sample3/GroundTruth/s2.wav b/assets/sample3/GroundTruth/s2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..82e669cafb571aea92459f3f830e6dfc3bda19dc
--- /dev/null
+++ b/assets/sample3/GroundTruth/s2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:48354c78aa7af0099c3bad5df18b587bc4ee897bd1ea00ac6f3db77c69055711
+size 192044
diff --git a/assets/sample3/TFGNet/s1.wav b/assets/sample3/TFGNet/s1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..bbbd1692e80eb4300fa1868d819550d58c1b8de6
--- /dev/null
+++ b/assets/sample3/TFGNet/s1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a669856a8a7ca1c4439a2cf1b88cfa5830a031449ef7950653120fd6ce0eeeca
+size 384080
diff --git a/assets/sample3/TFGNet/s2.wav b/assets/sample3/TFGNet/s2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..085bfc4614e658db90a01a70778b8c783494bfa1
--- /dev/null
+++ b/assets/sample3/TFGNet/s2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e409459991ef9b50a36267b6ea0cefa47263b25a3c6a6bb7d04a99681a06abaf
+size 384080
diff --git a/assets/sample3/TIGER/s1.wav b/assets/sample3/TIGER/s1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..8f6bd186b0a1b1b073aeaa7e952d3435bca97846
--- /dev/null
+++ b/assets/sample3/TIGER/s1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6784a1d0a0b2e7a1d69a523ba824b5493465301727fe79e0a99d0fc6f52d0ef
+size 384080
diff --git a/assets/sample3/TIGER/s2.wav b/assets/sample3/TIGER/s2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..cf2510e2f8487e8f409469bad1d537931c8ad084
--- /dev/null
+++ b/assets/sample3/TIGER/s2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:188edc3e4fac51b7494e5d7e06f103c250d185876e3cc71af908f613829de97a
+size 384080
diff --git a/assets/sample3/spec/TFGNet_s1.png b/assets/sample3/spec/TFGNet_s1.png
new file mode 100644
index 0000000000000000000000000000000000000000..9b14b798b1dc01ad60a000bd85bd602a9eed88e1
--- /dev/null
+++ b/assets/sample3/spec/TFGNet_s1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6cc1f43f5f52fa9faf2e2485a6196b5959f9679f996b3522d310e63a7c3e55f7
+size 631623
diff --git a/assets/sample3/spec/TFGNet_s2.png b/assets/sample3/spec/TFGNet_s2.png
new file mode 100644
index 0000000000000000000000000000000000000000..a4fa8cc3f3cb125bbbef5c07658abbcfa3d47c86
--- /dev/null
+++ b/assets/sample3/spec/TFGNet_s2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:83d0c746152c5808b33e18d04de0561cd503552b66299676e534eff24f0d0eea
+size 541215
diff --git a/assets/sample3/spec/TIGER_s1.png b/assets/sample3/spec/TIGER_s1.png
new file mode 100644
index 0000000000000000000000000000000000000000..a2babb155ef321359b40a667c2eee81f6b2af90e
--- /dev/null
+++ b/assets/sample3/spec/TIGER_s1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b3913b89117845e2b1d0a794c4203eee01558bc6a1702cee4ca01c7d8e6ba104
+size 550682
diff --git a/assets/sample3/spec/TIGER_s2.png b/assets/sample3/spec/TIGER_s2.png
new file mode 100644
index 0000000000000000000000000000000000000000..408cab28c1f42232211437d243c229b87ef4a6f2
--- /dev/null
+++ b/assets/sample3/spec/TIGER_s2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7d1e4d09989f758371536f84c2576ac72a31914f4446fd589b37040a1302b40e
+size 660406
diff --git a/assets/sample3/spec/ground_truth_s1.png b/assets/sample3/spec/ground_truth_s1.png
new file mode 100644
index 0000000000000000000000000000000000000000..6c2e9341adb25258ecda566fa3045c843e5f8807
--- /dev/null
+++ b/assets/sample3/spec/ground_truth_s1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6f170d33ffc66d730693cc071273d9aef3e58be850f9bd8b479cc93ced99d0cb
+size 548248
diff --git a/assets/sample3/spec/ground_truth_s2.png b/assets/sample3/spec/ground_truth_s2.png
new file mode 100644
index 0000000000000000000000000000000000000000..dc8e323920bd84fac4269dbc8f41598ea1c4f2b7
--- /dev/null
+++ b/assets/sample3/spec/ground_truth_s2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:80942ea5e1742fb77f653625540f36e08821c333631d5b9a6f1d37d570b68a8b
+size 686893
diff --git a/audio_test.py b/audio_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..2602981634f91d4f03746f21927ae58407fdbdd4
--- /dev/null
+++ b/audio_test.py
@@ -0,0 +1,116 @@
+import os
+import random
+from typing import Union
+import soundfile as sf
+import torch
+import yaml
+import json
+import argparse
+import numpy as np
+import pandas as pd
+from tqdm import tqdm
+from pprint import pprint
+from scipy.io import wavfile
+import warnings
+import torchaudio
+warnings.filterwarnings("ignore")
+import look2hear.models
+import look2hear.datas
+from look2hear.metrics import MetricsTracker
+from look2hear.utils import tensors_to_device, RichProgressBarTheme, MyMetricsTextColumn, BatchesProcessedColumn
+
+from rich.progress import (
+ BarColumn,
+ Progress,
+ TextColumn,
+ TimeRemainingColumn,
+ TransferSpeedColumn,
+)
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--conf_dir",
+ default="local/mixit_conf.yml",
+ help="Full path to save best validation model")
+
+
+compute_metrics = ["si_sdr", "sdr"]
+os.environ['CUDA_VISIBLE_DEVICES'] = "8"
+
+def main(config):
+ metricscolumn = MyMetricsTextColumn(style=RichProgressBarTheme.metrics)
+ progress = Progress(
+ TextColumn("[bold blue]Testing", justify="right"),
+ BarColumn(bar_width=None),
+ "•",
+ BatchesProcessedColumn(style=RichProgressBarTheme.batch_progress),
+ "•",
+ TransferSpeedColumn(),
+ "•",
+ TimeRemainingColumn(),
+ "•",
+ metricscolumn
+ )
+ # import pdb; pdb.set_trace()
+ config["train_conf"]["main_args"]["exp_dir"] = os.path.join(
+ os.getcwd(), "Experiments", "checkpoint", config["train_conf"]["exp"]["exp_name"]
+ )
+ model_path = os.path.join(config["train_conf"]["main_args"]["exp_dir"], "best_model.pth")
+ # import pdb; pdb.set_trace()
+ # conf["train_conf"]["masknet"].update({"n_src": 2})
+ model = getattr(look2hear.models, config["train_conf"]["audionet"]["audionet_name"]).from_pretrain(
+ model_path,
+ sample_rate=config["train_conf"]["datamodule"]["data_config"]["sample_rate"],
+ **config["train_conf"]["audionet"]["audionet_config"],
+ )
+ if config["train_conf"]["training"]["gpus"]:
+ device = "cuda"
+ model.to(device)
+ model_device = next(model.parameters()).device
+ datamodule: object = getattr(look2hear.datas, config["train_conf"]["datamodule"]["data_name"])(
+ **config["train_conf"]["datamodule"]["data_config"]
+ )
+ datamodule.setup()
+ _, _ , test_set = datamodule.make_sets
+
+ # Randomly choose the indexes of sentences to save.
+ ex_save_dir = os.path.join(config["train_conf"]["main_args"]["exp_dir"], "results/")
+ os.makedirs(ex_save_dir, exist_ok=True)
+ metrics = MetricsTracker(
+ save_file=os.path.join(ex_save_dir, "metrics.csv"))
+ torch.no_grad().__enter__()
+ with progress:
+ for idx in progress.track(range(len(test_set))):
+ if idx == 825:
+ # Forward the network on the mixture.
+ mix, sources, key = tensors_to_device(test_set[idx],
+ device=model_device)
+ est_sources = model(mix[None])
+ mix_np = mix
+ sources_np = sources
+ est_sources_np = est_sources.squeeze(0)
+ # metrics(mix=mix_np,
+ # clean=sources_np,
+ # estimate=est_sources_np,
+ # key=key)
+ save_dir = os.path.join("./result/TIGER", "idx{}".format(idx))
+ # est_sources_np = normalize_tensor_wav(est_sources_np)
+ for i in range(est_sources_np.shape[0]):
+ os.makedirs(os.path.join(save_dir, "s{}/".format(i + 1)), exist_ok=True)
+ # torchaudio.save(os.path.join(save_dir, "s{}/".format(i + 1)) + key, est_sources_np[i].unsqueeze(0).cpu(), 16000)
+ torchaudio.save(os.path.join(save_dir, "s{}/".format(i + 1)) + key.split("/")[-1], est_sources_np[i].unsqueeze(0).cpu(), 16000)
+ # if idx % 50 == 0:
+ # metricscolumn.update(metrics.update())
+ metrics.final()
+
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ arg_dic = dict(vars(args))
+
+ # Load training config
+ with open(args.conf_dir, "rb") as f:
+ train_conf = yaml.safe_load(f)
+ arg_dic["train_conf"] = train_conf
+ # print(arg_dic)
+ main(arg_dic)
diff --git a/audio_train.py b/audio_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e53b9d714e459598edd1730d6add4421f6eb47b
--- /dev/null
+++ b/audio_train.py
@@ -0,0 +1,202 @@
+import os
+import sys
+import torch
+from torch import Tensor
+import argparse
+import json
+import look2hear.datas
+import look2hear.models
+import look2hear.system
+import look2hear.losses
+import look2hear.metrics
+import look2hear.utils
+from look2hear.system import make_optimizer
+from dataclasses import dataclass
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+from torch.utils.data import DataLoader
+import pytorch_lightning as pl
+from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, RichProgressBar
+from pytorch_lightning.callbacks.progress.rich_progress import *
+from rich.console import Console
+from pytorch_lightning.loggers import TensorBoardLogger
+from pytorch_lightning.loggers.wandb import WandbLogger
+from pytorch_lightning.strategies.ddp import DDPStrategy
+from rich import print, reconfigure
+from collections.abc import MutableMapping
+from look2hear.utils import print_only, MyRichProgressBar, RichProgressBarTheme
+
+import warnings
+
+warnings.filterwarnings("ignore")
+
+import wandb
+wandb.login()
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--conf_dir",
+ default="local/conf.yml",
+ help="Full path to save best validation model",
+)
+
+def main(config):
+ print_only(
+ "Instantiating datamodule <{}>".format(config["datamodule"]["data_name"])
+ )
+ datamodule: object = getattr(look2hear.datas, config["datamodule"]["data_name"])(
+ **config["datamodule"]["data_config"]
+ )
+ datamodule.setup()
+
+ train_loader, val_loader, test_loader = datamodule.make_loader
+
+ # Define model and optimizer
+ print_only(
+ "Instantiating AudioNet <{}>".format(config["audionet"]["audionet_name"])
+ )
+ model = getattr(look2hear.models, config["audionet"]["audionet_name"])(
+ sample_rate=config["datamodule"]["data_config"]["sample_rate"],
+ **config["audionet"]["audionet_config"],
+ )
+ # import pdb; pdb.set_trace()
+ print_only("Instantiating Optimizer <{}>".format(config["optimizer"]["optim_name"]))
+ optimizer = make_optimizer(model.parameters(), **config["optimizer"])
+
+ # Define scheduler
+ scheduler = None
+ if config["scheduler"]["sche_name"]:
+ print_only(
+ "Instantiating Scheduler <{}>".format(config["scheduler"]["sche_name"])
+ )
+ if config["scheduler"]["sche_name"] != "DPTNetScheduler":
+ scheduler = getattr(torch.optim.lr_scheduler, config["scheduler"]["sche_name"])(
+ optimizer=optimizer, **config["scheduler"]["sche_config"]
+ )
+ else:
+ scheduler = {
+ "scheduler": getattr(look2hear.system.schedulers, config["scheduler"]["sche_name"])(
+ optimizer, len(train_loader) // config["datamodule"]["data_config"]["batch_size"], 64
+ ),
+ "interval": "step",
+ }
+
+ # Just after instantiating, save the args. Easy loading in the future.
+ config["main_args"]["exp_dir"] = os.path.join(
+ os.getcwd(), "Experiments", "checkpoint", config["exp"]["exp_name"]
+ )
+ exp_dir = config["main_args"]["exp_dir"]
+ os.makedirs(exp_dir, exist_ok=True)
+ conf_path = os.path.join(exp_dir, "conf.yml")
+ with open(conf_path, "w") as outfile:
+ yaml.safe_dump(config, outfile)
+
+ # Define Loss function.
+ print_only(
+ "Instantiating Loss, Train <{}>, Val <{}>".format(
+ config["loss"]["train"]["sdr_type"], config["loss"]["val"]["sdr_type"]
+ )
+ )
+ loss_func = {
+ "train": getattr(look2hear.losses, config["loss"]["train"]["loss_func"])(
+ getattr(look2hear.losses, config["loss"]["train"]["sdr_type"]),
+ **config["loss"]["train"]["config"],
+ ),
+ "val": getattr(look2hear.losses, config["loss"]["val"]["loss_func"])(
+ getattr(look2hear.losses, config["loss"]["val"]["sdr_type"]),
+ **config["loss"]["val"]["config"],
+ ),
+ }
+
+ print_only("Instantiating System <{}>".format(config["training"]["system"]))
+ system = getattr(look2hear.system, config["training"]["system"])(
+ audio_model=model,
+ loss_func=loss_func,
+ optimizer=optimizer,
+ train_loader=train_loader,
+ val_loader=val_loader,
+ test_loader=test_loader,
+ scheduler=scheduler,
+ config=config,
+ )
+
+ # Define callbacks
+ print_only("Instantiating ModelCheckpoint")
+ callbacks = []
+ checkpoint_dir = os.path.join(exp_dir)
+ checkpoint = ModelCheckpoint(
+ checkpoint_dir,
+ filename="{epoch}",
+ monitor="val_loss/dataloader_idx_0",
+ mode="min",
+ save_top_k=5,
+ verbose=True,
+ save_last=True,
+ )
+ callbacks.append(checkpoint)
+
+ if config["training"]["early_stop"]:
+ print_only("Instantiating EarlyStopping")
+ callbacks.append(EarlyStopping(**config["training"]["early_stop"]))
+ callbacks.append(MyRichProgressBar(theme=RichProgressBarTheme()))
+
+ # Don't ask GPU if they are not available.
+ gpus = config["training"]["gpus"] if torch.cuda.is_available() else None
+ distributed_backend = "cuda" if torch.cuda.is_available() else None
+
+ # default logger used by trainer
+ logger_dir = os.path.join(os.getcwd(), "Experiments", "tensorboard_logs")
+ os.makedirs(os.path.join(logger_dir, config["exp"]["exp_name"]), exist_ok=True)
+ # comet_logger = TensorBoardLogger(logger_dir, name=config["exp"]["exp_name"])
+ comet_logger = WandbLogger(
+ name=config["exp"]["exp_name"],
+ save_dir=os.path.join(logger_dir, config["exp"]["exp_name"]),
+ project="Real-work-dataset",
+ # offline=True
+ )
+
+ trainer = pl.Trainer(
+ max_epochs=config["training"]["epochs"],
+ callbacks=callbacks,
+ default_root_dir=exp_dir,
+ devices=gpus,
+ accelerator=distributed_backend,
+ strategy=DDPStrategy(find_unused_parameters=True),
+ limit_train_batches=1.0, # Useful for fast experiment
+ gradient_clip_val=5.0,
+ logger=comet_logger,
+ sync_batchnorm=True,
+ # precision="bf16-mixed",
+ # num_sanity_val_steps=0,
+ # sync_batchnorm=True,
+ # fast_dev_run=True,
+ )
+ trainer.fit(system)
+ print_only("Finished Training")
+ best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
+ with open(os.path.join(exp_dir, "best_k_models.json"), "w") as f:
+ json.dump(best_k, f, indent=0)
+
+ state_dict = torch.load(checkpoint.best_model_path)
+ system.load_state_dict(state_dict=state_dict["state_dict"])
+ system.cpu()
+
+ to_save = system.audio_model.serialize()
+ torch.save(to_save, os.path.join(exp_dir, "best_model.pth"))
+
+
+if __name__ == "__main__":
+ import yaml
+ from pprint import pprint
+ from look2hear.utils.parser_utils import (
+ prepare_parser_from_dict,
+ parse_args_as_dict,
+ )
+
+ args = parser.parse_args()
+ with open(args.conf_dir) as f:
+ def_conf = yaml.safe_load(f)
+ parser = prepare_parser_from_dict(def_conf, parser=parser)
+
+ arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True)
+ # pprint(arg_dic)
+ main(arg_dic)
diff --git a/configs/tiger-large.yml b/configs/tiger-large.yml
new file mode 100644
index 0000000000000000000000000000000000000000..61be3a25962433e4333e307933632c589e6e4d5b
--- /dev/null
+++ b/configs/tiger-large.yml
@@ -0,0 +1,71 @@
+# Network config
+audionet:
+ audionet_name: TIGER
+ audionet_config:
+ out_channels: 128
+ in_channels: 256
+ num_blocks: 8
+ upsampling_depth: 5
+ win: 640
+ stride: 160
+ num_sources: 2
+
+# Loss config
+loss:
+ train:
+ loss_func: PITLossWrapper
+ sdr_type: pairwise_neg_snr
+ config:
+ pit_from: pw_mtx
+ threshold_byloss: false
+ val:
+ loss_func: PITLossWrapper
+ sdr_type: pairwise_neg_sisdr
+ config:
+ pit_from: pw_mtx
+ threshold_byloss: false
+
+# Training config
+training:
+ system: AudioLightningModule
+ gpus: [2,3,4,5,6,7,8,9]
+ parallel: ddp
+ epochs: 500
+ early_stop:
+ monitor: val_loss/dataloader_idx_0
+ mode: min
+ patience: 20
+ verbose: true
+ SpeedAug: False
+
+# Optim config
+optimizer:
+ optim_name: adam
+ lr: 0.001
+ weight_decay: 0
+
+# Sche config
+scheduler:
+ sche_name: ReduceLROnPlateau
+ sche_config:
+ patience: 10
+ factor: 0.5
+
+# Data config
+datamodule:
+ data_name: EchoSetDataModule
+ data_config:
+ train_dir: DataPreProcess/EchoSet/train
+ valid_dir: DataPreProcess/EchoSet/val
+ test_dir: DataPreProcess/EchoSet/test
+ n_src: 2
+ sample_rate: 16000
+ segment: 3.0
+ normalize_audio: false
+ batch_size: 1
+ num_workers: 8
+ pin_memory: true
+ persistent_workers: false
+
+exp:
+ exp_name: TIGER-EchoSet
\ No newline at end of file
diff --git a/configs/tiger-small.yml b/configs/tiger-small.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7b9e6c22cbe6bfb6a9faa206b34f1e47e668dc62
--- /dev/null
+++ b/configs/tiger-small.yml
@@ -0,0 +1,71 @@
+# Network config
+audionet:
+ audionet_name: TIGER
+ audionet_config:
+ out_channels: 128
+ in_channels: 256
+ num_blocks: 4
+ upsampling_depth: 5
+ win: 640
+ stride: 160
+ num_sources: 2
+
+# Loss config
+loss:
+ train:
+ loss_func: PITLossWrapper
+ sdr_type: pairwise_neg_snr
+ config:
+ pit_from: pw_mtx
+ threshold_byloss: false
+ val:
+ loss_func: PITLossWrapper
+ sdr_type: pairwise_neg_sisdr
+ config:
+ pit_from: pw_mtx
+ threshold_byloss: false
+
+# Training config
+training:
+ system: AudioLightningModule
+ gpus: [2,3,4,5,6,7,8,9]
+ parallel: ddp
+ epochs: 500
+ early_stop:
+ monitor: val_loss/dataloader_idx_0
+ mode: min
+ patience: 20
+ verbose: true
+ SpeedAug: False
+
+# Optim config
+optimizer:
+ optim_name: adam
+ lr: 0.001
+ weight_decay: 0
+
+# Sche config
+scheduler:
+ sche_name: ReduceLROnPlateau
+ sche_config:
+ patience: 10
+ factor: 0.5
+
+# Data config
+datamodule:
+ data_name: EchoSetDataModule
+ data_config:
+ train_dir: DataPreProcess/EchoSet/train
+ valid_dir: DataPreProcess/EchoSet/val
+ test_dir: DataPreProcess/EchoSet/test
+ n_src: 2
+ sample_rate: 16000
+ segment: 3.0
+ normalize_audio: false
+ batch_size: 1
+ num_workers: 8
+ pin_memory: true
+ persistent_workers: false
+
+exp:
+ exp_name: TIGER-EchoSet
\ No newline at end of file
diff --git a/evaluated_mac_params.py b/evaluated_mac_params.py
new file mode 100644
index 0000000000000000000000000000000000000000..3972bd6c0c41ab6b6316449359dd560e9f3c1d76
--- /dev/null
+++ b/evaluated_mac_params.py
@@ -0,0 +1,53 @@
+import os
+import argparse
+import json
+import time
+import torch
+import torch.nn as nn
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+from torch.utils.data import DataLoader
+import pytorch_lightning as pl
+from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
+from look2hear.utils.parser_utils import prepare_parser_from_dict, parse_args_as_dict
+import look2hear.models
+import yaml
+from ptflops import get_model_complexity_info
+from rich import print
+
+def check_parameters(net):
+ """
+ Returns module parameters. Mb
+ """
+ parameters = sum(param.numel() for param in net.parameters())
+ return parameters / 10 ** 6
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--exp_dir", default="exp/tmp", help="Full path to save best validation model"
+)
+
+with open("configs/tiger.yml") as f:
+ def_conf = yaml.safe_load(f)
+parser = prepare_parser_from_dict(def_conf, parser=parser)
+
+arg_dic, plain_args = parse_args_as_dict(parser, return_plain_args=True)
+audiomodel = getattr(look2hear.models, arg_dic["audionet"]["audionet_name"])(
+ sample_rate=arg_dic["datamodule"]["data_config"]["sample_rate"],
+ **arg_dic["audionet"]["audionet_config"]
+)
+# 配置GPU为mps
+device = torch.device("mps")
+a = torch.randn(1, 1, 16000).to(device)
+total_macs = 0
+total_params = 0
+model = audiomodel.to(device)
+with torch.no_grad():
+ macs, params = get_model_complexity_info(
+ model, (16000,), as_strings=False, print_per_layer_stat=True, verbose=False
+ )
+print(model(a).shape)
+total_macs += macs
+total_params += params
+print("MACs: ", total_macs / 10.0 ** 9)
+print("Params: ", total_params / 10.0 ** 6)
\ No newline at end of file
diff --git a/index.html b/index.html
new file mode 100644
index 0000000000000000000000000000000000000000..7f4c035f90a5d5e99585fadf759e77f2bf8a7542
--- /dev/null
+++ b/index.html
@@ -0,0 +1,791 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ TIGER: Time-frequency Interleaved Gain Extraction and Reconstruction for Efficient Speech Separation
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ TIGER: Time-frequency Interleaved Gain Extraction and Reconstruction for Efficient
+ Speech Separation
+
+
+
+
+
+
+ ━ ICLR 2025 ━
+
+
+
+
+
+
+ Mohan Xu*
+
+
+
+
+ Kai Li*
+
+
+
+
+ Guo Chen
+
+
+
+
+ Xiaolin Hu
+
+
+
+
+
+
+
+
+
+
+
+
+ TIGER is a lightweight model
+ for speech separation which
+ effectively extracts key acoustic features through frequency band-split, multi-scale
+ and full-frequency-frame modeling.
+
+
+
+
+ Abstract
+
+
+
+ In recent years, much speech separation research has focused primarily on improving model performance.
+ However, for low-latency speech processing systems, computational efficiency is equally critical.
+ In this paper, we propose a speech separation model with significantly reduced parameter size and
+ computational cost:
+ Time-Frequency Interleaved Gain Extraction and Reconstruction Network (TIGER) .
+ TIGER leverages prior knowledge to divide frequency
+ bands and applies compression on frequency information.
+ We employ a multi-scale selective attention (MSA) module to extract contextual features,
+ while introducing a full-frequency-frame attention (F^3A) module to capture both temporal and
+ frequency contextual information.
+ Additionally, to more realistically evaluate the performance of speech separation models in complex
+ acoustic environments,
+ we introduce a novel dataset called EchoSet . This dataset includes noise and more realistic
+ reverberation
+ (e.g., considering object occlusions and material properties), with speech from two speakers overlapping
+ at random proportions.
+ Experimental results demonstrated that TIGER significantly outperformed
+ state-of-the-art (SOTA) model TF-GridNet
+ on the EchoSet dataset in both inference speed and separation quality, while reducing the number of
+ parameters by 94.3 %
+ and the MACs by 95.3 %. These results indicate that by utilizing frequency band-split and
+ interleaved modeling structures,
+ TIGER achieves a substantial reduction in parameters and computational costs while maintaining high
+ performance.
+ Notably, TIGER is the first speech separation model with fewer than 1 million parameters
+ that achieves performance close to the SOTA model.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Overall pipeline of the model architecture of TIGER and
+ its
+ modules .
+
+
+
+
+
+
+
+
+ Results on Speech Separation
+
+
+
+
+ Performance comparisons of TIGER and other existing separation models on Libri2Mix, LRS2-2Mix, and EchoSet .
+ Bold indicates optimal performance, and italics indicate suboptimal performance.
+
+
+
+
+
+ Efficiency comparisons of TIGER and other models.
+
+
+
+
+
+ Results on Cinematic Sound Separation
+
+
+
+ Comparison of performance and efficiency of cinematic sound
+ separation models on DnR .
+ ‘*’ means the result comes from the original paper of DnR.
+
+
+
+
+
+ Speech Separation Demo
+
+
+ EchoSet Sample I
+
+
+
+ Mixture
+
+
+
+
+
+ Your browser does
+ not support the audio element.
+
+
+
+
+
+
+
+ Ground Truth
+ TF-GridNet
+ TIGER
+
+
+
+
+
+ Your browser does
+ not support the audio element.
+
+
+
+ Your browser does not
+ support the audio element.
+
+
+ Your browser does not
+ support the audio element.
+
+
+
+
+ Your browser does
+ not support the audio element.
+
+
+
+ Your browser does not
+ support the audio element.
+
+
+ Your browser does not
+ support the audio element.
+
+
+
+
+
+
+ EchoSet Sample II
+
+
+
+ Mixture
+
+
+
+
+
+ Your browser does
+ not support the audio element.
+
+
+
+
+
+
+
+ Ground Truth
+ TF-GridNet
+ TIGER
+
+
+
+
+
+ Your browser does
+ not support the audio element.
+
+
+
+ Your browser does not
+ support the audio element.
+
+
+ Your browser does not
+ support the audio element.
+
+
+
+
+ Your browser does
+ not support the audio element.
+
+
+
+ Your browser does not
+ support the audio element.
+
+
+ Your browser does not
+ support the audio element.
+
+
+
+
+
+
+ EchoSet Sample III
+
+
+
+ Mixture
+
+
+
+
+
+ Your browser does
+ not support the audio element.
+
+
+
+
+
+
+
+ Ground Truth
+ TF-GridNet
+ TIGER
+
+
+
+
+
+ Your browser does
+ not support the audio element.
+
+
+
+ Your browser does not
+ support the audio element.
+
+
+ Your browser does not
+ support the audio element.
+
+
+
+
+ Your browser does
+ not support the audio element.
+
+
+
+ Your browser does not
+ support the audio element.
+
+
+ Your browser does not
+ support the audio element.
+
+
+
+
+
+
+
+ Cinematic Sound Demo
+
+
+
+
+
+ Deadpool
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Dialog
+ Music
+ Effect
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ The Wandering Earth
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Dialog
+ Music
+ Effect
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Captain America
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Dialog
+ Music
+ Effect
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ ▶ cslikai.cn's
+ clustrmaps 🌎 .
+
+
+
+
+
+
\ No newline at end of file
diff --git a/inference_dnr.py b/inference_dnr.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5daf6bb5c1ef6cd750c7eaab3007bb1561de91a
--- /dev/null
+++ b/inference_dnr.py
@@ -0,0 +1,28 @@
+import yaml
+import os
+import look2hear.models
+import argparse
+import torch
+import torchaudio
+os.environ["CUDA_VISIBLE_DEVICES"] = "2"
+
+# audio path
+parser = argparse.ArgumentParser()
+parser.add_argument("--audio_path", default="test/test_mixture_466.wav", help="Path to audio file")
+
+device = torch.device("cuda:7") if torch.cuda.is_available() else torch.device("cpu")
+print(device)
+
+# # Load model
+model = look2hear.models.TIGERDNR.from_pretrained("JusperLee/TIGER-DnR", cache_dir="cache")
+model.to(device)
+model.eval()
+
+audio = torchaudio.load(parser.parse_args().audio_path)[0].to(device)
+
+with torch.no_grad():
+ all_target_dialog, all_target_effect, all_target_music = model(audio[None])
+
+torchaudio.save("test/test_target_dialog_466.wav", all_target_dialog.cpu(), 44100)
+torchaudio.save("test/test_target_effect_466.wav", all_target_effect.cpu(), 44100)
+torchaudio.save("test/test_target_music_466.wav", all_target_music.cpu(), 44100)
\ No newline at end of file
diff --git a/inference_speech.py b/inference_speech.py
new file mode 100644
index 0000000000000000000000000000000000000000..830c2f0f5da633f8b169a3d5c9806a790f5c2505
--- /dev/null
+++ b/inference_speech.py
@@ -0,0 +1,115 @@
+import yaml
+import os
+import look2hear.models
+import argparse
+import torch
+import torchaudio
+import torchaudio.transforms as T # Added for resampling
+
+# audio path
+parser = argparse.ArgumentParser()
+# --- Argument Parsing ---
+parser = argparse.ArgumentParser(description="Separate speech sources using Look2Hear TIGER model.")
+parser.add_argument("--audio_path", default="test/mix.wav", help="Path to audio file (mixture).")
+parser.add_argument("--output_dir", default="separated_audio", help="Directory to save separated audio files.")
+parser.add_argument("--model_cache_dir", default="cache", help="Directory to cache downloaded model.")
+
+# Parse arguments once at the beginning
+
+args = parser.parse_args()
+
+audio_path = args.audio_path
+
+output_dir = args.output_dir
+
+cache_dir = args.model_cache_dir
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+print(f"Using device: {device}")
+
+# Load model
+
+print("Loading TIGER model...")
+# Ensure cache directory exists if specified
+if cache_dir:
+ os.makedirs(cache_dir, exist_ok=True)
+# Load the pretrained model
+model = look2hear.models.TIGER.from_pretrained("JusperLee/TIGER-speech", cache_dir=cache_dir)
+model.to(device)
+model.eval()
+
+
+# --- Audio Loading and Preprocessing ---
+# Define the target sample rate expected by the model (usually 16kHz for TIGER)
+
+target_sr = 16000
+print(f"Loading audio from: {audio_path}")
+try:
+ # Load audio and get its original sample rate
+ waveform, original_sr = torchaudio.load(audio_path)
+except Exception as e:
+ print(f"Error loading audio file {audio_path}: {e}")
+ exit(1)
+print(f"Original sample rate: {original_sr} Hz, Target sample rate: {target_sr} Hz")
+
+# Resample if necessary
+if original_sr != target_sr:
+ print(f"Resampling audio from {original_sr} Hz to {target_sr} Hz...")
+ resampler = T.Resample(orig_freq=original_sr, new_freq=target_sr)
+ waveform = resampler(waveform)
+ print("Resampling complete.")
+
+# Move waveform to the target device
+audio = waveform.to(device)
+
+# Prepare the input tensor for the model
+# Model likely expects a batch dimension [B, T] or [B, C, T]
+# Assuming input is mono or model handles channels; add batch dim
+# If audio has channel dim [C, T], keep it. If it's just [T], add channel dim first.
+
+if audio.dim() == 1:
+ audio = audio.unsqueeze(0) # Add channel dimension -> [1, T]
+# Add batch dimension -> [1, C, T]
+# The original audio[None] is equivalent to unsqueeze(0) on the batch dimension
+audio_input = audio.unsqueeze(0).to(device)
+print(f"Audio tensor prepared with shape: {audio_input.shape}")
+
+# --- Speech Separation ---
+
+# Create output directory if it doesn't exist
+
+os.makedirs(output_dir, exist_ok=True)
+print(f"Output directory: {output_dir}")
+print("Performing separation...")
+
+with torch.no_grad():
+ # Pass the prepared input tensor to the model
+ ests_speech = model(audio_input) # Expected output shape: [B, num_spk, T]
+
+# Process the estimated sources
+
+# Remove the batch dimension -> [num_spk, T]
+
+ests_speech = ests_speech.squeeze(0)
+
+num_speakers = ests_speech.shape[0]
+
+print(f"Separation complete. Detected {num_speakers} potential speakers.")
+
+
+
+# --- Save Separated Audio ---
+
+# Dynamically save all separated tracks
+
+for i in range(num_speakers):
+ output_filename = os.path.join(output_dir, f"spk{i+1}.wav")
+ speaker_track = ests_speech[i].cpu() # Get the i-th speaker track and move to CPU
+ print(f"Saving speaker {i+1} to {output_filename}")
+ try:
+ torchaudio.save(
+ output_filename,
+ speaker_track, # Save the individual track
+ target_sr # Save with the target sample rate
+ )
+ except Exception as e:
+ print(f"Error saving file {output_filename}: {e}")
diff --git a/look2hear/datas/Libri2Mix16.py b/look2hear/datas/Libri2Mix16.py
new file mode 100644
index 0000000000000000000000000000000000000000..74c81d03768653d7293359720c4e855e1e8250ea
--- /dev/null
+++ b/look2hear/datas/Libri2Mix16.py
@@ -0,0 +1,281 @@
+import os
+import json
+import librosa
+import numpy as np
+from typing import Any, Tuple
+import scipy
+import soundfile as sf
+import torch
+import random
+from collections import defaultdict
+from pytorch_lightning import LightningDataModule
+# from pytorch_lightning.core.mixins import HyperparametersMixin
+import torchaudio
+from torch.utils.data import ConcatDataset, DataLoader, Dataset
+from typing import Any, Dict, Optional, Tuple
+from pytorch_lightning.utilities import rank_zero_only
+
+@rank_zero_only
+def print_(message: str):
+ print(message)
+
+
+def normalize_tensor_wav(wav_tensor, eps=1e-8, std=None):
+ mean = wav_tensor.mean(-1, keepdim=True)
+ if std is None:
+ std = wav_tensor.std(-1, keepdim=True)
+ return (wav_tensor - mean) / (std + eps)
+
+def find_bottom_directories(root_dir):
+ bottom_directories = []
+ for dirpath, dirnames, filenames in os.walk(root_dir):
+ # 如果一个目录下没有子目录,则认为它是最底层的
+ if not dirnames:
+ bottom_directories.append(dirpath)
+ return bottom_directories
+
+def compute_mch_rms_dB(mch_wav, fs=16000, energy_thresh=-50):
+ """Return the wav RMS calculated only in the active portions"""
+ mean_square = max(1e-20, torch.mean(mch_wav ** 2))
+ return 10 * np.log10(mean_square)
+
+class Libri2MixDataset(Dataset):
+ def __init__(
+ self,
+ json_dir: str = "",
+ n_src: int = 2,
+ sample_rate: int = 8000,
+ segment: float = 4.0,
+ normalize_audio: bool = False,
+ ) -> None:
+ super().__init__()
+ self.EPS = 1e-8
+ if json_dir == None:
+ raise ValueError("JSON DIR is None!")
+ if n_src not in [1, 2]:
+ raise ValueError("{} is not in [1, 2]".format(n_src))
+ self.json_dir = json_dir
+ self.sample_rate = sample_rate
+ self.normalize_audio = normalize_audio
+
+ if segment is None:
+ self.seg_len = None
+ self.fps_len = None
+ else:
+ self.seg_len = int(segment * sample_rate)
+
+ self.n_src = n_src
+ self.test = self.seg_len is None
+ mix_json = os.path.join(json_dir, "mix_both.json")
+ sources_json = [
+ os.path.join(json_dir, source + ".json") for source in ["s1", "s2"]
+ ]
+
+ with open(mix_json, "r") as f:
+ mix_infos = json.load(f)
+ sources_infos = []
+ for src_json in sources_json:
+ with open(src_json, "r") as f:
+ sources_infos.append(json.load(f))
+
+ self.mix = []
+ self.sources = []
+ if self.n_src == 1:
+ orig_len = len(mix_infos) * 2
+ drop_utt, drop_len = 0, 0
+ if not self.test:
+ for i in range(len(mix_infos) - 1, -1, -1):
+ if mix_infos[i][1] < self.seg_len:
+ drop_utt = drop_utt + 1
+ drop_len = drop_len + mix_infos[i][1]
+ del mix_infos[i]
+ for src_inf in sources_infos:
+ del src_inf[i]
+ else:
+ for src_inf in sources_infos:
+ self.mix.append(mix_infos[i])
+ self.sources.append(src_inf[i])
+ else:
+ for i in range(len(mix_infos)):
+ for src_inf in sources_infos:
+ self.mix.append(mix_infos[i])
+ self.sources.append(src_inf[i])
+
+ print_(
+ "Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format(
+ drop_utt, drop_len / sample_rate / 3600, orig_len, self.seg_len
+ )
+ )
+ self.length = len(self.mix)
+
+ elif self.n_src == 2:
+ orig_len = len(mix_infos)
+ drop_utt, drop_len = 0, 0
+ if not self.test:
+ for i in range(len(mix_infos) - 1, -1, -1): # Go backward
+ if mix_infos[i][1] < self.seg_len:
+ drop_utt = drop_utt + 1
+ drop_len = drop_len + mix_infos[i][1]
+ del mix_infos[i]
+ for src_inf in sources_infos:
+ del src_inf[i]
+
+ print_(
+ "Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format(
+ drop_utt, drop_len / sample_rate / 36000, orig_len, self.seg_len
+ )
+ )
+ self.mix = mix_infos
+ self.sources = sources_infos
+ self.length = len(self.mix)
+
+ def __len__(self):
+ return self.length
+
+ def preprocess_audio_only(self, idx: int):
+ if self.n_src == 1:
+ if self.mix[idx][1] == self.seg_len or self.test:
+ rand_start = 0
+ else:
+ rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len)
+ if self.test:
+ stop = None
+ else:
+ stop = rand_start + self.seg_len
+ # Load mixture
+ x, _ = sf.read(
+ self.mix[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ # Load sources
+ s, _ = sf.read(
+ self.sources[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ # torch from numpy
+ target = torch.from_numpy(s)
+ mixture = torch.from_numpy(x)
+ if self.normalize_audio:
+ m_std = mixture.std(-1, keepdim=True)
+ mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std)
+ target = normalize_tensor_wav(target, eps=self.EPS, std=m_std)
+ return mixture, target.unsqueeze(0), self.mix[idx][0].split("/")[-1]
+ # import pdb; pdb.set_trace()
+ if self.n_src == 2:
+ if self.mix[idx][1] == self.seg_len or self.test:
+ rand_start = 0
+ else:
+ rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len)
+ if self.test:
+ stop = None
+ else:
+ stop = rand_start + self.seg_len
+ # Load mixture
+ x, _ = sf.read(
+ self.mix[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ # Load sources
+ source_arrays = []
+ for src in self.sources:
+ s, _ = sf.read(
+ src[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ source_arrays.append(s)
+ sources = torch.from_numpy(np.vstack(source_arrays))
+ mixture = torch.sum(sources, dim=0)
+
+ if self.normalize_audio:
+ m_std = mixture.std(-1, keepdim=True)
+ mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std)
+ sources = normalize_tensor_wav(sources, eps=self.EPS, std=m_std)
+
+ return mixture, sources, self.mix[idx][0].split("/")[-1]
+
+ def __getitem__(self, index: int):
+ return self.preprocess_audio_only(index)
+
+class Libri2MixModuleRemix(LightningDataModule):
+ def __init__(
+ self,
+ train_dir: str,
+ valid_dir: str,
+ test_dir: str,
+ n_src: int = 2,
+ sample_rate: int = 8000,
+ segment: float = 4.0,
+ normalize_audio: bool = False,
+ batch_size: int = 64,
+ num_workers: int = 0,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ ) -> None:
+ super().__init__()
+ self.save_hyperparameters(logger=False)
+
+ self.data_train: Optional[Dataset] = None
+ self.data_val: Optional[Dataset] = None
+ self.data_test: Optional[Dataset] = None
+
+ def setup(self, stage: Optional[str] = None) -> None:
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
+
+ This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
+ `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
+ `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
+ `self.setup()` once the data is prepared and available for use.
+
+ :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
+ """
+ # load and split datasets only if not loaded already
+ if not self.data_train and not self.data_val and not self.data_test:
+ self.data_train = Libri2MixDataset(
+ json_dir=self.hparams.train_dir,
+ n_src=self.hparams.n_src,
+ sample_rate=self.hparams.sample_rate,
+ segment=self.hparams.segment,
+ normalize_audio=self.hparams.normalize_audio,
+ )
+ self.data_val = Libri2MixDataset(
+ json_dir=self.hparams.valid_dir,
+ n_src=self.hparams.n_src,
+ sample_rate=self.hparams.sample_rate,
+ segment=None,
+ normalize_audio=self.hparams.normalize_audio,
+ )
+ self.data_test = Libri2MixDataset(
+ json_dir=self.hparams.test_dir,
+ n_src=self.hparams.n_src,
+ sample_rate=self.hparams.sample_rate,
+ segment=None,
+ normalize_audio=self.hparams.normalize_audio,
+ )
+
+ def train_dataloader(self) -> DataLoader:
+ return DataLoader(
+ self.data_train,
+ batch_size=self.hparams.batch_size,
+ num_workers=self.hparams.num_workers,
+ shuffle=True,
+ pin_memory=True,
+ )
+
+ def val_dataloader(self) -> DataLoader:
+ return DataLoader(
+ self.data_val,
+ batch_size=1,
+ num_workers=self.hparams.num_workers,
+ shuffle=False,
+ pin_memory=True,
+ )
+
+ def test_dataloader(self) -> DataLoader:
+ return DataLoader(
+ self.data_test,
+ batch_size=1,
+ num_workers=self.hparams.num_workers,
+ shuffle=False,
+ pin_memory=True,
+ )
+
+ @property
+ def make_loader(self):
+ return self.train_dataloader(), self.val_dataloader(), self.test_dataloader()
+
\ No newline at end of file
diff --git a/look2hear/datas/__init__.py b/look2hear/datas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d518ab704fb778822ac51b000423d00f3fa6481
--- /dev/null
+++ b/look2hear/datas/__init__.py
@@ -0,0 +1,9 @@
+from .echosetdatamodule import EchoSetDataModule
+from .Libri2Mix16 import Libri2MixModuleRemix
+from .lrs2datamodule import LRS2DataModule
+
+__all__ = [
+ "EchoSetDataModule",
+ "Libri2MixModuleRemix",
+ "LRS2DataModule",
+]
diff --git a/look2hear/datas/echosetdatamodule.py b/look2hear/datas/echosetdatamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..568c28dea8b6e53ed48f9191072f5a87eb6355de
--- /dev/null
+++ b/look2hear/datas/echosetdatamodule.py
@@ -0,0 +1,283 @@
+import os
+import json
+import numpy as np
+from typing import Any, Tuple
+import soundfile as sf
+import torch
+from pytorch_lightning import LightningDataModule
+# from pytorch_lightning.core.mixins import HyperparametersMixin
+from torch.utils.data import ConcatDataset, DataLoader, Dataset
+from typing import Dict, Iterable, List, Iterator
+from rich import print
+from pytorch_lightning.utilities import rank_zero_only
+
+
+@rank_zero_only
+def print_(message: str):
+ print(message)
+
+
+def normalize_tensor_wav(wav_tensor, eps=1e-8, std=None):
+ mean = wav_tensor.mean(-1, keepdim=True)
+ if std is None:
+ std = wav_tensor.std(-1, keepdim=True)
+ return (wav_tensor - mean) / (std + eps)
+
+
+class MP3DDataset(Dataset):
+ def __init__(
+ self,
+ json_dir: str = "",
+ n_src: int = 2,
+ sample_rate: int = 8000,
+ segment: float = 4.0,
+ normalize_audio: bool = False,
+ ) -> None:
+ super().__init__()
+ self.EPS = 1e-8
+ if json_dir == None:
+ raise ValueError("JSON DIR is None!")
+ if n_src not in [1, 2]:
+ raise ValueError("{} is not in [1, 2]".format(n_src))
+ self.json_dir = json_dir
+ self.sample_rate = sample_rate
+ self.normalize_audio = normalize_audio
+
+ if segment is None:
+ self.seg_len = None
+ self.fps_len = None
+ else:
+ self.seg_len = int(segment * sample_rate)
+
+ self.n_src = n_src
+ self.test = self.seg_len is None
+ mix_json = os.path.join(json_dir, "mix.json")
+ sources_json = [
+ os.path.join(json_dir, source + ".json") for source in ["s1", "s2"]
+ ]
+
+ with open(mix_json, "r") as f:
+ mix_infos = json.load(f)
+ sources_infos = []
+ for src_json in sources_json:
+ with open(src_json, "r") as f:
+ sources_infos.append(json.load(f))
+
+ self.mix = []
+ self.sources = []
+ if self.n_src == 1:
+ orig_len = len(mix_infos) * 2
+ drop_utt, drop_len = 0, 0
+ if not self.test:
+ for i in range(len(mix_infos) - 1, -1, -1):
+ if mix_infos[i][1] < self.seg_len:
+ drop_utt = drop_utt + 1
+ drop_len = drop_len + mix_infos[i][1]
+ del mix_infos[i]
+ for src_inf in sources_infos:
+ del src_inf[i]
+ else:
+ for src_inf in sources_infos:
+ self.mix.append(mix_infos[i])
+ self.sources.append(src_inf[i])
+ else:
+ for i in range(len(mix_infos)):
+ for src_inf in sources_infos:
+ self.mix.append(mix_infos[i])
+ self.sources.append(src_inf[i])
+
+ print_(
+ "Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format(
+ drop_utt, drop_len / sample_rate / 3600, orig_len, self.seg_len
+ )
+ )
+ self.length = len(self.mix)
+
+ elif self.n_src == 2:
+ orig_len = len(mix_infos)
+ drop_utt, drop_len = 0, 0
+ if not self.test:
+ for i in range(len(mix_infos) - 1, -1, -1): # Go backward
+ if mix_infos[i][1] < self.seg_len:
+ drop_utt = drop_utt + 1
+ drop_len = drop_len + mix_infos[i][1]
+ del mix_infos[i]
+ for src_inf in sources_infos:
+ del src_inf[i]
+
+ print_(
+ "Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format(
+ drop_utt, drop_len / sample_rate / 36000, orig_len, self.seg_len
+ )
+ )
+ self.mix = mix_infos
+ self.sources = sources_infos
+ self.length = len(self.mix)
+
+ def __len__(self):
+ return self.length
+
+ def preprocess_audio_only(self, idx: int):
+ if self.n_src == 1:
+ if self.mix[idx][1] == self.seg_len or self.test:
+ rand_start = 0
+ else:
+ rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len)
+ if self.test:
+ stop = None
+ else:
+ stop = rand_start + self.seg_len
+ # Load mixture
+ x, _ = sf.read(
+ self.mix[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ # Load sources
+ s, _ = sf.read(
+ self.sources[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ # torch from numpy
+ target = torch.from_numpy(s)
+ mixture = torch.from_numpy(x)
+ if self.normalize_audio:
+ m_std = mixture.std(-1, keepdim=True)
+ mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std)
+ target = normalize_tensor_wav(target, eps=self.EPS, std=m_std)
+ # return mixture, target.unsqueeze(0), self.mix[idx][0].split("/")[-1]
+ return mixture, target.unsqueeze(0), self.mix[idx][0]
+ # import pdb; pdb.set_trace()
+ if self.n_src == 2:
+ if self.mix[idx][1] == self.seg_len or self.test:
+ rand_start = 0
+ else:
+ rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len)
+ if self.test:
+ stop = None
+ else:
+ stop = rand_start + self.seg_len
+ # Load mixture
+ x, _ = sf.read(
+ self.mix[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ # Load sources
+ source_arrays = []
+ for src in self.sources:
+ s, _ = sf.read(
+ src[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ source_arrays.append(s)
+ sources = torch.from_numpy(np.vstack(source_arrays))
+ mixture = torch.from_numpy(x)
+
+ if self.normalize_audio:
+ m_std = mixture.std(-1, keepdim=True)
+ mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std)
+ sources = normalize_tensor_wav(sources, eps=self.EPS, std=m_std)
+
+ # return mixture, sources, self.mix[idx][0].split("/")[-1]
+ return mixture, sources, self.mix[idx][0]
+
+ def __getitem__(self, index: int):
+ return self.preprocess_audio_only(index)
+
+
+class EchoSetDataModule(object):
+ def __init__(
+ self,
+ train_dir: str,
+ valid_dir: str,
+ test_dir: str,
+ n_src: int = 2,
+ sample_rate: int = 8000,
+ segment: float = 4.0,
+ normalize_audio: bool = False,
+ batch_size: int = 64,
+ num_workers: int = 0,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ ) -> None:
+ super().__init__()
+ if train_dir == None or valid_dir == None or test_dir == None:
+ raise ValueError("JSON DIR is None!")
+ if n_src not in [1, 2]:
+ raise ValueError("{} is not in [1, 2]".format(n_src))
+
+ # this line allows to access init params with 'self.hparams' attribute
+ self.train_dir = train_dir
+ self.valid_dir = valid_dir
+ self.test_dir = test_dir
+ self.n_src = n_src
+ self.sample_rate = sample_rate
+ self.segment = segment
+ self.normalize_audio = normalize_audio
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.pin_memory = pin_memory
+ self.persistent_workers = persistent_workers
+
+ self.data_train: Dataset = None
+ self.data_val: Dataset = None
+ self.data_test: Dataset = None
+
+ def setup(self) -> None:
+ self.data_train = MP3DDataset(
+ json_dir=self.train_dir,
+ n_src=self.n_src,
+ sample_rate=self.sample_rate,
+ segment=self.segment,
+ normalize_audio=self.normalize_audio,
+ )
+ self.data_val = MP3DDataset(
+ json_dir=self.valid_dir,
+ n_src=self.n_src,
+ sample_rate=self.sample_rate,
+ segment=None,
+ normalize_audio=self.normalize_audio,
+ )
+ self.data_test = MP3DDataset(
+ json_dir=self.test_dir,
+ n_src=self.n_src,
+ sample_rate=self.sample_rate,
+ segment=None,
+ normalize_audio=self.normalize_audio,
+ )
+
+ def train_dataloader(self) -> DataLoader:
+ return DataLoader(
+ dataset=self.data_train,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ persistent_workers=self.persistent_workers,
+ pin_memory=self.pin_memory,
+ drop_last=True,
+ )
+
+ def val_dataloader(self) -> DataLoader:
+ return DataLoader(
+ dataset=self.data_val,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ persistent_workers=self.persistent_workers,
+ pin_memory=self.pin_memory,
+ drop_last=False,
+ )
+
+ def test_dataloader(self) -> DataLoader:
+ return DataLoader(
+ dataset=self.data_test,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ persistent_workers=self.persistent_workers,
+ pin_memory=self.pin_memory,
+ drop_last=False,
+ )
+
+ @property
+ def make_loader(self):
+ return self.train_dataloader(), self.val_dataloader(), self.test_dataloader()
+
+ @property
+ def make_sets(self):
+ return self.data_train, self.data_val, self.data_test
diff --git a/look2hear/datas/lrs2datamodule.py b/look2hear/datas/lrs2datamodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1abdf13ce5988f176ca4d63103ae6b07083d348
--- /dev/null
+++ b/look2hear/datas/lrs2datamodule.py
@@ -0,0 +1,358 @@
+import os
+import json
+from tkinter.tix import Tree
+import numpy as np
+from typing import Any, Tuple
+import soundfile as sf
+import torch
+from pytorch_lightning import LightningDataModule
+from pytorch_lightning.core.mixins import HyperparametersMixin
+from torch.utils.data import ConcatDataset, DataLoader, Dataset
+from typing import Dict, Iterable, List, Iterator
+from rich import print
+from pytorch_lightning.utilities import rank_zero_only
+
+
+@rank_zero_only
+def print_(message: str):
+ print(message)
+
+
+def normalize_tensor_wav(wav_tensor, eps=1e-8, std=None):
+ mean = wav_tensor.mean(-1, keepdim=True)
+ if std is None:
+ std = wav_tensor.std(-1, keepdim=True)
+ return (wav_tensor - mean) / (std + eps)
+
+
+class LRS2Dataset(Dataset):
+ def __init__(
+ self,
+ json_dir: str = "",
+ n_src: int = 2,
+ sample_rate: int = 8000,
+ fps: int = 25,
+ segment: float = 4.0,
+ normalize_audio: bool = False,
+ audio_only: bool = True,
+ ) -> None:
+ super().__init__()
+ self.EPS = 1e-8
+ if json_dir == None:
+ raise ValueError("JSON DIR is None!")
+ if n_src not in [1, 2]:
+ raise ValueError("{} is not in [1, 2]".format(n_src))
+ self.json_dir = json_dir
+ self.sample_rate = sample_rate
+ self.normalize_audio = normalize_audio
+ self.audio_only = audio_only
+
+ if segment is None:
+ self.seg_len = None
+ self.fps_len = None
+ else:
+ self.seg_len = int(segment * sample_rate)
+ self.fps_len = int(segment * fps)
+ self.n_src = n_src
+ self.test = self.seg_len is None
+ mix_json = os.path.join(json_dir, "mix.json")
+ sources_json = [
+ os.path.join(json_dir, source + ".json") for source in ["s1", "s2"]
+ ]
+
+ with open(mix_json, "r") as f:
+ mix_infos = json.load(f)
+ sources_infos = []
+ for src_json in sources_json:
+ with open(src_json, "r") as f:
+ sources_infos.append(json.load(f))
+
+ self.mix = []
+ self.sources = []
+ if self.n_src == 1:
+ orig_len = len(mix_infos) * 2
+ drop_utt, drop_len = 0, 0
+ if not self.test:
+ for i in range(len(mix_infos) - 1, -1, -1):
+ if mix_infos[i][1] < self.seg_len:
+ drop_utt = drop_utt + 1
+ drop_len = drop_len + mix_infos[i][1]
+ del mix_infos[i]
+ for src_inf in sources_infos:
+ del src_inf[i]
+ else:
+ for src_inf in sources_infos:
+ self.mix.append(mix_infos[i])
+ self.sources.append(src_inf[i])
+ else:
+ for i in range(len(mix_infos)):
+ for src_inf in sources_infos:
+ self.mix.append(mix_infos[i])
+ self.sources.append(src_inf[i])
+
+ print_(
+ "Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format(
+ drop_utt, drop_len / sample_rate / 3600, orig_len, self.seg_len
+ )
+ )
+ self.length = len(self.mix)
+
+ elif self.n_src == 2:
+ orig_len = len(mix_infos)
+ drop_utt, drop_len = 0, 0
+ if not self.test:
+ for i in range(len(mix_infos) - 1, -1, -1): # Go backward
+ if mix_infos[i][1] < self.seg_len:
+ drop_utt = drop_utt + 1
+ drop_len = drop_len + mix_infos[i][1]
+ del mix_infos[i]
+ for src_inf in sources_infos:
+ del src_inf[i]
+
+ print_(
+ "Drop {} utts({:.2f} h) from {} (shorter than {} samples)".format(
+ drop_utt, drop_len / sample_rate / 36000, orig_len, self.seg_len
+ )
+ )
+ self.mix = mix_infos
+ self.sources = sources_infos
+ self.length = len(self.mix)
+
+ def __len__(self):
+ return self.length
+
+ def preprocess_audio_only(self, idx: int):
+ if self.n_src == 1:
+ if self.mix[idx][1] == self.seg_len or self.test:
+ rand_start = 0
+ else:
+ rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len)
+ if self.test:
+ stop = None
+ else:
+ stop = rand_start + self.seg_len
+ # Load mixture
+ x, _ = sf.read(
+ self.mix[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ # Load sources
+ s, _ = sf.read(
+ self.sources[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ # torch from numpy
+ target = torch.from_numpy(s)
+ mixture = torch.from_numpy(x)
+ if self.normalize_audio:
+ m_std = mixture.std(-1, keepdim=True)
+ mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std)
+ target = normalize_tensor_wav(target, eps=self.EPS, std=m_std)
+ return mixture, target.unsqueeze(0), self.mix[idx][0].split("/")[-1]
+ # import pdb; pdb.set_trace()
+ if self.n_src == 2:
+ if self.mix[idx][1] == self.seg_len or self.test:
+ rand_start = 0
+ else:
+ rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len)
+ if self.test:
+ stop = None
+ else:
+ stop = rand_start + self.seg_len
+ # Load mixture
+ x, _ = sf.read(
+ self.mix[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ # Load sources
+ source_arrays = []
+ for src in self.sources:
+ s, _ = sf.read(
+ src[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ source_arrays.append(s)
+ sources = torch.from_numpy(np.vstack(source_arrays))
+ mixture = torch.from_numpy(x)
+
+ if self.normalize_audio:
+ m_std = mixture.std(-1, keepdim=True)
+ mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std)
+ sources = normalize_tensor_wav(sources, eps=self.EPS, std=m_std)
+
+ return mixture, sources, self.mix[idx][0].split("/")[-1]
+
+ def preprocess_audio_visual(self, idx: int):
+ if self.n_src == 1:
+ if self.mix[idx][1] == self.seg_len or self.test:
+ rand_start = 0
+ else:
+ rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len)
+ if self.test:
+ stop = None
+ else:
+ stop = rand_start + self.seg_len
+
+ mix_source, _ = sf.read(
+ self.mix[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ source = sf.read(
+ self.sources[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )[0]
+ source_mouth = None
+
+ source = torch.from_numpy(source)
+ mixture = torch.from_numpy(mix_source)
+
+ if self.normalize_audio:
+ m_std = mixture.std(-1, keepdim=True)
+ mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std)
+ source = normalize_tensor_wav(source, eps=self.EPS, std=m_std)
+ return mixture, source, source_mouth, self.mix[idx][0].split("/")[-1]
+
+ if self.n_src == 2:
+ if self.mix[idx][1] == self.seg_len or self.test:
+ rand_start = 0
+ else:
+ rand_start = np.random.randint(0, self.mix[idx][1] - self.seg_len)
+ if self.test:
+ stop = None
+ else:
+ stop = rand_start + self.seg_len
+
+ mix_source, _ = sf.read(
+ self.mix[idx][0], start=rand_start, stop=stop, dtype="float32"
+ )
+ sources = []
+ for src in self.sources[idx]:
+ # import pdb; pdb.set_trace()
+ sources.append(
+ sf.read(src[0], start=rand_start, stop=stop, dtype="float32")[0]
+ )
+ # import pdb; pdb.set_trace()
+ sources_mouths = None
+ # import pdb; pdb.set_trace()
+ sources = torch.stack([torch.from_numpy(source) for source in sources])
+ mixture = torch.from_numpy(mix_source)
+
+ if self.normalize_audio:
+ m_std = mixture.std(-1, keepdim=True)
+ mixture = normalize_tensor_wav(mixture, eps=self.EPS, std=m_std)
+ sources = normalize_tensor_wav(sources, eps=self.EPS, std=m_std)
+
+ return mixture, sources, sources_mouths, self.mix[idx][0].split("/")[-1]
+
+ def __getitem__(self, index: int):
+ if self.audio_only:
+ return self.preprocess_audio_only(index)
+ else:
+ return self.preprocess_audio_visual(index)
+
+
+class LRS2DataModule(object):
+ def __init__(
+ self,
+ train_dir: str,
+ valid_dir: str,
+ test_dir: str,
+ n_src: int = 2,
+ sample_rate: int = 8000,
+ fps: int = 25,
+ segment: float = 4.0,
+ normalize_audio: bool = False,
+ batch_size: int = 64,
+ num_workers: int = 0,
+ pin_memory: bool = False,
+ persistent_workers: bool = False,
+ audio_only: bool = True,
+ ) -> None:
+ super().__init__()
+ if train_dir == None or valid_dir == None or test_dir == None:
+ raise ValueError("JSON DIR is None!")
+ if n_src not in [1, 2]:
+ raise ValueError("{} is not in [1, 2]".format(n_src))
+
+ # this line allows to access init params with 'self.hparams' attribute
+ self.train_dir = train_dir
+ self.valid_dir = valid_dir
+ self.test_dir = test_dir
+ self.n_src = n_src
+ self.sample_rate = sample_rate
+ self.fps = fps
+ self.segment = segment
+ self.normalize_audio = normalize_audio
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.pin_memory = pin_memory
+ self.persistent_workers = persistent_workers
+ self.audio_only = audio_only
+
+ self.data_train: Dataset = None
+ self.data_val: Dataset = None
+ self.data_test: Dataset = None
+
+ def setup(self) -> None:
+ self.data_train = LRS2Dataset(
+ json_dir=self.train_dir,
+ n_src=self.n_src,
+ sample_rate=self.sample_rate,
+ fps=self.fps,
+ segment=self.segment,
+ normalize_audio=self.normalize_audio,
+ audio_only=self.audio_only,
+ )
+ self.data_val = LRS2Dataset(
+ json_dir=self.valid_dir,
+ n_src=self.n_src,
+ sample_rate=self.sample_rate,
+ fps=self.fps,
+ segment=self.segment,
+ normalize_audio=self.normalize_audio,
+ audio_only=self.audio_only,
+ )
+ self.data_test = LRS2Dataset(
+ json_dir=self.test_dir,
+ n_src=self.n_src,
+ sample_rate=self.sample_rate,
+ fps=self.fps,
+ segment=self.segment,
+ normalize_audio=self.normalize_audio,
+ audio_only=self.audio_only,
+ )
+
+ def train_dataloader(self) -> DataLoader:
+ return DataLoader(
+ dataset=self.data_train,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ persistent_workers=self.persistent_workers,
+ pin_memory=self.pin_memory,
+ drop_last=True,
+ )
+
+ def val_dataloader(self) -> DataLoader:
+ return DataLoader(
+ dataset=self.data_val,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ persistent_workers=self.persistent_workers,
+ pin_memory=self.pin_memory,
+ drop_last=True,
+ )
+
+ def test_dataloader(self) -> DataLoader:
+ return DataLoader(
+ dataset=self.data_test,
+ shuffle=False,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ persistent_workers=self.persistent_workers,
+ pin_memory=self.pin_memory,
+ drop_last=True,
+ )
+
+ @property
+ def make_loader(self):
+ return self.train_dataloader(), self.val_dataloader(), self.test_dataloader()
+
+ @property
+ def make_sets(self):
+ return self.data_train, self.data_val, self.data_test
diff --git a/look2hear/layers/__init__.py b/look2hear/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b85d700342170fd432d6fdf5508692b92cf1eb6
--- /dev/null
+++ b/look2hear/layers/__init__.py
@@ -0,0 +1,50 @@
+from .cnnlayers import (
+ TAC,
+ Conv1DBlock,
+ ConvNormAct,
+ ConvNorm,
+ NormAct,
+ Video1DConv,
+ Concat,
+ FRCNNBlock,
+ FRCNNBlockTCN,
+ Bottomup,
+ BottomupTCN,
+ Bottomup_Concat_Topdown,
+ Bottomup_Concat_Topdown_TCN,
+)
+from .rnnlayers import DPRNN, DPRNNBlock, DPRNNLinear, LSTMBlockTF, TransformerBlockTF
+from .enc_dec import make_enc_dec, FreeFB
+from .normalizations import gLN, cLN, LN, bN
+from .stft import forward_stft, inverse_stft
+from .stft_tfgn import Stft
+
+__all__ = [
+ "TAC",
+ "DPRNN",
+ "DPRNNBlock",
+ "DPRNNLinear",
+ "LSTMBlockTF",
+ "TransformerBlockTF",
+ "Conv1DBlock",
+ "ConvNormAct",
+ "ConvNorm",
+ "NormAct",
+ "Video1DConv",
+ "Concat",
+ "FRCNNBlock",
+ "FRCNNBlockTCN",
+ "Bottomup",
+ "Bottomup",
+ "Bottomup_Concat_Topdown",
+ "Bottomup_Concat_Topdown_TCN",
+ "make_enc_dec",
+ "FreeFB",
+ "gLN",
+ "cLN",
+ "LN",
+ "bN",
+ "forward_stft",
+ "inverse_stft",
+ "Stft",
+]
diff --git a/look2hear/layers/activations.py b/look2hear/layers/activations.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2e875bdc5b0dcc6efe6ec7eea7d4aec0dada330
--- /dev/null
+++ b/look2hear/layers/activations.py
@@ -0,0 +1,68 @@
+import torch
+from torch import nn
+
+
+def linear():
+ return nn.Identity()
+
+
+def relu():
+ return nn.ReLU()
+
+
+def prelu():
+ return nn.PReLU()
+
+
+def leaky_relu():
+ return nn.LeakyReLU()
+
+
+def sigmoid():
+ return nn.Sigmoid()
+
+
+def softmax(dim=None):
+ return nn.Softmax(dim=dim)
+
+
+def tanh():
+ return nn.Tanh()
+
+
+def gelu():
+ return nn.GELU()
+
+
+def register_activation(custom_act):
+ if (
+ custom_act.__name__ in globals().keys()
+ or custom_act.__name__.lower() in globals().keys()
+ ):
+ raise ValueError(
+ f"Activation {custom_act.__name__} already exists. Choose another name."
+ )
+ globals().update({custom_act.__name__: custom_act})
+
+
+def get(identifier):
+ if identifier is None:
+ return None
+ elif callable(identifier):
+ return identifier
+ elif isinstance(identifier, str):
+ cls = globals().get(identifier)
+ if cls is None:
+ raise ValueError(
+ "Could not interpret activation identifier: " + str(identifier)
+ )
+ return cls
+ else:
+ raise ValueError(
+ "Could not interpret activation identifier: " + str(identifier)
+ )
+
+
+if __name__ == "__main__":
+ print(globals().keys())
+ print(globals().get("tanh"))
diff --git a/look2hear/layers/cnnlayers.py b/look2hear/layers/cnnlayers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3351b669e47da9b91ee00c7c8cfb6eb41e1150e7
--- /dev/null
+++ b/look2hear/layers/cnnlayers.py
@@ -0,0 +1,887 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from . import normalizations, activations
+
+
+class _Chop1d(nn.Module):
+ """To ensure the output length is the same as the input."""
+
+ def __init__(self, chop_size):
+ super().__init__()
+ self.chop_size = chop_size
+
+ def forward(self, x):
+ return x[..., : -self.chop_size].contiguous()
+
+
+class Conv1DBlock(nn.Module):
+ def __init__(
+ self,
+ in_chan,
+ hid_chan,
+ skip_out_chan,
+ kernel_size,
+ padding,
+ dilation,
+ norm_type="gLN",
+ causal=False,
+ ):
+ super(Conv1DBlock, self).__init__()
+ self.skip_out_chan = skip_out_chan
+ conv_norm = normalizations.get(norm_type)
+ in_conv1d = nn.Conv1d(in_chan, hid_chan, 1)
+ depth_conv1d = nn.Conv1d(
+ hid_chan,
+ hid_chan,
+ kernel_size,
+ padding=padding,
+ dilation=dilation,
+ groups=hid_chan,
+ )
+ if causal:
+ depth_conv1d = nn.Sequential(depth_conv1d, _Chop1d(padding))
+ self.shared_block = nn.Sequential(
+ in_conv1d,
+ nn.PReLU(),
+ conv_norm(hid_chan),
+ depth_conv1d,
+ nn.PReLU(),
+ conv_norm(hid_chan),
+ )
+ self.res_conv = nn.Conv1d(hid_chan, in_chan, 1)
+ if skip_out_chan:
+ self.skip_conv = nn.Conv1d(hid_chan, skip_out_chan, 1)
+
+ def forward(self, x):
+ r"""Input shape $(batch, feats, seq)$."""
+ shared_out = self.shared_block(x)
+ res_out = self.res_conv(shared_out)
+ if not self.skip_out_chan:
+ return res_out
+ skip_out = self.skip_conv(shared_out)
+ return res_out, skip_out
+
+
+class ConvNormAct(nn.Module):
+ """
+ This class defines the convolution layer with normalization and a PReLU
+ activation
+ """
+
+ def __init__(
+ self,
+ in_chan,
+ out_chan,
+ kernel_size,
+ stride=1,
+ groups=1,
+ dilation=1,
+ padding=0,
+ norm_type="gLN",
+ act_type="prelu",
+ ):
+ super(ConvNormAct, self).__init__()
+ self.conv = nn.Conv1d(
+ in_chan,
+ out_chan,
+ kernel_size,
+ stride=stride,
+ dilation=dilation,
+ padding=padding,
+ bias=True,
+ groups=groups,
+ )
+ self.norm = normalizations.get(norm_type)(out_chan)
+ self.act = activations.get(act_type)()
+
+ def forward(self, x):
+ output = self.conv(x)
+ output = self.norm(output)
+ return self.act(output)
+
+
+class ConvNorm(nn.Module):
+ def __init__(
+ self,
+ in_chan,
+ out_chan,
+ kernel_size,
+ stride=1,
+ groups=1,
+ dilation=1,
+ padding=0,
+ norm_type="gLN",
+ ):
+ super(ConvNorm, self).__init__()
+ self.conv = nn.Conv1d(
+ in_chan,
+ out_chan,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ bias=True,
+ groups=groups,
+ )
+ self.norm = normalizations.get(norm_type)(out_chan)
+
+ def forward(self, x):
+ output = self.conv(x)
+ return self.norm(output)
+
+
+class NormAct(nn.Module):
+ """
+ This class defines a normalization and PReLU activation
+ """
+
+ def __init__(
+ self, out_chan, norm_type="gLN", act_type="prelu",
+ ):
+ """
+ :param nOut: number of output channels
+ """
+ super(NormAct, self).__init__()
+ # self.norm = nn.GroupNorm(1, nOut, eps=1e-08)
+ self.norm = normalizations.get(norm_type)(out_chan)
+ self.act = activations.get(act_type)()
+
+ def forward(self, input):
+ output = self.norm(input)
+ return self.act(output)
+
+
+class Video1DConv(nn.Module):
+ """
+ video part 1-D Conv Block
+ in_chan: video Encoder output channels
+ out_chan: dconv channels
+ kernel_size: the depthwise conv kernel size
+ dilation: the depthwise conv dilation
+ residual: Whether to use residual connection
+ skip_con: Whether to use skip connection
+ first_block: first block, not residual
+ """
+
+ def __init__(
+ self,
+ in_chan,
+ out_chan,
+ kernel_size,
+ dilation=1,
+ residual=True,
+ skip_con=True,
+ first_block=True,
+ ):
+ super(Video1DConv, self).__init__()
+ self.first_block = first_block
+ # first block, not residual
+ self.residual = residual and not first_block
+ self.bn = nn.BatchNorm1d(in_chan) if not first_block else None
+ self.relu = nn.ReLU() if not first_block else None
+ self.dconv = nn.Conv1d(
+ in_chan,
+ in_chan,
+ kernel_size,
+ groups=in_chan,
+ dilation=dilation,
+ padding=(dilation * (kernel_size - 1)) // 2,
+ bias=True,
+ )
+ self.bconv = nn.Conv1d(in_chan, out_chan, 1)
+ self.sconv = nn.Conv1d(in_chan, out_chan, 1)
+ self.skip_con = skip_con
+
+ def forward(self, x):
+ """
+ x: [B, N, T]
+ out: [B, N, T]
+ """
+ if not self.first_block:
+ y = self.bn(self.relu(x))
+ y = self.dconv(y)
+ else:
+ y = self.dconv(x)
+ # skip connection
+ if self.skip_con:
+ skip = self.sconv(y)
+ if self.residual:
+ y = y + x
+ return skip, y
+ else:
+ return skip, y
+ else:
+ y = self.bconv(y)
+ if self.residual:
+ y = y + x
+ return y
+ else:
+ return y
+
+
+class Concat(nn.Module):
+ def __init__(self, ain_chan, vin_chan, out_chan):
+ super(Concat, self).__init__()
+ self.ain_chan = ain_chan
+ self.vin_chan = vin_chan
+ # project
+ self.conv1d = nn.Sequential(
+ nn.Conv1d(ain_chan + vin_chan, out_chan, 1), nn.PReLU()
+ )
+
+ def forward(self, a, v):
+ # up-sample video features
+ v = torch.nn.functional.interpolate(v, size=a.size(-1))
+ # concat: n x (A+V) x Ta
+ y = torch.cat([a, v], dim=1)
+ # conv1d
+ return self.conv1d(y)
+
+
+class FRCNNBlock(nn.Module):
+ def __init__(
+ self,
+ in_chan=128,
+ out_chan=512,
+ upsampling_depth=4,
+ norm_type="gLN",
+ act_type="prelu",
+ ):
+ super().__init__()
+ self.proj_1x1 = ConvNormAct(
+ in_chan,
+ out_chan,
+ kernel_size=1,
+ stride=1,
+ groups=1,
+ dilation=1,
+ padding=0,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ self.depth = upsampling_depth
+ self.spp_dw = nn.ModuleList([])
+ self.spp_dw.append(
+ ConvNorm(
+ out_chan,
+ out_chan,
+ kernel_size=5,
+ stride=1,
+ groups=out_chan,
+ dilation=1,
+ padding=((5 - 1) // 2) * 1,
+ norm_type=norm_type,
+ )
+ )
+ # ----------Down Sample Layer----------
+ for i in range(1, upsampling_depth):
+ self.spp_dw.append(
+ ConvNorm(
+ out_chan,
+ out_chan,
+ kernel_size=5,
+ stride=2,
+ groups=out_chan,
+ dilation=1,
+ padding=((5 - 1) // 2) * 1,
+ norm_type=norm_type,
+ )
+ )
+ # ----------Fusion Layer----------
+ self.fuse_layers = nn.ModuleList([])
+ for i in range(upsampling_depth):
+ fuse_layer = nn.ModuleList([])
+ for j in range(upsampling_depth):
+ if i == j:
+ fuse_layer.append(None)
+ elif j - i == 1:
+ fuse_layer.append(None)
+ elif i - j == 1:
+ fuse_layer.append(
+ ConvNorm(
+ out_chan,
+ out_chan,
+ kernel_size=5,
+ stride=2,
+ groups=out_chan,
+ dilation=1,
+ padding=((5 - 1) // 2) * 1,
+ norm_type=norm_type,
+ )
+ )
+ self.fuse_layers.append(fuse_layer)
+ self.concat_layer = nn.ModuleList([])
+ # ----------Concat Layer----------
+ for i in range(upsampling_depth):
+ if i == 0 or i == upsampling_depth - 1:
+ self.concat_layer.append(
+ ConvNormAct(
+ out_chan * 2,
+ out_chan,
+ 1,
+ 1,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ )
+ else:
+ self.concat_layer.append(
+ ConvNormAct(
+ out_chan * 3,
+ out_chan,
+ 1,
+ 1,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ )
+ self.last_layer = nn.Sequential(
+ ConvNormAct(
+ out_chan * upsampling_depth,
+ out_chan,
+ 1,
+ 1,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ )
+ self.res_conv = nn.Conv1d(out_chan, in_chan, 1)
+ # ----------parameters-------------
+ self.depth = upsampling_depth
+
+ def forward(self, x):
+ """
+ :param x: input feature map
+ :return: transformed feature map
+ """
+ residual = x.clone()
+ # Reduce --> project high-dimensional feature maps to low-dimensional space
+ output1 = self.proj_1x1(x)
+ output = [self.spp_dw[0](output1)]
+ for k in range(1, self.depth):
+ out_k = self.spp_dw[k](output[-1])
+ output.append(out_k)
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ wav_length = output[i].shape[-1]
+ y = torch.cat(
+ (
+ self.fuse_layers[i][0](output[i - 1])
+ if i - 1 >= 0
+ else torch.Tensor().to(output1.device),
+ output[i],
+ F.interpolate(output[i + 1], size=wav_length, mode="nearest")
+ if i + 1 < self.depth
+ else torch.Tensor().to(output1.device),
+ ),
+ dim=1,
+ )
+ x_fuse.append(self.concat_layer[i](y))
+
+ wav_length = output[0].shape[-1]
+ for i in range(1, len(x_fuse)):
+ x_fuse[i] = F.interpolate(x_fuse[i], size=wav_length, mode="nearest")
+
+ concat = self.last_layer(torch.cat(x_fuse, dim=1))
+ expanded = self.res_conv(concat)
+ return expanded + residual
+
+
+class Bottomup(nn.Module):
+ def __init__(
+ self,
+ in_chan=128,
+ out_chan=512,
+ upsampling_depth=4,
+ norm_type="gLN",
+ act_type="prelu",
+ ):
+ super().__init__()
+ self.proj_1x1 = ConvNormAct(
+ in_chan,
+ out_chan,
+ kernel_size=1,
+ stride=1,
+ groups=1,
+ dilation=1,
+ padding=0,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ self.depth = upsampling_depth
+ self.spp_dw = nn.ModuleList([])
+ self.spp_dw.append(
+ ConvNorm(
+ out_chan,
+ out_chan,
+ kernel_size=5,
+ stride=1,
+ groups=out_chan,
+ dilation=1,
+ padding=((5 - 1) // 2) * 1,
+ norm_type=norm_type,
+ )
+ )
+ # ----------Down Sample Layer----------
+ for i in range(1, upsampling_depth):
+ self.spp_dw.append(
+ ConvNorm(
+ out_chan,
+ out_chan,
+ kernel_size=5,
+ stride=2,
+ groups=out_chan,
+ dilation=1,
+ padding=((5 - 1) // 2) * 1,
+ norm_type=norm_type,
+ )
+ )
+
+ def forward(self, x):
+ residual = x.clone()
+ # Reduce --> project high-dimensional feature maps to low-dimensional space
+ output1 = self.proj_1x1(x)
+ output = [self.spp_dw[0](output1)]
+ for k in range(1, self.depth):
+ out_k = self.spp_dw[k](output[-1])
+ output.append(out_k)
+
+ return residual, output[-1], output
+
+
+class BottomupTCN(nn.Module):
+ def __init__(
+ self,
+ in_chan=128,
+ out_chan=512,
+ upsampling_depth=4,
+ norm_type="gLN",
+ act_type="prelu",
+ ):
+ super().__init__()
+ self.proj_1x1 = ConvNormAct(
+ in_chan,
+ out_chan,
+ kernel_size=1,
+ stride=1,
+ groups=1,
+ dilation=1,
+ padding=0,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ self.depth = upsampling_depth
+ self.spp_dw = nn.ModuleList([])
+ self.spp_dw.append(
+ Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=True)
+ )
+ # ----------Down Sample Layer----------
+ for i in range(1, upsampling_depth):
+ self.spp_dw.append(
+ Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=False)
+ )
+
+ def forward(self, x):
+ residual = x.clone()
+ # Reduce --> project high-dimensional feature maps to low-dimensional space
+ output1 = self.proj_1x1(x)
+ output = [self.spp_dw[0](output1)]
+ for k in range(1, self.depth):
+ out_k = self.spp_dw[k](output[-1])
+ output.append(out_k)
+
+ return residual, output[-1], output
+
+
+class Bottomup_Concat_Topdown(nn.Module):
+ def __init__(
+ self,
+ in_chan=128,
+ out_chan=512,
+ upsampling_depth=4,
+ norm_type="gLN",
+ act_type="prelu",
+ ):
+ super().__init__()
+ # ----------Fusion Layer----------
+ self.fuse_layers = nn.ModuleList([])
+ for i in range(upsampling_depth):
+ fuse_layer = nn.ModuleList([])
+ for j in range(upsampling_depth):
+ if i == j:
+ fuse_layer.append(None)
+ elif j - i == 1:
+ fuse_layer.append(None)
+ elif i - j == 1:
+ fuse_layer.append(
+ ConvNorm(
+ out_chan,
+ out_chan,
+ kernel_size=5,
+ stride=2,
+ groups=out_chan,
+ dilation=1,
+ padding=((5 - 1) // 2) * 1,
+ norm_type=norm_type,
+ )
+ )
+ self.fuse_layers.append(fuse_layer)
+ self.concat_layer = nn.ModuleList([])
+ # ----------Concat Layer----------
+ for i in range(upsampling_depth):
+ if i == 0 or i == upsampling_depth - 1:
+ self.concat_layer.append(
+ ConvNormAct(
+ out_chan * 3,
+ out_chan,
+ 1,
+ 1,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ )
+ else:
+ self.concat_layer.append(
+ ConvNormAct(
+ out_chan * 4,
+ out_chan,
+ 1,
+ 1,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ )
+ self.last_layer = nn.Sequential(
+ ConvNormAct(
+ out_chan * upsampling_depth,
+ out_chan,
+ 1,
+ 1,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ )
+ self.res_conv = nn.Conv1d(out_chan, in_chan, 1)
+ # ----------parameters-------------
+ self.depth = upsampling_depth
+
+ def forward(self, residual, bottomup, topdown):
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ wav_length = bottomup[i].shape[-1]
+ y = torch.cat(
+ (
+ self.fuse_layers[i][0](bottomup[i - 1])
+ if i - 1 >= 0
+ else torch.Tensor().to(bottomup[i].device),
+ bottomup[i],
+ F.interpolate(bottomup[i + 1], size=wav_length, mode="nearest")
+ if i + 1 < self.depth
+ else torch.Tensor().to(bottomup[i].device),
+ F.interpolate(topdown, size=wav_length, mode="nearest"),
+ ),
+ dim=1,
+ )
+ x_fuse.append(self.concat_layer[i](y))
+
+ wav_length = bottomup[0].shape[-1]
+ for i in range(1, len(x_fuse)):
+ x_fuse[i] = F.interpolate(x_fuse[i], size=wav_length, mode="nearest")
+
+ concat = self.last_layer(torch.cat(x_fuse, dim=1))
+ expanded = self.res_conv(concat)
+ return expanded + residual
+
+
+class Bottomup_Concat_Topdown_TCN(nn.Module):
+ def __init__(
+ self,
+ in_chan=128,
+ out_chan=512,
+ upsampling_depth=4,
+ norm_type="gLN",
+ act_type="prelu",
+ ):
+ super().__init__()
+ # ----------Fusion Layer----------
+ self.fuse_layers = nn.ModuleList([])
+ for i in range(upsampling_depth):
+ fuse_layer = nn.ModuleList([])
+ for j in range(upsampling_depth):
+ if i == j:
+ fuse_layer.append(None)
+ elif j - i == 1:
+ fuse_layer.append(None)
+ elif i - j == 1:
+ fuse_layer.append(None)
+ self.fuse_layers.append(fuse_layer)
+ self.concat_layer = nn.ModuleList([])
+ # ----------Concat Layer----------
+ for i in range(upsampling_depth):
+ if i == 0 or i == upsampling_depth - 1:
+ self.concat_layer.append(
+ ConvNormAct(
+ out_chan * 3,
+ out_chan,
+ 1,
+ 1,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ )
+ else:
+ self.concat_layer.append(
+ ConvNormAct(
+ out_chan * 4,
+ out_chan,
+ 1,
+ 1,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ )
+ self.last_layer = nn.Sequential(
+ ConvNormAct(
+ out_chan * upsampling_depth,
+ out_chan,
+ 1,
+ 1,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ )
+ self.res_conv = nn.Conv1d(out_chan, in_chan, 1)
+ # ----------parameters-------------
+ self.depth = upsampling_depth
+
+ def forward(self, residual, bottomup, topdown):
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ wav_length = bottomup[i].shape[-1]
+ y = torch.cat(
+ (
+ bottomup[i - 1]
+ if i - 1 >= 0
+ else torch.Tensor().to(bottomup[i].device),
+ bottomup[i],
+ bottomup[i + 1]
+ if i + 1 < self.depth
+ else torch.Tensor().to(bottomup[i].device),
+ F.interpolate(topdown, size=wav_length, mode="nearest"),
+ ),
+ dim=1,
+ )
+ x_fuse.append(self.concat_layer[i](y))
+
+ concat = self.last_layer(torch.cat(x_fuse, dim=1))
+ expanded = self.res_conv(concat)
+ return expanded + residual
+
+
+class FRCNNBlockTCN(nn.Module):
+ def __init__(
+ self,
+ in_chan=128,
+ out_chan=512,
+ upsampling_depth=4,
+ norm_type="gLN",
+ act_type="prelu",
+ ):
+ super().__init__()
+ self.proj_1x1 = ConvNormAct(
+ in_chan,
+ out_chan,
+ kernel_size=1,
+ stride=1,
+ groups=1,
+ dilation=1,
+ padding=0,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ self.depth = upsampling_depth
+ self.spp_dw = nn.ModuleList([])
+ self.spp_dw.append(
+ Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=True)
+ )
+ # ----------Down Sample Layer----------
+ for i in range(1, upsampling_depth):
+ self.spp_dw.append(
+ Video1DConv(out_chan, out_chan, 3, skip_con=False, first_block=False)
+ )
+ # ----------Fusion Layer----------
+ self.fuse_layers = nn.ModuleList([])
+ for i in range(upsampling_depth):
+ fuse_layer = nn.ModuleList([])
+ for j in range(upsampling_depth):
+ if i == j:
+ fuse_layer.append(None)
+ elif j - i == 1:
+ fuse_layer.append(None)
+ elif i - j == 1:
+ fuse_layer.append(None)
+ self.fuse_layers.append(fuse_layer)
+ self.concat_layer = nn.ModuleList([])
+ # ----------Concat Layer----------
+ for i in range(upsampling_depth):
+ if i == 0 or i == upsampling_depth - 1:
+ self.concat_layer.append(
+ ConvNormAct(
+ out_chan * 2,
+ out_chan,
+ 1,
+ 1,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ )
+ else:
+ self.concat_layer.append(
+ ConvNormAct(
+ out_chan * 3,
+ out_chan,
+ 1,
+ 1,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ )
+ self.last_layer = nn.Sequential(
+ ConvNormAct(
+ out_chan * upsampling_depth,
+ out_chan,
+ 1,
+ 1,
+ norm_type=norm_type,
+ act_type=act_type,
+ )
+ )
+ self.res_conv = nn.Conv1d(out_chan, in_chan, 1)
+ # ----------parameters-------------
+ self.depth = upsampling_depth
+
+ def forward(self, x):
+ """
+ :param x: input feature map
+ :return: transformed feature map
+ """
+ residual = x.clone()
+ # Reduce --> project high-dimensional feature maps to low-dimensional space
+ output1 = self.proj_1x1(x)
+ output = [self.spp_dw[0](output1)]
+ for k in range(1, self.depth):
+ out_k = self.spp_dw[k](output[-1])
+ output.append(out_k)
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ wav_length = output[i].shape[-1]
+ y = torch.cat(
+ (
+ output[i - 1] if i - 1 >= 0 else torch.Tensor().to(output1.device),
+ output[i],
+ output[i + 1]
+ if i + 1 < self.depth
+ else torch.Tensor().to(output1.device),
+ ),
+ dim=1,
+ )
+ x_fuse.append(self.concat_layer[i](y))
+
+ concat = self.last_layer(torch.cat(x_fuse, dim=1))
+ expanded = self.res_conv(concat)
+ return expanded + residual
+
+
+class TAC(nn.Module):
+ """Transform-Average-Concatenate inter-microphone-channel permutation invariant communication block [1].
+ Args:
+ input_dim (int): Number of features of input representation.
+ hidden_dim (int, optional): size of hidden layers in TAC operations.
+ activation (str, optional): type of activation used. See asteroid.masknn.activations.
+ norm_type (str, optional): type of normalization layer used. See asteroid.masknn.norms.
+ .. note:: Supports inputs of shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`
+ as in FasNet-TAC. The operations are applied for each element in ``chunk_size`` and ``n_chunks``.
+ Output is of same shape as input.
+ References
+ [1] : Luo, Yi, et al. "End-to-end microphone permutation and number invariant multi-channel
+ speech separation." ICASSP 2020.
+ """
+
+ def __init__(self, input_dim, hidden_dim=384, activation="prelu", norm_type="gLN"):
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.input_tf = nn.Sequential(
+ nn.Linear(input_dim, hidden_dim), activations.get(activation)()
+ )
+ self.avg_tf = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim), activations.get(activation)()
+ )
+ self.concat_tf = nn.Sequential(
+ nn.Linear(2 * hidden_dim, input_dim), activations.get(activation)()
+ )
+ self.norm = normalizations.get(norm_type)(input_dim)
+
+ def forward(self, x, valid_mics=None):
+ """
+ Args:
+ x: (:class:`torch.Tensor`): Input multi-channel DPRNN features.
+ Shape: :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`.
+ valid_mics: (:class:`torch.LongTensor`): tensor containing effective number of microphones on each batch.
+ Batches can be composed of examples coming from arrays with a different
+ number of microphones and thus the ``mic_channels`` dimension is padded.
+ E.g. torch.tensor([4, 3]) means first example has 4 channels and the second 3.
+ Shape: :math`(batch)`.
+ Returns:
+ output (:class:`torch.Tensor`): features for each mic_channel after TAC inter-channel processing.
+ Shape :math:`(batch, mic\_channels, features, chunk\_size, n\_chunks)`.
+ """
+ # Input is 5D because it is multi-channel DPRNN. DPRNN single channel is 4D.
+ batch_size, nmics, channels, chunk_size, n_chunks = x.size()
+ if valid_mics is None:
+ valid_mics = torch.LongTensor([nmics] * batch_size)
+ # First operation: transform the input for each frame and independently on each mic channel.
+ output = self.input_tf(
+ x.permute(0, 3, 4, 1, 2).reshape(
+ batch_size * nmics * chunk_size * n_chunks, channels
+ )
+ ).reshape(batch_size, chunk_size, n_chunks, nmics, self.hidden_dim)
+
+ # Mean pooling across channels
+ if valid_mics.max() == 0:
+ # Fixed geometry array
+ mics_mean = output.mean(1)
+ else:
+ # Only consider valid channels in each batch element: each example can have different number of microphones.
+ mics_mean = [
+ output[b, :, :, : valid_mics[b]].mean(2).unsqueeze(0)
+ for b in range(batch_size)
+ ] # 1, dim1*dim2, H
+ mics_mean = torch.cat(mics_mean, 0) # B*dim1*dim2, H
+
+ # The average is processed by a non-linear transform
+ mics_mean = self.avg_tf(
+ mics_mean.reshape(batch_size * chunk_size * n_chunks, self.hidden_dim)
+ )
+ mics_mean = (
+ mics_mean.reshape(batch_size, chunk_size, n_chunks, self.hidden_dim)
+ .unsqueeze(3)
+ .expand_as(output)
+ )
+
+ # Concatenate the transformed average in each channel with the original feats and
+ # project back to same number of features
+ output = torch.cat([output, mics_mean], -1)
+ output = self.concat_tf(
+ output.reshape(batch_size * chunk_size * n_chunks * nmics, -1)
+ ).reshape(batch_size, chunk_size, n_chunks, nmics, -1)
+ output = self.norm(
+ output.permute(0, 3, 4, 1, 2).reshape(
+ batch_size * nmics, -1, chunk_size, n_chunks
+ )
+ ).reshape(batch_size, nmics, -1, chunk_size, n_chunks)
+
+ output += x
+ return output
diff --git a/look2hear/layers/enc_dec.py b/look2hear/layers/enc_dec.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8315483c335e0caae3b3d928e05ccd83b7885ce
--- /dev/null
+++ b/look2hear/layers/enc_dec.py
@@ -0,0 +1,463 @@
+import warnings
+from typing import Optional
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def make_enc_dec(
+ fb_name,
+ n_filters,
+ kernel_size,
+ stride=None,
+ sample_rate=8000.0,
+ who_is_pinv=None,
+ padding=0,
+ output_padding=0,
+ **kwargs,
+):
+ """Creates congruent encoder and decoder from the same filterbank family.
+ Args:
+ fb_name (str, className): Filterbank family from which to make encoder
+ and decoder. To choose among [``'free'``, ``'analytic_free'``,
+ ``'param_sinc'``, ``'stft'``]. Can also be a class defined in a
+ submodule in this subpackade (e.g. :class:`~.FreeFB`).
+ n_filters (int): Number of filters.
+ kernel_size (int): Length of the filters.
+ stride (int, optional): Stride of the convolution.
+ If None (default), set to ``kernel_size // 2``.
+ sample_rate (float): Sample rate of the expected audio.
+ Defaults to 8000.0.
+ who_is_pinv (str, optional): If `None`, no pseudo-inverse filters will
+ be used. If string (among [``'encoder'``, ``'decoder'``]), decides
+ which of ``Encoder`` or ``Decoder`` will be the pseudo inverse of
+ the other one.
+ padding (int): Zero-padding added to both sides of the input.
+ Passed to Encoder and Decoder.
+ output_padding (int): Additional size added to one side of the output shape.
+ Passed to Decoder.
+ **kwargs: Arguments which will be passed to the filterbank class
+ additionally to the usual `n_filters`, `kernel_size` and `stride`.
+ Depends on the filterbank family.
+ Returns:
+ :class:`.Encoder`, :class:`.Decoder`
+ """
+ fb_class = get(fb_name)
+
+ if who_is_pinv in ["dec", "decoder"]:
+ fb = fb_class(
+ n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
+ )
+ enc = Encoder(fb, padding=padding)
+ # Decoder filterbank is pseudo inverse of encoder filterbank.
+ dec = Decoder.pinv_of(fb)
+ elif who_is_pinv in ["enc", "encoder"]:
+ fb = fb_class(
+ n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
+ )
+ dec = Decoder(fb, padding=padding, output_padding=output_padding)
+ # Encoder filterbank is pseudo inverse of decoder filterbank.
+ enc = Encoder.pinv_of(fb)
+ else:
+ fb = fb_class(
+ n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
+ )
+ enc = Encoder(fb, padding=padding)
+ # Filters between encoder and decoder should not be shared.
+ fb = fb_class(
+ n_filters, kernel_size, stride=stride, sample_rate=sample_rate, **kwargs
+ )
+ dec = Decoder(fb, padding=padding, output_padding=output_padding)
+ return enc, dec
+
+
+def register_filterbank(custom_fb):
+ """Register a custom filterbank, gettable with `filterbanks.get`.
+ Args:
+ custom_fb: Custom filterbank to register.
+ """
+ if (
+ custom_fb.__name__ in globals().keys()
+ or custom_fb.__name__.lower() in globals().keys()
+ ):
+ raise ValueError(
+ f"Filterbank {custom_fb.__name__} already exists. Choose another name."
+ )
+ globals().update({custom_fb.__name__: custom_fb})
+
+
+def get(identifier):
+ """Returns a filterbank class from a string. Returns its input if it
+ is callable (already a :class:`.Filterbank` for example).
+ Args:
+ identifier (str or Callable or None): the filterbank identifier.
+ Returns:
+ :class:`.Filterbank` or None
+ """
+ if identifier is None:
+ return None
+ elif callable(identifier):
+ return identifier
+ elif isinstance(identifier, str):
+ cls = globals().get(identifier)
+ if cls is None:
+ raise ValueError(
+ "Could not interpret filterbank identifier: " + str(identifier)
+ )
+ return cls
+ else:
+ raise ValueError(
+ "Could not interpret filterbank identifier: " + str(identifier)
+ )
+
+
+class Filterbank(nn.Module):
+ """Base Filterbank class.
+ Each subclass has to implement a ``filters`` method.
+ Args:
+ n_filters (int): Number of filters.
+ kernel_size (int): Length of the filters.
+ stride (int, optional): Stride of the conv or transposed conv. (Hop size).
+ If None (default), set to ``kernel_size // 2``.
+ sample_rate (float): Sample rate of the expected audio.
+ Defaults to 8000.
+ Attributes:
+ n_feats_out (int): Number of output filters.
+ """
+
+ def __init__(self, n_filters, kernel_size, stride=None, sample_rate=8000.0):
+ super(Filterbank, self).__init__()
+ self.n_filters = n_filters
+ self.kernel_size = kernel_size
+ self.stride = stride if stride else self.kernel_size // 2
+ # If not specified otherwise in the filterbank's init, output
+ # number of features is equal to number of required filters.
+ self.n_feats_out = n_filters
+ self.sample_rate = sample_rate
+
+ def filters(self):
+ """Abstract method for filters."""
+ raise NotImplementedError
+
+ def pre_analysis(self, wav: torch.Tensor):
+ """Apply transform before encoder convolution."""
+ return wav
+
+ def post_analysis(self, spec: torch.Tensor):
+ """Apply transform to encoder convolution."""
+ return spec
+
+ def pre_synthesis(self, spec: torch.Tensor):
+ """Apply transform before decoder transposed convolution."""
+ return spec
+
+ def post_synthesis(self, wav: torch.Tensor):
+ """Apply transform after decoder transposed convolution."""
+ return wav
+
+ def get_config(self):
+ """Returns dictionary of arguments to re-instantiate the class.
+ Needs to be subclassed if the filterbanks takes additional arguments
+ than ``n_filters`` ``kernel_size`` ``stride`` and ``sample_rate``.
+ """
+ config = {
+ "fb_name": self.__class__.__name__,
+ "n_filters": self.n_filters,
+ "kernel_size": self.kernel_size,
+ "stride": self.stride,
+ "sample_rate": self.sample_rate,
+ }
+ return config
+
+ def forward(self, waveform):
+ raise NotImplementedError(
+ "Filterbanks must be wrapped with an Encoder or a Decoder."
+ )
+
+
+class _EncDec(nn.Module):
+ """Base private class for Encoder and Decoder.
+ Common parameters and methods.
+ Args:
+ filterbank (:class:`Filterbank`): Filterbank instance. The filterbank
+ to use as an encoder or a decoder.
+ is_pinv (bool): Whether to be the pseudo inverse of filterbank.
+ Attributes:
+ filterbank (:class:`Filterbank`)
+ stride (int)
+ is_pinv (bool)
+ """
+
+ def __init__(self, filterbank, is_pinv=False):
+ super(_EncDec, self).__init__()
+ self.filterbank = filterbank
+ self.sample_rate = getattr(filterbank, "sample_rate", None)
+ self.stride = self.filterbank.stride
+ self.is_pinv = is_pinv
+
+ def filters(self):
+ return self.filterbank.filters()
+
+ def compute_filter_pinv(self, filters):
+ """Computes pseudo inverse filterbank of given filters."""
+ scale = self.filterbank.stride / self.filterbank.kernel_size
+ shape = filters.shape
+ ifilt = torch.pinverse(filters.squeeze()).transpose(-1, -2).view(shape)
+ # Compensate for the overlap-add.
+ return ifilt * scale
+
+ def get_filters(self):
+ """Returns filters or pinv filters depending on `is_pinv` attribute"""
+ if self.is_pinv:
+ return self.compute_filter_pinv(self.filters())
+ else:
+ return self.filters()
+
+ def get_config(self):
+ """Returns dictionary of arguments to re-instantiate the class."""
+ config = {"is_pinv": self.is_pinv}
+ base_config = self.filterbank.get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+
+class Encoder(_EncDec):
+ r"""Encoder class.
+ Add encoding methods to Filterbank classes.
+ Not intended to be subclassed.
+ Args:
+ filterbank (:class:`Filterbank`): The filterbank to use
+ as an encoder.
+ is_pinv (bool): Whether to be the pseudo inverse of filterbank.
+ as_conv1d (bool): Whether to behave like nn.Conv1d.
+ If True (default), forwarding input with shape :math:`(batch, 1, time)`
+ will output a tensor of shape :math:`(batch, freq, conv\_time)`.
+ If False, will output a tensor of shape :math:`(batch, 1, freq, conv\_time)`.
+ padding (int): Zero-padding added to both sides of the input.
+ """
+
+ def __init__(self, filterbank, is_pinv=False, as_conv1d=True, padding=0):
+ super(Encoder, self).__init__(filterbank, is_pinv=is_pinv)
+ self.as_conv1d = as_conv1d
+ self.n_feats_out = self.filterbank.n_feats_out
+ self.kernel_size = self.filterbank.kernel_size
+ self.padding = padding
+
+ @classmethod
+ def pinv_of(cls, filterbank, **kwargs):
+ """Returns an :class:`~.Encoder`, pseudo inverse of a
+ :class:`~.Filterbank` or :class:`~.Decoder`."""
+ if isinstance(filterbank, Filterbank):
+ return cls(filterbank, is_pinv=True, **kwargs)
+ elif isinstance(filterbank, Decoder):
+ return cls(filterbank.filterbank, is_pinv=True, **kwargs)
+
+ def forward(self, waveform):
+ """Convolve input waveform with the filters from a filterbank.
+ Args:
+ waveform (:class:`torch.Tensor`): any tensor with samples along the
+ last dimension. The waveform representation with and
+ batch/channel etc.. dimension.
+ Returns:
+ :class:`torch.Tensor`: The corresponding TF domain signal.
+ Shapes
+ >>> (time, ) -> (freq, conv_time)
+ >>> (batch, time) -> (batch, freq, conv_time) # Avoid
+ >>> if as_conv1d:
+ >>> (batch, 1, time) -> (batch, freq, conv_time)
+ >>> (batch, chan, time) -> (batch, chan, freq, conv_time)
+ >>> else:
+ >>> (batch, chan, time) -> (batch, chan, freq, conv_time)
+ >>> (batch, any, dim, time) -> (batch, any, dim, freq, conv_time)
+ """
+ filters = self.get_filters()
+ waveform = self.filterbank.pre_analysis(waveform)
+ spec = multishape_conv1d(
+ waveform,
+ filters=filters,
+ stride=self.stride,
+ padding=self.padding,
+ as_conv1d=self.as_conv1d,
+ )
+ return self.filterbank.post_analysis(spec)
+
+
+def multishape_conv1d(
+ waveform: torch.Tensor,
+ filters: torch.Tensor,
+ stride: int,
+ padding: int = 0,
+ as_conv1d: bool = True,
+) -> torch.Tensor:
+ if waveform.ndim == 1:
+ # Assumes 1D input with shape (time,)
+ # Output will be (freq, conv_time)
+ return F.conv1d(
+ waveform[None, None], filters, stride=stride, padding=padding
+ ).squeeze()
+ elif waveform.ndim == 2:
+ # Assume 2D input with shape (batch or channels, time)
+ # Output will be (batch or channels, freq, conv_time)
+ warnings.warn(
+ "Input tensor was 2D. Applying the corresponding "
+ "Decoder to the current output will result in a 3D "
+ "tensor. This behaviours was introduced to match "
+ "Conv1D and ConvTranspose1D, please use 3D inputs "
+ "to avoid it. For example, this can be done with "
+ "input_tensor.unsqueeze(1)."
+ )
+ return F.conv1d(waveform.unsqueeze(1), filters, stride=stride, padding=padding)
+ elif waveform.ndim == 3:
+ batch, channels, time_len = waveform.shape
+ if channels == 1 and as_conv1d:
+ # That's the common single channel case (batch, 1, time)
+ # Output will be (batch, freq, stft_time), behaves as Conv1D
+ return F.conv1d(waveform, filters, stride=stride, padding=padding)
+ else:
+ # Return batched convolution, input is (batch, 3, time), output will be
+ # (b, 3, f, conv_t). Useful for multichannel transforms. If as_conv1d is
+ # false, (batch, 1, time) will output (batch, 1, freq, conv_time), useful for
+ # consistency.
+ return batch_packed_1d_conv(
+ waveform, filters, stride=stride, padding=padding
+ )
+ else: # waveform.ndim > 3
+ # This is to compute "multi"multichannel convolution.
+ # Input can be (*, time), output will be (*, freq, conv_time)
+ return batch_packed_1d_conv(waveform, filters, stride=stride, padding=padding)
+
+
+def batch_packed_1d_conv(
+ inp: torch.Tensor, filters: torch.Tensor, stride: int = 1, padding: int = 0
+):
+ # Here we perform multichannel / multi-source convolution.
+ # Output should be (batch, channels, freq, conv_time)
+ batched_conv = F.conv1d(
+ inp.view(-1, 1, inp.shape[-1]), filters, stride=stride, padding=padding
+ )
+ output_shape = inp.shape[:-1] + batched_conv.shape[-2:]
+ return batched_conv.view(output_shape)
+
+
+class Decoder(_EncDec):
+ """Decoder class.
+ Add decoding methods to Filterbank classes.
+ Not intended to be subclassed.
+ Args:
+ filterbank (:class:`Filterbank`): The filterbank to use as an decoder.
+ is_pinv (bool): Whether to be the pseudo inverse of filterbank.
+ padding (int): Zero-padding added to both sides of the input.
+ output_padding (int): Additional size added to one side of the
+ output shape.
+ .. note::
+ ``padding`` and ``output_padding`` arguments are directly passed to
+ ``F.conv_transpose1d``.
+ """
+
+ def __init__(self, filterbank, is_pinv=False, padding=0, output_padding=0):
+ super().__init__(filterbank, is_pinv=is_pinv)
+ self.padding = padding
+ self.output_padding = output_padding
+
+ @classmethod
+ def pinv_of(cls, filterbank):
+ """Returns an Decoder, pseudo inverse of a filterbank or Encoder."""
+ if isinstance(filterbank, Filterbank):
+ return cls(filterbank, is_pinv=True)
+ elif isinstance(filterbank, Encoder):
+ return cls(filterbank.filterbank, is_pinv=True)
+
+ def forward(self, spec, length: Optional[int] = None) -> torch.Tensor:
+ """Applies transposed convolution to a TF representation.
+ This is equivalent to overlap-add.
+ Args:
+ spec (:class:`torch.Tensor`): 3D or 4D Tensor. The TF
+ representation. (Output of :func:`Encoder.forward`).
+ length: desired output length.
+ Returns:
+ :class:`torch.Tensor`: The corresponding time domain signal.
+ """
+ filters = self.get_filters()
+ spec = self.filterbank.pre_synthesis(spec)
+ wav = multishape_conv_transpose1d(
+ spec,
+ filters,
+ stride=self.stride,
+ padding=self.padding,
+ output_padding=self.output_padding,
+ )
+ wav = self.filterbank.post_synthesis(wav)
+ if length is not None:
+ length = min(length, wav.shape[-1])
+ return wav[..., :length]
+ return wav
+
+
+def multishape_conv_transpose1d(
+ spec: torch.Tensor,
+ filters: torch.Tensor,
+ stride: int = 1,
+ padding: int = 0,
+ output_padding: int = 0,
+) -> torch.Tensor:
+ if spec.ndim == 2:
+ # Input is (freq, conv_time), output is (time)
+ return F.conv_transpose1d(
+ spec.unsqueeze(0),
+ filters,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ ).squeeze()
+ if spec.ndim == 3:
+ # Input is (batch, freq, conv_time), output is (batch, 1, time)
+ return F.conv_transpose1d(
+ spec,
+ filters,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ )
+ else:
+ # Multiply all the left dimensions together and group them in the
+ # batch. Make the convolution and restore.
+ view_as = (-1,) + spec.shape[-2:]
+ out = F.conv_transpose1d(
+ spec.reshape(view_as),
+ filters,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ )
+ return out.view(spec.shape[:-2] + (-1,))
+
+
+class FreeFB(Filterbank):
+ """Free filterbank without any constraints. Equivalent to
+ :class:`nn.Conv1d`.
+ Args:
+ n_filters (int): Number of filters.
+ kernel_size (int): Length of the filters.
+ stride (int, optional): Stride of the convolution.
+ If None (default), set to ``kernel_size // 2``.
+ sample_rate (float): Sample rate of the expected audio.
+ Defaults to 8000.
+ Attributes:
+ n_feats_out (int): Number of output filters.
+ References
+ [1] : "Filterbank design for end-to-end speech separation". ICASSP 2020.
+ Manuel Pariente, Samuele Cornell, Antoine Deleforge, Emmanuel Vincent.
+ """
+
+ def __init__(
+ self, n_filters, kernel_size, stride=None, sample_rate=8000.0, **kwargs
+ ):
+ super().__init__(n_filters, kernel_size, stride=stride, sample_rate=sample_rate)
+ self._filters = nn.Parameter(torch.ones(n_filters, 1, kernel_size))
+ for p in self.parameters():
+ nn.init.xavier_normal_(p)
+
+ def filters(self):
+ return self._filters
+
+
+free = FreeFB
diff --git a/look2hear/layers/normalizations.py b/look2hear/layers/normalizations.py
new file mode 100644
index 0000000000000000000000000000000000000000..e83d95d25404d218143f0ed367ed3df396b4b787
--- /dev/null
+++ b/look2hear/layers/normalizations.py
@@ -0,0 +1,183 @@
+import torch
+from torch.autograd import Variable
+import torch.nn as nn
+import numpy as np
+from typing import List
+from torch.nn.modules.batchnorm import _BatchNorm
+from collections.abc import Iterable
+
+
+def norm(x, dims: List[int], EPS: float = 1e-8):
+ mean = x.mean(dim=dims, keepdim=True)
+ var2 = torch.var(x, dim=dims, keepdim=True, unbiased=False)
+ value = (x - mean) / torch.sqrt(var2 + EPS)
+ return value
+
+
+def glob_norm(x, ESP: float = 1e-8):
+ dims: List[int] = torch.arange(1, len(x.shape)).tolist()
+ return norm(x, dims, ESP)
+
+
+class MLayerNorm(nn.Module):
+ def __init__(self, channel_size):
+ super().__init__()
+ self.channel_size = channel_size
+ self.gamma = nn.Parameter(torch.ones(channel_size), requires_grad=True)
+ self.beta = nn.Parameter(torch.ones(channel_size), requires_grad=True)
+
+ def apply_gain_and_bias(self, normed_x):
+ """Assumes input of size `[batch, chanel, *]`."""
+ return (self.gamma * normed_x.transpose(1, -1) + self.beta).transpose(1, -1)
+
+ def forward(self, x, EPS: float = 1e-8):
+ pass
+
+
+class GlobalLN(MLayerNorm):
+ def forward(self, x, EPS: float = 1e-8):
+ value = glob_norm(x, EPS)
+ return self.apply_gain_and_bias(value)
+
+
+class ChannelLN(MLayerNorm):
+ def forward(self, x, EPS: float = 1e-8):
+ mean = torch.mean(x, dim=1, keepdim=True)
+ var = torch.var(x, dim=1, keepdim=True, unbiased=False)
+ return self.apply_gain_and_bias((x - mean) / (var + EPS).sqrt())
+
+
+# class CumulateLN(MLayerNorm):
+# def forward(self, x, EPS: float = 1e-8):
+# batch, channels, time = x.size()
+# cum_sum = torch.cumsum(x.sum(1, keepdim=True), dim=1)
+# cum_pow_sum = torch.cumsum(x.pow(2).sum(1, keepdim=True), dim=1)
+# cnt = torch.arange(
+# start=channels, end=channels * (time + 1), step=channels, dtype=x.dtype, device=x.device
+# ).view(1, 1, -1)
+# cum_mean = cum_sum / cnt
+# cum_var = (cum_pow_sum / cnt) - cum_mean.pow(2)
+# return self.apply_gain_and_bias((x - cum_mean) / (cum_var + EPS).sqrt())
+
+
+class BatchNorm(_BatchNorm):
+ """Wrapper class for pytorch BatchNorm1D and BatchNorm2D"""
+
+ def _check_input_dim(self, input):
+ if input.dim() < 2 or input.dim() > 4:
+ raise ValueError(
+ "expected 4D or 3D input (got {}D input)".format(input.dim())
+ )
+
+
+class CumulativeLayerNorm(nn.LayerNorm):
+ def __init__(self, dim, elementwise_affine=True):
+ super(CumulativeLayerNorm, self).__init__(
+ dim, elementwise_affine=elementwise_affine
+ )
+
+ def forward(self, x):
+ # x: N x C x L
+ # N x L x C
+ x = torch.transpose(x, 1, -1)
+ # N x L x C == only channel norm
+ x = super().forward(x)
+ # N x C x L
+ x = torch.transpose(x, 1, -1)
+ return x
+
+
+class CumulateLN(nn.Module):
+ def __init__(self, dimension, eps=1e-8, trainable=True):
+ super(CumulateLN, self).__init__()
+
+ self.eps = eps
+ if trainable:
+ self.gain = nn.Parameter(torch.ones(1, dimension, 1))
+ self.bias = nn.Parameter(torch.zeros(1, dimension, 1))
+ else:
+ self.gain = Variable(torch.ones(1, dimension, 1), requires_grad=False)
+ self.bias = Variable(torch.zeros(1, dimension, 1), requires_grad=False)
+
+ def forward(self, input):
+ # input size: (Batch, Freq, Time)
+ # cumulative mean for each time step
+
+ batch_size = input.size(0)
+ channel = input.size(1)
+ time_step = input.size(2)
+
+ step_sum = input.sum(1) # B, T
+ step_pow_sum = input.pow(2).sum(1) # B, T
+ cum_sum = torch.cumsum(step_sum, dim=1) # B, T
+ cum_pow_sum = torch.cumsum(step_pow_sum, dim=1) # B, T
+
+ entry_cnt = np.arange(channel, channel * (time_step + 1), channel)
+ entry_cnt = torch.from_numpy(entry_cnt).type(input.type())
+ entry_cnt = entry_cnt.view(1, -1).expand_as(cum_sum)
+
+ cum_mean = cum_sum / entry_cnt # B, T
+ cum_var = (cum_pow_sum - 2 * cum_mean * cum_sum) / entry_cnt + cum_mean.pow(
+ 2
+ ) # B, T
+ cum_std = (cum_var + self.eps).sqrt() # B, T
+
+ cum_mean = cum_mean.unsqueeze(1)
+ cum_std = cum_std.unsqueeze(1)
+
+ x = (input - cum_mean.expand_as(input)) / cum_std.expand_as(input)
+ return x * self.gain.expand_as(x).type(x.type()) + self.bias.expand_as(x).type(
+ x.type()
+ )
+
+class LayerNormalization4D(nn.Module):
+ def __init__(self, input_dimension: Iterable, eps: float = 1e-5):
+ super(LayerNormalization4D, self).__init__()
+ assert len(input_dimension) == 2
+ param_size = [1, input_dimension[0], 1, input_dimension[1]]
+
+ self.dim = (1, 3) if param_size[-1] > 1 else (1,)
+ self.gamma = nn.Parameter(torch.Tensor(*param_size).to(torch.float32))
+ self.beta = nn.Parameter(torch.Tensor(*param_size).to(torch.float32))
+ nn.init.ones_(self.gamma)
+ nn.init.zeros_(self.beta)
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor):
+ mu_ = x.mean(dim=self.dim, keepdim=True)
+ std_ = torch.sqrt(x.var(dim=self.dim, unbiased=False, keepdim=True) + self.eps)
+ x_hat = ((x - mu_) / std_) * self.gamma + self.beta
+ return x_hat
+
+# Aliases.
+gLN = GlobalLN
+cLN = CumulateLN
+LN = CumulativeLayerNorm
+bN = BatchNorm
+LN4D = LayerNormalization4D
+
+def get(identifier):
+ """Returns a norm class from a string. Returns its input if it
+ is callable (already a :class:`._LayerNorm` for example).
+
+ Args:
+ identifier (str or Callable or None): the norm identifier.
+
+ Returns:
+ :class:`._LayerNorm` or None
+ """
+ if identifier is None:
+ return None
+ elif callable(identifier):
+ return identifier
+ elif isinstance(identifier, str):
+ cls = globals().get(identifier)
+ if cls is None:
+ raise ValueError(
+ "Could not interpret normalization identifier: " + str(identifier)
+ )
+ return cls
+ else:
+ raise ValueError(
+ "Could not interpret normalization identifier: " + str(identifier)
+ )
diff --git a/look2hear/layers/rnnlayers.py b/look2hear/layers/rnnlayers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c657e7bf527b7512d6f3639af053cdf205a77c92
--- /dev/null
+++ b/look2hear/layers/rnnlayers.py
@@ -0,0 +1,921 @@
+import torch
+import math
+import inspect
+from torch import nn
+from torch import Tensor
+from typing import Tuple
+from typing import Optional
+from torch.nn.functional import fold, unfold
+import numpy as np
+
+from . import activations, normalizations
+from .normalizations import gLN
+
+
+def has_arg(fn, name):
+ """Checks if a callable accepts a given keyword argument.
+ Args:
+ fn (callable): Callable to inspect.
+ name (str): Check if ``fn`` can be called with ``name`` as a keyword
+ argument.
+ Returns:
+ bool: whether ``fn`` accepts a ``name`` keyword argument.
+ """
+ signature = inspect.signature(fn)
+ parameter = signature.parameters.get(name)
+ if parameter is None:
+ return False
+ return parameter.kind in (
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.KEYWORD_ONLY,
+ )
+
+
+class SingleRNN(nn.Module):
+ """Module for a RNN block.
+ Inspired from https://github.com/yluo42/TAC/blob/master/utility/models.py
+ Licensed under CC BY-NC-SA 3.0 US.
+ Args:
+ rnn_type (str): Select from ``'RNN'``, ``'LSTM'``, ``'GRU'``. Can
+ also be passed in lowercase letters.
+ input_size (int): Dimension of the input feature. The input should have
+ shape [batch, seq_len, input_size].
+ hidden_size (int): Dimension of the hidden state.
+ n_layers (int, optional): Number of layers used in RNN. Default is 1.
+ dropout (float, optional): Dropout ratio. Default is 0.
+ bidirectional (bool, optional): Whether the RNN layers are
+ bidirectional. Default is ``False``.
+ """
+
+ def __init__(
+ self,
+ rnn_type,
+ input_size,
+ hidden_size,
+ n_layers=1,
+ dropout=0,
+ bidirectional=False,
+ ):
+ super(SingleRNN, self).__init__()
+ assert rnn_type.upper() in ["RNN", "LSTM", "GRU"]
+ rnn_type = rnn_type.upper()
+ self.rnn_type = rnn_type
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.n_layers = n_layers
+ self.dropout = dropout
+ self.bidirectional = bidirectional
+ self.rnn = getattr(nn, rnn_type)(
+ input_size,
+ hidden_size,
+ num_layers=n_layers,
+ dropout=dropout,
+ batch_first=True,
+ bidirectional=bool(bidirectional),
+ )
+
+ @property
+ def output_size(self):
+ return self.hidden_size * (2 if self.bidirectional else 1)
+
+ def forward(self, inp):
+ """ Input shape [batch, seq, feats] """
+ self.rnn.flatten_parameters() # Enables faster multi-GPU training.
+ output = inp
+ rnn_output, _ = self.rnn(output)
+ return rnn_output
+
+
+class LSTMBlockTF(nn.Module):
+ def __init__(
+ self,
+ in_chan,
+ hid_size,
+ norm_type="gLN",
+ bidirectional=True,
+ rnn_type="LSTM",
+ num_layers=1,
+ dropout=0,
+ ):
+ super(LSTMBlockTF, self).__init__()
+ self.RNN = SingleRNN(
+ rnn_type,
+ in_chan,
+ hid_size,
+ num_layers,
+ dropout=dropout,
+ bidirectional=bidirectional,
+ )
+ self.linear = nn.Linear(self.RNN.output_size, in_chan)
+ self.norm = normalizations.get(norm_type)(in_chan)
+
+ def forward(self, x):
+ B, F, T = x.size()
+ output = self.RNN(x.transpose(1, 2)) # B, T, N
+ output = self.linear(output)
+ output = output.transpose(1, -1) # B, N, T
+ output = self.norm(output)
+ return output + x
+
+
+# ===================Transformer======================
+class Linear(nn.Module):
+ """
+ Wrapper class of torch.nn.Linear
+ Weight initialize by xavier initialization and bias initialize to zeros.
+ """
+
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
+ super(Linear, self).__init__()
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
+ nn.init.xavier_uniform_(self.linear.weight)
+ if bias:
+ nn.init.zeros_(self.linear.bias)
+
+ def forward(self, x):
+ return self.linear(x)
+
+
+class Swish(nn.Module):
+ """
+ Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied
+ to a variety of challenging domains such as Image classification and Machine translation.
+ """
+
+ def __init__(self):
+ super(Swish, self).__init__()
+
+ def forward(self, inputs):
+ return inputs * inputs.sigmoid()
+
+
+class Transpose(nn.Module):
+ """ Wrapper class of torch.transpose() for Sequential module. """
+
+ def __init__(self, shape: tuple):
+ super(Transpose, self).__init__()
+ self.shape = shape
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.transpose(*self.shape)
+
+
+class GLU(nn.Module):
+ """
+ The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing
+ in the paper “Language Modeling with Gated Convolutional Networks”
+ """
+
+ def __init__(self, dim: int) -> None:
+ super(GLU, self).__init__()
+ self.dim = dim
+
+ def forward(self, inputs: Tensor) -> Tensor:
+ outputs, gate = inputs.chunk(2, dim=self.dim)
+ return outputs * gate.sigmoid()
+
+
+class FeedForwardModule(nn.Module):
+ def __init__(
+ self, encoder_dim: int = 512, expansion_factor: int = 4, dropout_p: float = 0.1,
+ ) -> None:
+ super(FeedForwardModule, self).__init__()
+ self.sequential = nn.Sequential(
+ nn.LayerNorm(encoder_dim),
+ Linear(encoder_dim, encoder_dim * expansion_factor, bias=True),
+ Swish(),
+ nn.Dropout(p=dropout_p),
+ Linear(encoder_dim * expansion_factor, encoder_dim, bias=True),
+ nn.Dropout(p=dropout_p),
+ )
+
+ def forward(self, inputs):
+ return self.sequential(inputs)
+
+
+class PositionalEncoding(nn.Module):
+ """
+ Positional Encoding proposed in "Attention Is All You Need".
+ Since transformer contains no recurrence and no convolution, in order for the model to make
+ use of the order of the sequence, we must add some positional information.
+ "Attention Is All You Need" use sine and cosine functions of different frequencies:
+ PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model))
+ PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model))
+ """
+
+ def __init__(self, d_model: int = 512, max_len: int = 10000) -> None:
+ super(PositionalEncoding, self).__init__()
+ pe = torch.zeros(max_len, d_model, requires_grad=False)
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer("pe", pe)
+
+ def forward(self, length: int) -> Tensor:
+ return self.pe[:, :length]
+
+
+class RelativeMultiHeadAttention(nn.Module):
+ """
+ Multi-head attention with relative positional encoding.
+ This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+ Args:
+ d_model (int): The dimension of model
+ num_heads (int): The number of attention heads.
+ dropout_p (float): probability of dropout
+ Inputs: query, key, value, pos_embedding, mask
+ - **query** (batch, time, dim): Tensor containing query vector
+ - **key** (batch, time, dim): Tensor containing key vector
+ - **value** (batch, time, dim): Tensor containing value vector
+ - **pos_embedding** (batch, time, dim): Positional embedding tensor
+ - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
+ Returns:
+ - **outputs**: Tensor produces by relative multi head attention module.
+ """
+
+ def __init__(
+ self, d_model: int = 512, num_heads: int = 16, dropout_p: float = 0.1,
+ ):
+ super(RelativeMultiHeadAttention, self).__init__()
+ assert d_model % num_heads == 0, "d_model % num_heads should be zero."
+ self.d_model = d_model
+ self.d_head = int(d_model / num_heads)
+ self.num_heads = num_heads
+ self.sqrt_dim = math.sqrt(d_model)
+
+ self.query_proj = Linear(d_model, d_model)
+ self.key_proj = Linear(d_model, d_model)
+ self.value_proj = Linear(d_model, d_model)
+ self.pos_proj = Linear(d_model, d_model, bias=False)
+
+ self.dropout = nn.Dropout(p=dropout_p)
+ self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
+ self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
+ torch.nn.init.xavier_uniform_(self.u_bias)
+ torch.nn.init.xavier_uniform_(self.v_bias)
+
+ self.out_proj = Linear(d_model, d_model)
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ pos_embedding: Tensor,
+ mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ batch_size = value.size(0)
+
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
+ key = (
+ self.key_proj(key)
+ .view(batch_size, -1, self.num_heads, self.d_head)
+ .permute(0, 2, 1, 3)
+ )
+ value = (
+ self.value_proj(value)
+ .view(batch_size, -1, self.num_heads, self.d_head)
+ .permute(0, 2, 1, 3)
+ )
+ pos_embedding = self.pos_proj(pos_embedding).view(
+ batch_size, -1, self.num_heads, self.d_head
+ )
+
+ content_score = torch.matmul(
+ (query + self.u_bias).transpose(1, 2), key.transpose(2, 3)
+ )
+ pos_score = torch.matmul(
+ (query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1)
+ )
+ pos_score = self._relative_shift(pos_score)
+
+ score = (content_score + pos_score) / self.sqrt_dim
+
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+ score.masked_fill_(mask, -1e9)
+
+ attn = torch.nn.functional.softmax(score, -1)
+ attn = self.dropout(attn)
+
+ context = torch.matmul(attn, value).transpose(1, 2)
+ context = context.contiguous().view(batch_size, -1, self.d_model)
+
+ return self.out_proj(context)
+
+ def _relative_shift(self, pos_score: Tensor) -> Tensor:
+ batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
+ zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
+ padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
+
+ padded_pos_score = padded_pos_score.view(
+ batch_size, num_heads, seq_length2 + 1, seq_length1
+ )
+ pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
+
+ return pos_score
+
+
+class MultiHeadedSelfAttentionModule(nn.Module):
+ """
+ Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL,
+ the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention
+ module to generalize better on different input length and the resulting encoder is more robust to the variance of
+ the utterance length. Conformer use prenorm residual units with dropout which helps training
+ and regularizing deeper models.
+ Args:
+ d_model (int): The dimension of model
+ num_heads (int): The number of attention heads.
+ dropout_p (float): probability of dropout
+ device (torch.device): torch device (cuda or cpu)
+ Inputs: inputs, mask
+ - **inputs** (batch, time, dim): Tensor containing input vector
+ - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
+ Returns:
+ - **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module.
+ """
+
+ def __init__(
+ self, d_model: int, num_heads: int, dropout_p: float = 0.1, is_casual=True
+ ):
+ super(MultiHeadedSelfAttentionModule, self).__init__()
+ self.positional_encoding = PositionalEncoding(d_model)
+ self.layer_norm = nn.LayerNorm(d_model)
+ self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p)
+ self.dropout = nn.Dropout(p=dropout_p)
+ self.is_casual = is_casual
+
+ def forward(self, inputs: Tensor):
+ batch_size, seq_length, _ = inputs.size()
+ pos_embedding = self.positional_encoding(seq_length)
+ pos_embedding = pos_embedding.repeat(batch_size, 1, 1)
+
+ mask = None
+ if self.is_casual:
+ mask = torch.triu(
+ torch.ones((seq_length, seq_length), dtype=torch.uint8).to(
+ inputs.device
+ ),
+ diagonal=1,
+ )
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1).bool() # [B, L, L]
+
+ inputs = self.layer_norm(inputs)
+ outputs = self.attention(
+ inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask
+ )
+
+ return self.dropout(outputs)
+
+
+class ResidualConnectionModule(nn.Module):
+ """
+ Residual Connection Module.
+ outputs = (module(inputs) x module_factor + inputs x input_factor)
+ """
+
+ def __init__(
+ self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0
+ ):
+ super(ResidualConnectionModule, self).__init__()
+ self.module = module
+ self.module_factor = module_factor
+ self.input_factor = input_factor
+
+ def forward(self, inputs):
+ return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor)
+
+
+class DepthwiseConv1d(nn.Module):
+ """
+ When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
+ this operation is termed in literature as depthwise convolution.
+ Args:
+ in_channels (int): Number of channels in the input
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the convolving kernel
+ stride (int, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
+ bias (bool, optional): If True, adds a learnable bias to the output. Default: True
+ Inputs: inputs
+ - **inputs** (batch, in_channels, time): Tensor containing input vector
+ Returns: outputs
+ - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ bias: bool = False,
+ is_casual: bool = True,
+ ) -> None:
+ super(DepthwiseConv1d, self).__init__()
+ assert (
+ out_channels % in_channels == 0
+ ), "out_channels should be constant multiple of in_channels"
+ if is_casual:
+ padding = kernel_size - 1
+ else:
+ padding = (kernel_size - 1) // 2
+ self.conv = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ groups=in_channels,
+ stride=stride,
+ padding=padding,
+ bias=bias,
+ )
+ self.is_casual = is_casual
+ self.kernel_size = kernel_size
+
+ def forward(self, inputs: Tensor) -> Tensor:
+ if self.is_casual:
+ return self.conv(inputs)[:, :, : -(self.kernel_size - 1)]
+ return self.conv(inputs)
+
+
+class PointwiseConv1d(nn.Module):
+ """
+ When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution.
+ This operation often used to match dimensions.
+ Args:
+ in_channels (int): Number of channels in the input
+ out_channels (int): Number of channels produced by the convolution
+ stride (int, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
+ bias (bool, optional): If True, adds a learnable bias to the output. Default: True
+ Inputs: inputs
+ - **inputs** (batch, in_channels, time): Tensor containing input vector
+ Returns: outputs
+ - **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ stride: int = 1,
+ padding: int = 0,
+ bias: bool = True,
+ ) -> None:
+ super(PointwiseConv1d, self).__init__()
+ self.conv = nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ stride=stride,
+ padding=padding,
+ bias=bias,
+ )
+
+ def forward(self, inputs: Tensor) -> Tensor:
+ return self.conv(inputs)
+
+
+class ConformerConvModule(nn.Module):
+ """
+ Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
+ This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
+ to aid training deep models.
+ Args:
+ in_channels (int): Number of channels in the input
+ kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
+ dropout_p (float, optional): probability of dropout
+ device (torch.device): torch device (cuda or cpu)
+ Inputs: inputs
+ inputs (batch, time, dim): Tensor contains input sequences
+ Outputs: outputs
+ outputs (batch, time, dim): Tensor produces by conformer convolution module.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ kernel_size: int = 31,
+ expansion_factor: int = 2,
+ dropout_p: float = 0.1,
+ is_casual: bool = True,
+ ) -> None:
+ super(ConformerConvModule, self).__init__()
+ assert (
+ kernel_size - 1
+ ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
+ assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
+
+ self.sequential = nn.Sequential(
+ nn.LayerNorm(in_channels),
+ Transpose(shape=(1, 2)),
+ PointwiseConv1d(
+ in_channels,
+ in_channels * expansion_factor,
+ stride=1,
+ padding=0,
+ bias=True,
+ ),
+ GLU(dim=1),
+ DepthwiseConv1d(
+ in_channels, in_channels, kernel_size, stride=1, is_casual=is_casual
+ ),
+ nn.BatchNorm1d(in_channels),
+ Swish(),
+ PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True),
+ nn.Dropout(p=dropout_p),
+ )
+
+ def forward(self, inputs: Tensor) -> Tensor:
+ return self.sequential(inputs).transpose(1, 2)
+
+
+class TransformerLayer(nn.Module):
+ def __init__(
+ self, in_chan=128, n_head=8, n_att=1, dropout=0.1, max_len=500, is_casual=True
+ ):
+ super(TransformerLayer, self).__init__()
+ self.in_chan = in_chan
+ self.n_head = n_head
+ self.dropout = dropout
+ self.max_len = max_len
+ self.n_att = n_att
+
+ self.seq = nn.Sequential(
+ ResidualConnectionModule(
+ FeedForwardModule(in_chan, expansion_factor=4, dropout_p=dropout),
+ module_factor=0.5,
+ ),
+ ResidualConnectionModule(
+ MultiHeadedSelfAttentionModule(in_chan, n_head, dropout, is_casual)
+ ),
+ ResidualConnectionModule(
+ ConformerConvModule(in_chan, 31, 2, dropout, is_casual=is_casual)
+ ),
+ ResidualConnectionModule(
+ FeedForwardModule(in_chan, expansion_factor=4, dropout_p=dropout),
+ module_factor=0.5,
+ ),
+ nn.LayerNorm(in_chan),
+ )
+
+ def forward(self, x):
+ return self.seq(x)
+
+
+class TransformerBlockTF(nn.Module):
+ def __init__(
+ self,
+ in_chan,
+ n_head=8,
+ n_att=1,
+ dropout=0.1,
+ max_len=500,
+ norm_type="cLN",
+ is_casual=True,
+ ):
+ super(TransformerBlockTF, self).__init__()
+ self.transformer = TransformerLayer(
+ in_chan, n_head, n_att, dropout, max_len, is_casual
+ )
+ self.norm = normalizations.get(norm_type)(in_chan)
+
+ def forward(self, x):
+ B, F, T = x.size()
+ output = self.transformer(x.permute(0, 2, 1).contiguous()) # B, T, N
+ output = output.permute(0, 2, 1).contiguous() # B, N, T
+ output = self.norm(output)
+ return output + x
+
+
+# ====================================================
+
+
+class DPRNNBlock(nn.Module):
+ def __init__(
+ self,
+ in_chan,
+ hid_size,
+ norm_type="gLN",
+ bidirectional=True,
+ rnn_type="LSTM",
+ num_layers=1,
+ dropout=0,
+ ):
+ super(DPRNNBlock, self).__init__()
+ self.intra_RNN = SingleRNN(
+ rnn_type,
+ in_chan,
+ hid_size,
+ num_layers,
+ dropout=dropout,
+ bidirectional=True,
+ )
+ self.inter_RNN = SingleRNN(
+ rnn_type,
+ in_chan,
+ hid_size,
+ num_layers,
+ dropout=dropout,
+ bidirectional=bidirectional,
+ )
+ self.intra_linear = nn.Linear(self.intra_RNN.output_size, in_chan)
+ self.intra_norm = normalizations.get(norm_type)(in_chan)
+
+ self.inter_linear = nn.Linear(self.inter_RNN.output_size, in_chan)
+ self.inter_norm = normalizations.get(norm_type)(in_chan)
+
+ def forward(self, x):
+ """ Input shape : [batch, feats, chunk_size, num_chunks] """
+ B, N, K, L = x.size()
+ output = x # for skip connection
+ # Intra-chunk processing
+ x = x.transpose(1, -1).reshape(B * L, K, N)
+ x = self.intra_RNN(x)
+ x = self.intra_linear(x)
+ x = x.reshape(B, L, K, N).transpose(1, -1)
+ x = self.intra_norm(x)
+ output = output + x
+ # Inter-chunk processing
+ x = output.transpose(1, 2).transpose(2, -1).reshape(B * K, L, N)
+ x = self.inter_RNN(x)
+ x = self.inter_linear(x)
+ x = x.reshape(B, K, L, N).transpose(1, -1).transpose(2, -1).contiguous()
+ x = self.inter_norm(x)
+ return output + x
+
+
+class DPRNN(nn.Module):
+ def __init__(
+ self,
+ in_chan,
+ n_src,
+ out_chan=None,
+ bn_chan=128,
+ hid_size=128,
+ chunk_size=100,
+ hop_size=None,
+ n_repeats=6,
+ norm_type="gLN",
+ mask_act="relu",
+ bidirectional=True,
+ rnn_type="LSTM",
+ num_layers=1,
+ dropout=0,
+ ):
+ super(DPRNN, self).__init__()
+ self.in_chan = in_chan
+ out_chan = out_chan if out_chan is not None else in_chan
+ self.out_chan = out_chan
+ self.bn_chan = bn_chan
+ self.hid_size = hid_size
+ self.chunk_size = chunk_size
+ hop_size = hop_size if hop_size is not None else chunk_size // 2
+ self.hop_size = hop_size
+ self.n_repeats = n_repeats
+ self.n_src = n_src
+ self.norm_type = norm_type
+ self.mask_act = mask_act
+ self.bidirectional = bidirectional
+ self.rnn_type = rnn_type
+ self.num_layers = num_layers
+ self.dropout = dropout
+
+ layer_norm = normalizations.get(norm_type)(in_chan)
+ bottleneck_conv = nn.Conv1d(in_chan, bn_chan, 1)
+ self.bottleneck = nn.Sequential(layer_norm, bottleneck_conv)
+
+ # Succession of DPRNNBlocks.
+ net = []
+ for x in range(self.n_repeats):
+ net += [
+ DPRNNBlock(
+ bn_chan,
+ hid_size,
+ norm_type=norm_type,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ num_layers=num_layers,
+ dropout=dropout,
+ )
+ ]
+ self.net = nn.Sequential(*net)
+ # Masking in 3D space
+ net_out_conv = nn.Conv2d(bn_chan, n_src * bn_chan, 1)
+ self.first_out = nn.Sequential(nn.PReLU(), net_out_conv)
+ # Gating and masking in 2D space (after fold)
+ self.net_out = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Tanh())
+ self.net_gate = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Sigmoid())
+ self.mask_net = nn.Conv1d(bn_chan, out_chan, 1, bias=False)
+
+ # Get activation function.
+ mask_nl_class = activations.get(mask_act)
+ # For softmax, feed the source dimension.
+ if has_arg(mask_nl_class, "dim"):
+ self.output_act = mask_nl_class(dim=1)
+ else:
+ self.output_act = mask_nl_class()
+
+ def forward(self, mixture_w):
+ r"""Forward.
+ Args:
+ mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$
+ Returns:
+ :class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$
+ """
+ batch, n_filters, n_frames = mixture_w.size()
+ output = self.bottleneck(mixture_w) # [batch, bn_chan, n_frames]
+ output = unfold(
+ output.unsqueeze(-1),
+ kernel_size=(self.chunk_size, 1),
+ padding=(self.chunk_size, 0),
+ stride=(self.hop_size, 1),
+ )
+ n_chunks = output.shape[-1]
+ output = output.reshape(batch, self.bn_chan, self.chunk_size, n_chunks)
+ # Apply stacked DPRNN Blocks sequentially
+ output = self.net(output)
+ # Map to sources with kind of 2D masks
+ output = self.first_out(output)
+ output = output.reshape(
+ batch * self.n_src, self.bn_chan, self.chunk_size, n_chunks
+ )
+ # Overlap and add:
+ # [batch, out_chan, chunk_size, n_chunks] -> [batch, out_chan, n_frames]
+ to_unfold = self.bn_chan * self.chunk_size
+ output = fold(
+ output.reshape(batch * self.n_src, to_unfold, n_chunks),
+ (n_frames, 1),
+ kernel_size=(self.chunk_size, 1),
+ padding=(self.chunk_size, 0),
+ stride=(self.hop_size, 1),
+ )
+ # Apply gating
+ output = output.reshape(batch * self.n_src, self.bn_chan, -1)
+ # output = self.net_out(output) * self.net_gate(output)
+ # Compute mask
+ score = self.mask_net(output)
+ est_mask = self.output_act(score)
+ est_mask = est_mask.view(batch, self.n_src, self.out_chan, n_frames)
+ return est_mask
+
+ def get_config(self):
+ config = {
+ "in_chan": self.in_chan,
+ "out_chan": self.out_chan,
+ "bn_chan": self.bn_chan,
+ "hid_size": self.hid_size,
+ "chunk_size": self.chunk_size,
+ "hop_size": self.hop_size,
+ "n_repeats": self.n_repeats,
+ "n_src": self.n_src,
+ "norm_type": self.norm_type,
+ "mask_act": self.mask_act,
+ "bidirectional": self.bidirectional,
+ "rnn_type": self.rnn_type,
+ "num_layers": self.num_layers,
+ "dropout": self.dropout,
+ }
+ return config
+
+
+class DPRNNLinear(nn.Module):
+ def __init__(
+ self,
+ in_chan,
+ n_src,
+ out_chan=None,
+ bn_chan=128,
+ hid_size=128,
+ chunk_size=100,
+ hop_size=None,
+ n_repeats=6,
+ norm_type="gLN",
+ mask_act="relu",
+ bidirectional=True,
+ rnn_type="LSTM",
+ num_layers=1,
+ dropout=0,
+ ):
+ super(DPRNNLinear, self).__init__()
+ self.in_chan = in_chan
+ out_chan = out_chan if out_chan is not None else in_chan
+ self.out_chan = out_chan
+ self.bn_chan = bn_chan
+ self.hid_size = hid_size
+ self.chunk_size = chunk_size
+ hop_size = hop_size if hop_size is not None else chunk_size // 2
+ self.hop_size = hop_size
+ self.n_repeats = n_repeats
+ self.n_src = n_src
+ self.norm_type = norm_type
+ self.mask_act = mask_act
+ self.bidirectional = bidirectional
+ self.rnn_type = rnn_type
+ self.num_layers = num_layers
+ self.dropout = dropout
+
+ layer_norm = normalizations.get(norm_type)(in_chan)
+ bottleneck_conv = nn.Conv1d(in_chan, bn_chan, 1)
+ self.bottleneck = nn.Sequential(layer_norm, bottleneck_conv)
+
+ # Succession of DPRNNBlocks.
+ net = []
+ for x in range(self.n_repeats):
+ net += [
+ DPRNNBlock(
+ bn_chan,
+ hid_size,
+ norm_type=norm_type,
+ bidirectional=bidirectional,
+ rnn_type=rnn_type,
+ num_layers=num_layers,
+ dropout=dropout,
+ )
+ ]
+ self.net = nn.Sequential(*net)
+ # Masking in 3D space
+ net_out_conv = nn.Conv2d(bn_chan, n_src * bn_chan, 1)
+ self.first_out = nn.Sequential(nn.PReLU(), net_out_conv)
+ # Gating and masking in 2D space (after fold)
+ # self.net_out = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Tanh())
+ self.net_out = nn.Linear(bn_chan, out_chan)
+ self.net_gate = nn.Sequential(nn.Conv1d(bn_chan, bn_chan, 1), nn.Sigmoid())
+ self.mask_net = nn.Conv1d(bn_chan, out_chan, 1, bias=False)
+
+ # Get activation function.
+ mask_nl_class = activations.get(mask_act)
+ # For softmax, feed the source dimension.
+ if has_arg(mask_nl_class, "dim"):
+ self.output_act = mask_nl_class(dim=1)
+ else:
+ self.output_act = mask_nl_class()
+
+ def forward(self, mixture_w):
+ r"""Forward.
+ Args:
+ mixture_w (:class:`torch.Tensor`): Tensor of shape $(batch, nfilters, nframes)$
+ Returns:
+ :class:`torch.Tensor`: estimated mask of shape $(batch, nsrc, nfilters, nframes)$
+ """
+ batch, n_filters, n_frames = mixture_w.size()
+ output = self.bottleneck(mixture_w) # [batch, bn_chan, n_frames]
+ output = unfold(
+ output.unsqueeze(-1),
+ kernel_size=(self.chunk_size, 1),
+ padding=(self.chunk_size, 0),
+ stride=(self.hop_size, 1),
+ )
+ n_chunks = output.shape[-1]
+ output = output.reshape(batch, self.bn_chan, self.chunk_size, n_chunks)
+ # Apply stacked DPRNN Blocks sequentially
+ output = self.net(output)
+ # Map to sources with kind of 2D masks
+ output = self.first_out(output)
+ output = output.reshape(
+ batch * self.n_src, self.bn_chan, self.chunk_size, n_chunks
+ )
+ # Overlap and add:
+ # [batch, out_chan, chunk_size, n_chunks] -> [batch, out_chan, n_frames]
+ to_unfold = self.bn_chan * self.chunk_size
+ output = fold(
+ output.reshape(batch * self.n_src, to_unfold, n_chunks),
+ (n_frames, 1),
+ kernel_size=(self.chunk_size, 1),
+ padding=(self.chunk_size, 0),
+ stride=(self.hop_size, 1),
+ )
+ # Apply gating
+ output = output.reshape(batch * self.n_src, self.bn_chan, -1)
+ output = self.net_out(output.transpose(1, 1)).transpose(1, 2) * self.net_gate(
+ output
+ )
+ # Compute mask
+ score = self.mask_net(output)
+ est_mask = self.output_act(score)
+ est_mask = est_mask.view(batch, self.n_src, self.out_chan, n_frames)
+ return est_mask
+
+ def get_config(self):
+ config = {
+ "in_chan": self.in_chan,
+ "out_chan": self.out_chan,
+ "bn_chan": self.bn_chan,
+ "hid_size": self.hid_size,
+ "chunk_size": self.chunk_size,
+ "hop_size": self.hop_size,
+ "n_repeats": self.n_repeats,
+ "n_src": self.n_src,
+ "norm_type": self.norm_type,
+ "mask_act": self.mask_act,
+ "bidirectional": self.bidirectional,
+ "rnn_type": self.rnn_type,
+ "num_layers": self.num_layers,
+ "dropout": self.dropout,
+ }
+ return config
diff --git a/look2hear/layers/stft.py b/look2hear/layers/stft.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bef618093aa219f59cbd967f4ed134d20b0837e
--- /dev/null
+++ b/look2hear/layers/stft.py
@@ -0,0 +1,797 @@
+# Copyright 2019 Jian Wu
+# License: Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as tf
+import librosa.filters as filters
+
+from typing import Optional, Tuple
+from distutils.version import LooseVersion
+
+EPSILON = float(np.finfo(np.float32).eps)
+TORCH_VERSION = th.__version__
+
+if TORCH_VERSION >= LooseVersion("1.7"):
+ from torch.fft import fft as fft_func
+else:
+ pass
+
+
+def export_jit(transform: nn.Module) -> nn.Module:
+ """
+ Export transform module for inference
+ """
+ export_out = [module for module in transform if module.exportable()]
+ return nn.Sequential(*export_out)
+
+
+def init_window(wnd: str, frame_len: int, device: th.device = "cpu") -> th.Tensor:
+ """
+ Return window coefficient
+ Args:
+ wnd: window name
+ frame_len: length of the frame
+ """
+
+ def sqrthann(frame_len, periodic=True):
+ return th.hann_window(frame_len, periodic=periodic) ** 0.5
+
+ if wnd not in ["bartlett", "hann", "hamm", "blackman", "rect", "sqrthann"]:
+ raise RuntimeError(f"Unknown window type: {wnd}")
+
+ wnd_tpl = {
+ "sqrthann": sqrthann,
+ "hann": th.hann_window,
+ "hamm": th.hamming_window,
+ "blackman": th.blackman_window,
+ "bartlett": th.bartlett_window,
+ "rect": th.ones,
+ }
+ if wnd != "rect":
+ # match with librosa
+ c = wnd_tpl[wnd](frame_len, periodic=True)
+ else:
+ c = wnd_tpl[wnd](frame_len)
+ return c.to(device)
+
+
+def init_kernel(
+ frame_len: int,
+ frame_hop: int,
+ window: th.Tensor,
+ round_pow_of_two: bool = True,
+ normalized: bool = False,
+ inverse: bool = False,
+ mode: str = "librosa",
+) -> Tuple[th.Tensor, th.Tensor]:
+ """
+ Return STFT kernels
+ Args:
+ frame_len: length of the frame
+ frame_hop: hop size between frames
+ window: window tensor
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
+ normalized: return normalized DFT matrix
+ inverse: return iDFT matrix
+ mode: framing mode (librosa or kaldi)
+ """
+ if mode not in ["librosa", "kaldi"]:
+ raise ValueError(f"Unsupported mode: {mode}")
+ # FFT size: B
+ if round_pow_of_two or mode == "kaldi":
+ fft_size = 2 ** math.ceil(math.log2(frame_len))
+ else:
+ fft_size = frame_len
+ # center padding window if needed
+ if mode == "librosa" and fft_size != frame_len:
+ lpad = (fft_size - frame_len) // 2
+ window = tf.pad(window, (lpad, fft_size - frame_len - lpad))
+ if normalized:
+ # make K^H * K = I
+ S = fft_size ** 0.5
+ else:
+ S = 1
+ # W x B x 2
+ if TORCH_VERSION >= LooseVersion("1.7"):
+ K = fft_func(th.eye(fft_size) / S, dim=-1)
+ K = th.stack([K.real, K.imag], dim=-1)
+ else:
+ I = th.stack([th.eye(fft_size), th.zeros(fft_size, fft_size)], dim=-1)
+ K = th.fft(I / S, 1)
+ if mode == "kaldi":
+ K = K[:frame_len]
+ if inverse and not normalized:
+ # to make K^H * K = I
+ K = K / fft_size
+ # 2 x B x W
+ K = th.transpose(K, 0, 2)
+ # 2B x 1 x W
+ K = th.reshape(K, (fft_size * 2, 1, K.shape[-1]))
+ return K.to(window.device), window
+
+
+def mel_filter(
+ frame_len: int,
+ round_pow_of_two: bool = True,
+ num_bins: Optional[int] = None,
+ sr: int = 16000,
+ num_mels: int = 80,
+ fmin: float = 0.0,
+ fmax: Optional[float] = None,
+ norm: bool = False,
+) -> th.Tensor:
+ """
+ Return mel filter coefficients
+ Args:
+ frame_len: length of the frame
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
+ num_bins: number of the frequency bins produced by STFT
+ num_mels: number of the mel bands
+ fmin: lowest frequency (in Hz)
+ fmax: highest frequency (in Hz)
+ norm: normalize the mel filter coefficients
+ """
+ # FFT points
+ if num_bins is None:
+ N = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
+ else:
+ N = (num_bins - 1) * 2
+ # fmin & fmax
+ freq_upper = sr // 2
+ if fmax is None:
+ fmax = freq_upper
+ else:
+ fmax = min(fmax + freq_upper if fmax < 0 else fmax, freq_upper)
+ fmin = max(0, fmin)
+ # mel filter coefficients
+ mel = filters.mel(
+ sr,
+ N,
+ n_mels=num_mels,
+ fmax=fmax,
+ fmin=fmin,
+ htk=True,
+ norm="slaney" if norm else None,
+ )
+ # num_mels x (N // 2 + 1)
+ return th.tensor(mel, dtype=th.float32)
+
+
+def speed_perturb_filter(
+ src_sr: int, dst_sr: int, cutoff_ratio: float = 0.95, num_zeros: int = 64
+) -> th.Tensor:
+ """
+ Return speed perturb filters, reference:
+ https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py
+ Args:
+ src_sr: sample rate of the source signal
+ dst_sr: sample rate of the target signal
+ Return:
+ weight (Tensor): coefficients of the filter
+ """
+ if src_sr == dst_sr:
+ raise ValueError(f"src_sr should not be equal to dst_sr: {src_sr}/{dst_sr}")
+ gcd = math.gcd(src_sr, dst_sr)
+ src_sr = src_sr // gcd
+ dst_sr = dst_sr // gcd
+ if src_sr == 1 or dst_sr == 1:
+ raise ValueError("do not support integer downsample/upsample")
+ zeros_per_block = min(src_sr, dst_sr) * cutoff_ratio
+ padding = 1 + int(num_zeros / zeros_per_block)
+ # dst_sr x src_sr x K
+ times = (
+ np.arange(dst_sr)[:, None, None] / float(dst_sr)
+ - np.arange(src_sr)[None, :, None] / float(src_sr)
+ - np.arange(2 * padding + 1)[None, None, :]
+ + padding
+ )
+ window = np.heaviside(1 - np.abs(times / padding), 0.0) * (
+ 0.5 + 0.5 * np.cos(times / padding * math.pi)
+ )
+ weight = np.sinc(times * zeros_per_block) * window * zeros_per_block / float(src_sr)
+ return th.tensor(weight, dtype=th.float32)
+
+
+def splice_feature(
+ feats: th.Tensor, lctx: int = 1, rctx: int = 1, op: str = "cat"
+) -> th.Tensor:
+ """
+ Splice feature
+ Args:
+ feats (Tensor): N x ... x T x F, original feature
+ lctx: left context
+ rctx: right context
+ op: operator on feature context
+ Return:
+ splice (Tensor): feature with context padded
+ """
+ if lctx + rctx == 0:
+ return feats
+ if op not in ["cat", "stack"]:
+ raise ValueError(f"Unknown op for feature splicing: {op}")
+ # [N x ... x T x F, ...]
+ ctx = []
+ T = feats.shape[-2]
+ for c in range(-lctx, rctx + 1):
+ idx = th.arange(c, c + T, device=feats.device, dtype=th.int64)
+ idx = th.clamp(idx, min=0, max=T - 1)
+ ctx.append(th.index_select(feats, -2, idx))
+ if op == "cat":
+ # N x ... x T x FD
+ splice = th.cat(ctx, -1)
+ else:
+ # N x ... x T x F x D
+ splice = th.stack(ctx, -1)
+ return splice
+
+
+def _forward_stft(
+ wav: th.Tensor,
+ kernel: th.Tensor,
+ window: th.Tensor,
+ return_polar: bool = False,
+ pre_emphasis: float = 0,
+ frame_hop: int = 256,
+ onesided: bool = False,
+ center: bool = False,
+ eps: float = EPSILON,
+) -> th.Tensor:
+ """
+ STFT function implemented by conv1d (not efficient, but we don't care during training)
+ Args:
+ wav (Tensor): N x (C) x S
+ kernel (Tensor): STFT transform kernels, from init_kernel(...)
+ return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor
+ pre_emphasis: factor of preemphasis
+ frame_hop: frame hop size in number samples
+ onesided: return half FFT bins
+ center: if true, we assumed to have centered frames
+ Return:
+ transform (Tensor): STFT transform results
+ """
+ wav_dim = wav.dim()
+ if wav_dim not in [2, 3]:
+ raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D")
+ # if N x S, reshape N x 1 x S
+ # else: reshape NC x 1 x S
+ N, S = wav.shape[0], wav.shape[-1]
+ wav = wav.view(-1, 1, S)
+ # NC x 1 x S+2P
+ if center:
+ pad = kernel.shape[-1] // 2
+ # NOTE: match with librosa
+ wav = tf.pad(wav, (pad, pad), mode="reflect")
+ # STFT
+ kernel = kernel * window
+ if pre_emphasis > 0:
+ # NC x W x T
+ frames = tf.unfold(
+ wav[:, None], (1, kernel.shape[-1]), stride=frame_hop, padding=0
+ )
+ # follow Kaldi's Preemphasize
+ frames[:, 1:] = frames[:, 1:] - pre_emphasis * frames[:, :-1]
+ frames[:, 0] *= 1 - pre_emphasis
+ # 1 x 2B x W, NC x W x T, NC x 2B x T
+ packed = th.matmul(kernel[:, 0][None, ...], frames)
+ else:
+ packed = tf.conv1d(wav, kernel, stride=frame_hop, padding=0)
+ # NC x 2B x T => N x C x 2B x T
+ if wav_dim == 3:
+ packed = packed.view(N, -1, packed.shape[-2], packed.shape[-1])
+ # N x (C) x B x T
+ real, imag = th.chunk(packed, 2, dim=-2)
+ # N x (C) x B/2+1 x T
+ if onesided:
+ num_bins = kernel.shape[0] // 4 + 1
+ real = real[..., :num_bins, :]
+ imag = imag[..., :num_bins, :]
+ if return_polar:
+ mag = (real ** 2 + imag ** 2 + eps) ** 0.5
+ pha = th.atan2(imag, real)
+ return th.stack([mag, pha], dim=-1)
+ else:
+ return th.stack([real, imag], dim=-1)
+
+
+def _inverse_stft(
+ transform: th.Tensor,
+ kernel: th.Tensor,
+ window: th.Tensor,
+ return_polar: bool = False,
+ frame_hop: int = 256,
+ onesided: bool = False,
+ center: bool = False,
+ eps: float = EPSILON,
+) -> th.Tensor:
+ """
+ iSTFT function implemented by conv1d
+ Args:
+ transform (Tensor): STFT transform results
+ kernel (Tensor): STFT transform kernels, from init_kernel(...)
+ return_polar (bool): keep same with the one in _forward_stft
+ frame_hop: frame hop size in number samples
+ onesided: return half FFT bins
+ center: used in _forward_stft
+ Return:
+ wav (Tensor), N x S
+ """
+ # (N) x F x T x 2
+ transform_dim = transform.dim()
+ # if F x T x 2, reshape 1 x F x T x 2
+ if transform_dim == 3:
+ transform = th.unsqueeze(transform, 0)
+ if transform_dim != 4:
+ raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D")
+
+ if return_polar:
+ real = transform[..., 0] * th.cos(transform[..., 1])
+ imag = transform[..., 0] * th.sin(transform[..., 1])
+ else:
+ real, imag = transform[..., 0], transform[..., 1]
+
+ if onesided:
+ # [self.num_bins - 2, ..., 1]
+ reverse = range(kernel.shape[0] // 4 - 1, 0, -1)
+ # extend matrix: N x B x T
+ real = th.cat([real, real[:, reverse]], 1)
+ imag = th.cat([imag, -imag[:, reverse]], 1)
+ # pack: N x 2B x T
+ packed = th.cat([real, imag], dim=1)
+ # N x 1 x T
+ wav = tf.conv_transpose1d(packed, kernel * window, stride=frame_hop, padding=0)
+ # normalized audio samples
+ # refer: https://github.com/pytorch/audio/blob/2ebbbf511fb1e6c47b59fd32ad7e66023fa0dff1/torchaudio/functional.py#L171
+ num_frames = packed.shape[-1]
+ win_length = window.shape[0]
+ # W x T
+ win = th.repeat_interleave(window[..., None] ** 2, num_frames, dim=-1)
+ # Do OLA on windows
+ # v1)
+ I = th.eye(win_length, device=win.device)[:, None]
+ denorm = tf.conv_transpose1d(win[None, ...], I, stride=frame_hop, padding=0)
+ # v2)
+ # num_samples = (num_frames - 1) * frame_hop + win_length
+ # denorm = tf.fold(win[None, ...], (num_samples, 1), (win_length, 1),
+ # stride=frame_hop)[..., 0]
+ if center:
+ pad = kernel.shape[-1] // 2
+ wav = wav[..., pad:-pad]
+ denorm = denorm[..., pad:-pad]
+ wav = wav / (denorm + eps)
+ # N x S
+ return wav.squeeze(1)
+
+
+def _pytorch_stft(
+ wav: th.Tensor,
+ frame_len: int,
+ frame_hop: int,
+ n_fft: int = 512,
+ return_polar: bool = False,
+ window: str = "sqrthann",
+ normalized: bool = False,
+ onesided: bool = True,
+ center: bool = False,
+ eps: float = EPSILON,
+) -> th.Tensor:
+ """
+ Wrapper of PyTorch STFT function
+ Args:
+ wav (Tensor): source audio signal
+ frame_len: length of the frame
+ frame_hop: hop size between frames
+ n_fft: number of the FFT size
+ return_polar: return the results in polar coordinate
+ window: window tensor
+ center: same definition with the parameter in librosa.stft
+ normalized: use normalized DFT kernel
+ onesided: output onesided STFT
+ Return:
+ transform (Tensor), STFT transform results
+ """
+ if TORCH_VERSION < LooseVersion("1.7"):
+ raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7")
+ wav_dim = wav.dim()
+ if wav_dim not in [2, 3]:
+ raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D")
+ # if N x C x S, reshape NC x S
+ wav = wav.view(-1, wav.shape[-1])
+ # STFT: N x F x T x 2
+ stft = th.stft(
+ wav,
+ n_fft,
+ hop_length=frame_hop,
+ win_length=window.shape[-1],
+ window=window,
+ center=center,
+ normalized=normalized,
+ onesided=onesided,
+ return_complex=False,
+ )
+ if wav_dim == 3:
+ N, F, T, _ = stft.shape
+ stft = stft.view(N, -1, F, T, 2)
+ # N x (C) x F x T x 2
+ if not return_polar:
+ return stft
+ # N x (C) x F x T
+ real, imag = stft[..., 0], stft[..., 1]
+ mag = (real ** 2 + imag ** 2 + eps) ** 0.5
+ pha = th.atan2(imag, real)
+ return th.stack([mag, pha], dim=-1)
+
+
+def _pytorch_istft(
+ transform: th.Tensor,
+ frame_len: int,
+ frame_hop: int,
+ window: th.Tensor,
+ n_fft: int = 512,
+ return_polar: bool = False,
+ normalized: bool = False,
+ onesided: bool = True,
+ center: bool = False,
+ eps: float = EPSILON,
+) -> th.Tensor:
+ """
+ Wrapper of PyTorch iSTFT function
+ Args:
+ transform (Tensor): results of STFT
+ frame_len: length of the frame
+ frame_hop: hop size between frames
+ window: window tensor
+ n_fft: number of the FFT size
+ return_polar: keep same with _pytorch_stft
+ center: same definition with the parameter in librosa.stft
+ normalized: use normalized DFT kernel
+ onesided: output onesided STFT
+ Return:
+ wav (Tensor): synthetic audio
+ """
+ if TORCH_VERSION < LooseVersion("1.7"):
+ raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7")
+
+ transform_dim = transform.dim()
+ # if F x T x 2, reshape 1 x F x T x 2
+ if transform_dim == 3:
+ transform = th.unsqueeze(transform, 0)
+ if transform_dim != 4:
+ raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D")
+
+ if return_polar:
+ real = transform[..., 0] * th.cos(transform[..., 1])
+ imag = transform[..., 0] * th.sin(transform[..., 1])
+ transform = th.stack([real, imag], -1)
+ # stft is a complex tensor of PyTorch
+ stft = th.view_as_complex(transform)
+ # (N) x S
+ wav = th.istft(
+ stft,
+ n_fft,
+ hop_length=frame_hop,
+ win_length=window.shape[-1],
+ window=window,
+ center=center,
+ normalized=normalized,
+ onesided=onesided,
+ return_complex=False,
+ )
+ return wav
+
+
+def forward_stft(
+ wav: th.Tensor,
+ frame_len: int,
+ frame_hop: int,
+ window: str = "sqrthann",
+ round_pow_of_two: bool = True,
+ return_polar: bool = False,
+ pre_emphasis: float = 0,
+ normalized: bool = False,
+ onesided: bool = True,
+ center: bool = False,
+ mode: str = "librosa",
+ eps: float = EPSILON,
+) -> th.Tensor:
+ """
+ STFT function implementation, equals to STFT layer
+ Args:
+ wav: source audio signal
+ frame_len: length of the frame
+ frame_hop: hop size between frames
+ return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor
+ window: window name
+ center: center flag (similar with that in librosa.stft)
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
+ pre_emphasis: factor of preemphasis
+ normalized: use normalized DFT kernel
+ onesided: output onesided STFT
+ inverse: using iDFT kernel (for iSTFT)
+ mode: STFT mode, "kaldi" or "librosa" or "torch"
+ Return:
+ transform: results of STFT
+ """
+ window = init_window(window, frame_len, device=wav.device)
+ if mode == "torch":
+ n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
+ return _pytorch_stft(
+ wav,
+ frame_len,
+ frame_hop,
+ n_fft=n_fft,
+ return_polar=return_polar,
+ window=window,
+ normalized=normalized,
+ onesided=onesided,
+ center=center,
+ eps=eps,
+ )
+ else:
+ kernel, window = init_kernel(
+ frame_len,
+ frame_hop,
+ window=window,
+ round_pow_of_two=round_pow_of_two,
+ normalized=normalized,
+ inverse=False,
+ mode=mode,
+ )
+ return _forward_stft(
+ wav,
+ kernel,
+ window,
+ return_polar=return_polar,
+ frame_hop=frame_hop,
+ pre_emphasis=pre_emphasis,
+ onesided=onesided,
+ center=center,
+ eps=eps,
+ )
+
+
+def inverse_stft(
+ transform: th.Tensor,
+ frame_len: int,
+ frame_hop: int,
+ return_polar: bool = False,
+ window: str = "sqrthann",
+ round_pow_of_two: bool = True,
+ normalized: bool = False,
+ onesided: bool = True,
+ center: bool = False,
+ mode: str = "librosa",
+ eps: float = EPSILON,
+) -> th.Tensor:
+ """
+ iSTFT function implementation, equals to iSTFT layer
+ Args:
+ transform: results of STFT
+ frame_len: length of the frame
+ frame_hop: hop size between frames
+ return_polar: keep same with function forward_stft(...)
+ window: window name
+ center: center flag (similar with that in librosa.stft)
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
+ normalized: use normalized DFT kernel
+ onesided: output onesided STFT
+ mode: STFT mode, "kaldi" or "librosa" or "torch"
+ Return:
+ wav: synthetic signals
+ """
+ window = init_window(window, frame_len, device=transform.device)
+ if mode == "torch":
+ n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
+ return _pytorch_istft(
+ transform,
+ frame_len,
+ frame_hop,
+ n_fft=n_fft,
+ return_polar=return_polar,
+ window=window,
+ normalized=normalized,
+ onesided=onesided,
+ center=center,
+ eps=eps,
+ )
+ else:
+ kernel, window = init_kernel(
+ frame_len,
+ frame_hop,
+ window,
+ round_pow_of_two=round_pow_of_two,
+ normalized=normalized,
+ inverse=True,
+ mode=mode,
+ )
+ return _inverse_stft(
+ transform,
+ kernel,
+ window,
+ return_polar=return_polar,
+ frame_hop=frame_hop,
+ onesided=onesided,
+ center=center,
+ eps=eps,
+ )
+
+
+class STFTBase(nn.Module):
+ """
+ Base layer for (i)STFT
+ Args:
+ frame_len: length of the frame
+ frame_hop: hop size between frames
+ window: window name
+ center: center flag (similar with that in librosa.stft)
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
+ normalized: use normalized DFT kernel
+ pre_emphasis: factor of preemphasis
+ mode: STFT mode, "kaldi" or "librosa" or "torch"
+ onesided: output onesided STFT
+ inverse: using iDFT kernel (for iSTFT)
+ """
+
+ def __init__(
+ self,
+ frame_len: int,
+ frame_hop: int,
+ window: str = "sqrthann",
+ round_pow_of_two: bool = True,
+ normalized: bool = False,
+ pre_emphasis: float = 0,
+ onesided: bool = True,
+ inverse: bool = False,
+ center: bool = False,
+ mode: str = "librosa",
+ ) -> None:
+ super(STFTBase, self).__init__()
+ if mode != "torch":
+ K, w = init_kernel(
+ frame_len,
+ frame_hop,
+ init_window(window, frame_len),
+ round_pow_of_two=round_pow_of_two,
+ normalized=normalized,
+ inverse=inverse,
+ mode=mode,
+ )
+ self.K = nn.Parameter(K, requires_grad=False)
+ self.w = nn.Parameter(w, requires_grad=False)
+ self.num_bins = self.K.shape[0] // 4 + 1
+ self.pre_emphasis = pre_emphasis
+ self.win_length = self.K.shape[2]
+ else:
+ self.K = None
+ w = init_window(window, frame_len)
+ self.w = nn.Parameter(w, requires_grad=False)
+ fft_size = (
+ 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
+ )
+ self.num_bins = fft_size // 2 + 1
+ self.pre_emphasis = 0
+ self.win_length = fft_size
+ self.frame_len = frame_len
+ self.frame_hop = frame_hop
+ self.window = window
+ self.normalized = normalized
+ self.onesided = onesided
+ self.center = center
+ self.mode = mode
+
+ def num_frames(self, wav_len: th.Tensor) -> th.Tensor:
+ """
+ Compute number of the frames
+ """
+ assert th.sum(wav_len <= self.win_length) == 0
+ if self.center:
+ wav_len += self.win_length
+ return (
+ th.div(wav_len - self.win_length, self.frame_hop, rounding_mode="trunc") + 1
+ )
+
+ def extra_repr(self) -> str:
+ str_repr = (
+ f"num_bins={self.num_bins}, win_length={self.win_length}, "
+ + f"stride={self.frame_hop}, window={self.window}, "
+ + f"center={self.center}, mode={self.mode}"
+ )
+ if not self.onesided:
+ str_repr += f", onesided={self.onesided}"
+ if self.pre_emphasis > 0:
+ str_repr += f", pre_emphasis={self.pre_emphasis}"
+ if self.normalized:
+ str_repr += f", normalized={self.normalized}"
+ return str_repr
+
+
+class STFT(STFTBase):
+ """
+ Short-time Fourier Transform as a Layer
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(STFT, self).__init__(*args, inverse=False, **kwargs)
+
+ def forward(
+ self, wav: th.Tensor, return_polar: bool = False, eps: float = EPSILON
+ ) -> th.Tensor:
+ """
+ Accept (single or multiple channel) raw waveform and output magnitude and phase
+ Args
+ wav (Tensor) input signal, N x (C) x S
+ Return
+ transform (Tensor), N x (C) x F x T x 2
+ """
+ if self.mode == "torch":
+ return _pytorch_stft(
+ wav,
+ self.frame_len,
+ self.frame_hop,
+ n_fft=(self.num_bins - 1) * 2,
+ return_polar=return_polar,
+ window=self.w,
+ normalized=self.normalized,
+ onesided=self.onesided,
+ center=self.center,
+ eps=eps,
+ )
+ else:
+ return _forward_stft(
+ wav,
+ self.K,
+ self.w,
+ return_polar=return_polar,
+ frame_hop=self.frame_hop,
+ pre_emphasis=self.pre_emphasis,
+ onesided=self.onesided,
+ center=self.center,
+ eps=eps,
+ )
+
+
+class iSTFT(STFTBase):
+ """
+ Inverse Short-time Fourier Transform as a Layer
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(iSTFT, self).__init__(*args, inverse=True, **kwargs)
+
+ def forward(
+ self, transform: th.Tensor, return_polar: bool = False, eps: float = EPSILON
+ ) -> th.Tensor:
+ """
+ Accept phase & magnitude and output raw waveform
+ Args
+ transform (Tensor): STFT output, N x F x T x 2
+ Return
+ s (Tensor): N x S
+ """
+ if self.mode == "torch":
+ return _pytorch_istft(
+ transform,
+ self.frame_len,
+ self.frame_hop,
+ n_fft=(self.num_bins - 1) * 2,
+ return_polar=return_polar,
+ window=self.w,
+ normalized=self.normalized,
+ onesided=self.onesided,
+ center=self.center,
+ eps=eps,
+ )
+ else:
+ return _inverse_stft(
+ transform,
+ self.K,
+ self.w,
+ return_polar=return_polar,
+ frame_hop=self.frame_hop,
+ onesided=self.onesided,
+ center=self.center,
+ eps=eps,
+ )
diff --git a/look2hear/layers/stft_tfgn.py b/look2hear/layers/stft_tfgn.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc399d78bd5e8827ab9401bf08cac34f03a2288d
--- /dev/null
+++ b/look2hear/layers/stft_tfgn.py
@@ -0,0 +1,245 @@
+from typing import Optional, Tuple, Union
+
+import librosa
+import numpy as np
+import torch
+from packaging.version import parse as V
+from torch_complex.tensor import ComplexTensor
+from typeguard import check_argument_types
+
+from ..utils.complex_utils import is_complex
+from ..utils.inversible_interface import InversibleInterface
+from ..utils.nets_utils import make_pad_mask
+
+is_torch_1_10_plus = V(torch.__version__) >= V("1.10.0")
+
+
+is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
+
+
+is_torch_1_7_plus = V(torch.__version__) >= V("1.7")
+
+
+class Stft(torch.nn.Module, InversibleInterface):
+ def __init__(
+ self,
+ n_fft: int = 512,
+ win_length: int = None,
+ hop_length: int = 128,
+ window: Optional[str] = "hann",
+ center: bool = True,
+ normalized: bool = False,
+ onesided: bool = True,
+ ):
+ assert check_argument_types()
+ super().__init__()
+ self.n_fft = n_fft
+ if win_length is None:
+ self.win_length = n_fft
+ else:
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.normalized = normalized
+ self.onesided = onesided
+ if window is not None and not hasattr(torch, f"{window}_window"):
+ raise ValueError(f"{window} window is not implemented")
+ self.window = window
+
+ def extra_repr(self):
+ return (
+ f"n_fft={self.n_fft}, "
+ f"win_length={self.win_length}, "
+ f"hop_length={self.hop_length}, "
+ f"center={self.center}, "
+ f"normalized={self.normalized}, "
+ f"onesided={self.onesided}"
+ )
+
+ def forward(
+ self, input: torch.Tensor, ilens: torch.Tensor = None
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """STFT forward function.
+
+ Args:
+ input: (Batch, Nsamples) or (Batch, Nsample, Channels)
+ ilens: (Batch)
+ Returns:
+ output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
+
+ """
+ bs = input.size(0)
+ if input.dim() == 3:
+ multi_channel = True
+ # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
+ input = input.transpose(1, 2).reshape(-1, input.size(1))
+ else:
+ multi_channel = False
+
+ # NOTE(kamo):
+ # The default behaviour of torch.stft is compatible with librosa.stft
+ # about padding and scaling.
+ # Note that it's different from scipy.signal.stft
+
+ # output: (Batch, Freq, Frames, 2=real_imag)
+ # or (Batch, Channel, Freq, Frames, 2=real_imag)
+ if self.window is not None:
+ window_func = getattr(torch, f"{self.window}_window")
+ window = window_func(
+ self.win_length, dtype=input.dtype, device=input.device
+ )
+ else:
+ window = None
+
+ # For the compatibility of ARM devices, which do not support
+ # torch.stft() due to the lack of MKL (on older pytorch versions),
+ # there is an alternative replacement implementation with librosa.
+ # Note: pytorch >= 1.10.0 now has native support for FFT and STFT
+ # on all cpu targets including ARM.
+ if is_torch_1_10_plus or input.is_cuda or torch.backends.mkl.is_available():
+ stft_kwargs = dict(
+ n_fft=self.n_fft,
+ win_length=self.win_length,
+ hop_length=self.hop_length,
+ center=self.center,
+ window=window,
+ normalized=self.normalized,
+ onesided=self.onesided,
+ )
+ if is_torch_1_7_plus:
+ stft_kwargs["return_complex"] = False
+ output = torch.stft(input, **stft_kwargs)
+ else:
+ if self.training:
+ raise NotImplementedError(
+ "stft is implemented with librosa on this device, which does not "
+ "support the training mode."
+ )
+
+ # use stft_kwargs to flexibly control different PyTorch versions' kwargs
+ # note: librosa does not support a win_length that is < n_ftt
+ # but the window can be manually padded (see below).
+ stft_kwargs = dict(
+ n_fft=self.n_fft,
+ win_length=self.n_fft,
+ hop_length=self.hop_length,
+ center=self.center,
+ window=window,
+ pad_mode="reflect",
+ )
+
+ if window is not None:
+ # pad the given window to n_fft
+ n_pad_left = (self.n_fft - window.shape[0]) // 2
+ n_pad_right = self.n_fft - window.shape[0] - n_pad_left
+ stft_kwargs["window"] = torch.cat(
+ [torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0
+ ).numpy()
+ else:
+ win_length = (
+ self.win_length if self.win_length is not None else self.n_fft
+ )
+ stft_kwargs["window"] = torch.ones(win_length)
+
+ output = []
+ # iterate over istances in a batch
+ for i, instance in enumerate(input):
+ stft = librosa.stft(input[i].numpy(), **stft_kwargs)
+ output.append(torch.tensor(np.stack([stft.real, stft.imag], -1)))
+ output = torch.stack(output, 0)
+ if not self.onesided:
+ len_conj = self.n_fft - output.shape[1]
+ conj = output[:, 1 : 1 + len_conj].flip(1)
+ conj[:, :, :, -1].data *= -1
+ output = torch.cat([output, conj], 1)
+ if self.normalized:
+ output = output * (stft_kwargs["window"].shape[0] ** (-0.5))
+
+ # output: (Batch, Freq, Frames, 2=real_imag)
+ # -> (Batch, Frames, Freq, 2=real_imag)
+ output = output.transpose(1, 2)
+ if multi_channel:
+ # output: (Batch * Channel, Frames, Freq, 2=real_imag)
+ # -> (Batch, Frame, Channel, Freq, 2=real_imag)
+ output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
+ 1, 2
+ )
+
+ if ilens is not None:
+ if self.center:
+ pad = self.n_fft // 2
+ ilens = ilens + 2 * pad
+
+ if is_torch_1_9_plus:
+ olens = (
+ torch.div(
+ ilens - self.n_fft, self.hop_length, rounding_mode="trunc"
+ )
+ + 1
+ )
+ else:
+ olens = (ilens - self.n_fft) // self.hop_length + 1
+ output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
+ else:
+ olens = None
+
+ return output, olens
+
+ def inverse(
+ self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor = None
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Inverse STFT.
+
+ Args:
+ input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F)
+ ilens: (batch,)
+ Returns:
+ wavs: (batch, samples)
+ ilens: (batch,)
+ """
+ if V(torch.__version__) >= V("1.6.0"):
+ istft = torch.functional.istft
+ else:
+ try:
+ import torchaudio
+ except ImportError:
+ raise ImportError(
+ "Please install torchaudio>=0.3.0 or use torch>=1.6.0"
+ )
+
+ if not hasattr(torchaudio.functional, "istft"):
+ raise ImportError(
+ "Please install torchaudio>=0.3.0 or use torch>=1.6.0"
+ )
+ istft = torchaudio.functional.istft
+
+ if self.window is not None:
+ window_func = getattr(torch, f"{self.window}_window")
+ if is_complex(input):
+ datatype = input.real.dtype
+ else:
+ datatype = input.dtype
+ window = window_func(self.win_length, dtype=datatype, device=input.device)
+ else:
+ window = None
+
+ if is_complex(input):
+ input = torch.stack([input.real, input.imag], dim=-1)
+ elif input.shape[-1] != 2:
+ raise TypeError("Invalid input type")
+ input = input.transpose(1, 2)
+ input = torch.complex(input[:,:,:,0], input[:,:,:,1])
+
+ wavs = istft(
+ input,
+ n_fft=self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=window,
+ center=self.center,
+ normalized=self.normalized,
+ onesided=self.onesided,
+ length=ilens.max() if ilens is not None else ilens,
+ )
+
+ return wavs, ilens
\ No newline at end of file
diff --git a/look2hear/losses/__init__.py b/look2hear/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..76c874db26f4479517a2866969b70e9b342240c8
--- /dev/null
+++ b/look2hear/losses/__init__.py
@@ -0,0 +1,37 @@
+from .matrix import pairwise_neg_sisdr
+from .matrix import pairwise_neg_sdsdr
+from .matrix import pairwise_neg_snr
+from .matrix import singlesrc_neg_sisdr
+from .matrix import singlesrc_neg_sdsdr
+from .matrix import singlesrc_neg_snr
+from .matrix import multisrc_neg_sisdr
+from .matrix import multisrc_neg_sdsdr
+from .matrix import multisrc_neg_snr
+from .matrix import freq_mae_wavl1loss
+from .matrix import pairwise_neg_sisdr_freq_mse
+from .matrix import pairwise_neg_snr_multidecoder
+from .pit_wrapper import PITLossWrapper
+from .mixit import MixITLossWrapper
+from .matrix import PairwiseNegSDR
+from .matrix import SingleSrcNegSDR
+from .sisnri import SISNRi
+
+__all__ = [
+ "SISNRi",
+ "MixITLossWrapper",
+ "PITLossWrapper",
+ "PairwiseNegSDR",
+ "SingleSrcNegSDR",
+ "singlesrc_neg_sisdr",
+ "pairwise_neg_sisdr",
+ "multisrc_neg_sisdr",
+ "pairwise_neg_sdsdr",
+ "singlesrc_neg_sdsdr",
+ "multisrc_neg_sdsdr",
+ "pairwise_neg_snr",
+ "singlesrc_neg_snr",
+ "multisrc_neg_snr",
+ "freq_mae_wavl1loss",
+ "pairwise_neg_sisdr_freq_mse",
+ "pairwise_neg_snr_multidecoder"
+]
diff --git a/look2hear/losses/matrix.py b/look2hear/losses/matrix.py
new file mode 100644
index 0000000000000000000000000000000000000000..07a7acc4a9db34e82730f42408a6a8b1a80786d8
--- /dev/null
+++ b/look2hear/losses/matrix.py
@@ -0,0 +1,200 @@
+import torch
+from torch.nn.modules.loss import _Loss
+
+
+class PairwiseNegSDR(_Loss):
+ def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8):
+ super().__init__()
+ assert sdr_type in ["snr", "sisdr", "sdsdr"]
+ self.sdr_type = sdr_type
+ self.zero_mean = zero_mean
+ self.take_log = take_log
+ self.EPS = EPS
+
+ def forward(self, ests, targets):
+ if targets.size() != ests.size() or targets.ndim != 3:
+ raise TypeError(
+ f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead"
+ )
+ assert targets.size() == ests.size()
+ # Step 1. Zero-mean norm
+ if self.zero_mean:
+ mean_source = torch.mean(targets, dim=2, keepdim=True)
+ mean_estimate = torch.mean(ests, dim=2, keepdim=True)
+ targets = targets - mean_source
+ ests = ests - mean_estimate
+ # Step 2. Pair-wise SI-SDR. (Reshape to use broadcast)
+ s_target = torch.unsqueeze(targets, dim=1)
+ s_estimate = torch.unsqueeze(ests, dim=2)
+ if self.sdr_type in ["sisdr", "sdsdr"]:
+ # [batch, n_src, n_src, 1]
+ pair_wise_dot = torch.sum(s_estimate * s_target, dim=3, keepdim=True)
+ # [batch, 1, n_src, 1]
+ s_target_energy = torch.sum(s_target ** 2, dim=3, keepdim=True) + self.EPS
+ # [batch, n_src, n_src, time]
+ pair_wise_proj = pair_wise_dot * s_target / s_target_energy
+ else:
+ # [batch, n_src, n_src, time]
+ pair_wise_proj = s_target.repeat(1, s_target.shape[2], 1, 1)
+ if self.sdr_type in ["sdsdr", "snr"]:
+ e_noise = s_estimate - s_target
+ else:
+ e_noise = s_estimate - pair_wise_proj
+ # [batch, n_src, n_src]
+ pair_wise_sdr = torch.sum(pair_wise_proj ** 2, dim=3) / (
+ torch.sum(e_noise ** 2, dim=3) + self.EPS
+ )
+ if self.take_log:
+ pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS)
+ return -pair_wise_sdr
+
+
+class SingleSrcNegSDR(_Loss):
+ def __init__(
+ self, sdr_type, zero_mean=True, take_log=True, reduction="none", EPS=1e-8
+ ):
+ assert reduction != "sum", NotImplementedError
+ super().__init__(reduction=reduction)
+
+ assert sdr_type in ["snr", "sisdr", "sdsdr"]
+ self.sdr_type = sdr_type
+ self.zero_mean = zero_mean
+ self.take_log = take_log
+ self.EPS = 1e-8
+
+ def forward(self, ests, targets):
+ if targets.size() != ests.size() or targets.ndim != 2:
+ raise TypeError(
+ f"Inputs must be of shape [batch, time], got {targets.size()} and {ests.size()} instead"
+ )
+ # Step 1. Zero-mean norm
+ if self.zero_mean:
+ mean_source = torch.mean(targets, dim=1, keepdim=True)
+ mean_estimate = torch.mean(ests, dim=1, keepdim=True)
+ targets = targets - mean_source
+ ests = ests - mean_estimate
+ # Step 2. Pair-wise SI-SDR.
+ if self.sdr_type in ["sisdr", "sdsdr"]:
+ # [batch, 1]
+ dot = torch.sum(ests * targets, dim=1, keepdim=True)
+ # [batch, 1]
+ s_target_energy = torch.sum(targets ** 2, dim=1, keepdim=True) + self.EPS
+ # [batch, time]
+ scaled_target = dot * targets / s_target_energy
+ else:
+ # [batch, time]
+ scaled_target = targets
+ if self.sdr_type in ["sdsdr", "snr"]:
+ e_noise = ests - targets
+ else:
+ e_noise = ests - scaled_target
+ # [batch]
+ losses = torch.sum(scaled_target ** 2, dim=1) / (
+ torch.sum(e_noise ** 2, dim=1) + self.EPS
+ )
+ if self.take_log:
+ losses = 10 * torch.log10(losses + self.EPS)
+ losses = losses.mean() if self.reduction == "mean" else losses
+ return -losses
+
+
+class MultiSrcNegSDR(_Loss):
+ def __init__(self, sdr_type, zero_mean=True, take_log=True, EPS=1e-8):
+ super().__init__()
+
+ assert sdr_type in ["snr", "sisdr", "sdsdr"]
+ self.sdr_type = sdr_type
+ self.zero_mean = zero_mean
+ self.take_log = take_log
+ self.EPS = 1e-8
+
+ def forward(self, ests, targets):
+ if targets.size() != ests.size() or targets.ndim != 3:
+ raise TypeError(
+ f"Inputs must be of shape [batch, n_src, time], got {targets.size()} and {ests.size()} instead"
+ )
+ # Step 1. Zero-mean norm
+ if self.zero_mean:
+ mean_source = torch.mean(targets, dim=2, keepdim=True)
+ mean_est = torch.mean(ests, dim=2, keepdim=True)
+ targets = targets - mean_source
+ ests = ests - mean_est
+ # Step 2. Pair-wise SI-SDR.
+ if self.sdr_type in ["sisdr", "sdsdr"]:
+ # [batch, n_src]
+ pair_wise_dot = torch.sum(ests * targets, dim=2, keepdim=True)
+ # [batch, n_src]
+ s_target_energy = torch.sum(targets ** 2, dim=2, keepdim=True) + self.EPS
+ # [batch, n_src, time]
+ scaled_targets = pair_wise_dot * targets / s_target_energy
+ else:
+ # [batch, n_src, time]
+ scaled_targets = targets
+ if self.sdr_type in ["sdsdr", "snr"]:
+ e_noise = ests - targets
+ else:
+ e_noise = ests - scaled_targets
+ # [batch, n_src]
+ pair_wise_sdr = torch.sum(scaled_targets ** 2, dim=2) / (
+ torch.sum(e_noise ** 2, dim=2) + self.EPS
+ )
+ if self.take_log:
+ pair_wise_sdr = 10 * torch.log10(pair_wise_sdr + self.EPS)
+ return -torch.mean(pair_wise_sdr, dim=-1)
+
+class freq_MAE_WavL1Loss(_Loss):
+ def __init__(self, win=2048, stride=512):
+ super().__init__()
+ self.EPS = 1e-8
+ self.win = win
+ self.stride = stride
+
+ def forward(self, ests, targets):
+ B, nsrc, _ = ests.shape
+ est_spec = torch.stft(ests.view(-1, ests.shape[-1]), n_fft=self.win, hop_length=self.stride,
+ window=torch.hann_window(self.win).to(ests.device).float(),
+ return_complex=True)
+ est_target = torch.stft(targets.view(-1, targets.shape[-1]), n_fft=self.win, hop_length=self.stride,
+ window=torch.hann_window(self.win).to(ests.device).float(),
+ return_complex=True)
+ freq_L1 = (est_spec.real - est_target.real).abs().mean((1,2)) + (est_spec.imag - est_target.imag).abs().mean((1,2))
+ freq_L1 = freq_L1.view(B, nsrc).mean(-1)
+
+ wave_l1 = (ests - targets).abs().mean(-1)
+ # print(freq_L1.shape, wave_l1.shape)
+ wave_l1 = wave_l1.view(B, nsrc).mean(-1)
+ return freq_L1 + wave_l1
+
+class freq_MSE_Loss(_Loss):
+ def __init__(self, win=640, stride=160):
+ super().__init__()
+ self.EPS = 1e-8
+ self.win = win
+ self.stride = stride
+
+ def forward(self, ests, targets):
+ B, nsrc, _ = ests.shape
+ est_spec = torch.stft(ests.view(-1, ests.shape[-1]), n_fft=self.win, hop_length=self.stride,
+ window=torch.hann_window(self.win).to(ests.device).float(),
+ return_complex=True)
+ est_target = torch.stft(targets.view(-1, targets.shape[-1]), n_fft=self.win, hop_length=self.stride,
+ window=torch.hann_window(self.win).to(ests.device).float(),
+ return_complex=True)
+ freq_mse = (est_spec.real - est_target.real).square().mean((1,2)) + (est_spec.imag - est_target.imag).square().mean((1,2))
+ freq_mse = freq_mse.view(B, nsrc).mean(-1)
+
+ return freq_mse
+
+# aliases
+pairwise_neg_sisdr = PairwiseNegSDR("sisdr")
+pairwise_neg_sdsdr = PairwiseNegSDR("sdsdr")
+pairwise_neg_snr = PairwiseNegSDR("snr")
+singlesrc_neg_sisdr = SingleSrcNegSDR("sisdr")
+singlesrc_neg_sdsdr = SingleSrcNegSDR("sdsdr")
+singlesrc_neg_snr = SingleSrcNegSDR("snr")
+multisrc_neg_sisdr = MultiSrcNegSDR("sisdr")
+multisrc_neg_sdsdr = MultiSrcNegSDR("sdsdr")
+multisrc_neg_snr = MultiSrcNegSDR("snr")
+freq_mae_wavl1loss = freq_MAE_WavL1Loss()
+pairwise_neg_sisdr_freq_mse = (PairwiseNegSDR("sisdr"), freq_MSE_Loss())
+pairwise_neg_snr_multidecoder = (PairwiseNegSDR("snr"), MultiSrcNegSDR("snr"))
\ No newline at end of file
diff --git a/look2hear/losses/mixit.py b/look2hear/losses/mixit.py
new file mode 100644
index 0000000000000000000000000000000000000000..750466ff26f3bef07a407bd124b52467b5d350d8
--- /dev/null
+++ b/look2hear/losses/mixit.py
@@ -0,0 +1,311 @@
+# import torch
+# import torch.nn as nn
+# from itertools import combinations
+# from pprint import pprint
+
+
+# def parts_mixgen(lists):
+# parts = []
+# for k in range(len(lists) + 1):
+# for sublist in combinations(lists, k):
+# rest = []
+# if sublist != () and len(sublist) != len(lists):
+# for item in lists:
+# if item not in sublist:
+# rest.append(item)
+# parts.append([sublist, rest])
+# return parts
+
+
+# def parts_mixgen_bisection(lists, srcmix, mix):
+# if mix == 0:
+# yield []
+# else:
+# for c in combinations(lists, srcmix):
+# rest = [x for x in lists if x not in c]
+# for r in parts_mixgen_bisection(rest, srcmix, mix - 1):
+# yield [list(c), *r]
+
+
+# class MixIT(nn.Module):
+# def __init__(self, loss_func, bisection=True):
+# super().__init__()
+# self.loss_func = loss_func
+# self.bisection = bisection
+
+# def forward(self, ests, targets, return_ests=None, **kwargs):
+# assert ests.shape[0] == targets.shape[0]
+# assert ests.shape[-1] == targets.shape[-1]
+
+# if self.bisection:
+# loss, min_loss_idx, parts = self.mixit_bisection(self.loss_func, ests, targets, **kwargs)
+# else:
+# loss, min_loss_idx, parts = self.mixit_non_bisection(self.loss_func, ests, targets, **kwargs)
+
+# mean_loss = torch.mean(loss)
+# if not return_ests:
+# return mean_loss
+
+# reordered = self.reorder_source(ests, targets, min_loss_idx, parts)
+# return mean_loss, reordered
+
+# def mixit_bisection(self, loss_func, ests, targets, **kwargs):
+# n_mix = targets.shape[1]
+# n_src = ests.shape[1]
+# srcmix = n_src // n_mix
+
+# parts = parts_mixgen_bisection(range(n_src), srcmix, n_mix)
+
+# loss_lists = []
+# for part in parts:
+# ests_mix = torch.stack([ests[:, i, :].sum(1) for i in part], dim=1)
+# loss = loss_func(ests_mix, targets, **kwargs)
+# loss_lists.append(loss[:, None])
+# loss_lists = torch.cat(loss_lists, dim=1)
+# min_loss, min_loss_indexes = torch.min(loss_lists, dim=1, keepdim=True)
+# return min_loss, min_loss_indexes, parts
+
+# def mixit_non_bisection(self, loss_func, ests, targets, **kwargs):
+# n_mix = targets.shape[1]
+# n_src = ests.shape[1]
+
+# parts = parts_mixgen(range(n_src))
+
+# loss_lists = []
+# for part in parts:
+# ests_mix = torch.stack([ests[:, i, :].sum(1) for i in part], dim=1)
+# loss = loss_func(ests_mix, targets, **kwargs)
+# loss_lists.append(loss[:, None])
+# loss_lists = torch.cat(loss_lists, dim=1)
+# min_loss, min_loss_indexes = torch.min(loss_lists, dim=1, keepdim=True)
+# return min_loss, min_loss_indexes, parts
+
+# def reoder_source(self, ests, targets, min_loss_idx, parts, **kwargs):
+# ordered = torch.zeros_like(targets)
+# for b, idx in enumerate(min_loss_idx):
+# right_partition = parts[idx]
+# ordered[b, :, :] = torch.stack(
+# [ests[b, idx, :][None, :, :].sum(1) for idx in right_partition], dim=1
+# )
+# return ordered
+
+# if __name__ == "__main__":
+# print(parts_mixgen(range(4)))
+# print([item for item in parts_mixgen_bisection(range(4), 2, 2)])
+
+import warnings
+from itertools import combinations
+import torch
+from torch import nn
+
+
+class MixITLossWrapper(nn.Module):
+ r"""Mixture invariant loss wrapper.
+ Args:
+ loss_func: function with signature (est_targets, targets, **kwargs).
+ generalized (bool): Determines how MixIT is applied. If False ,
+ apply MixIT for any number of mixtures as soon as they contain
+ the same number of sources (:meth:`~MixITLossWrapper.best_part_mixit`.)
+ If True (default), apply MixIT for two mixtures, but those mixtures do not
+ necessarly have to contain the same number of sources.
+ See :meth:`~MixITLossWrapper.best_part_mixit_generalized`.
+ For each of these modes, the best partition and reordering will be
+ automatically computed.
+ Examples:
+ >>> import torch
+ >>> from asteroid.losses import multisrc_mse
+ >>> mixtures = torch.randn(10, 2, 16000)
+ >>> est_sources = torch.randn(10, 4, 16000)
+ >>> # Compute MixIT loss based on pairwise losses
+ >>> loss_func = MixITLossWrapper(multisrc_mse)
+ >>> loss_val = loss_func(est_sources, mixtures)
+ References
+ [1] Scott Wisdom et al. "Unsupervised sound separation using
+ mixtures of mixtures." arXiv:2006.12701 (2020)
+ """
+
+ def __init__(self, loss_func, generalized=True):
+ super().__init__()
+ self.loss_func = loss_func
+ self.generalized = generalized
+
+ def forward(self, est_targets, targets, return_est=False, **kwargs):
+ r"""Find the best partition and return the loss.
+ Args:
+ est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, *)`.
+ The batch of target estimates.
+ targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`.
+ The batch of training targets
+ return_est: Boolean. Whether to return the estimated mixtures
+ estimates (To compute metrics or to save example).
+ **kwargs: additional keyword argument that will be passed to the
+ loss function.
+ Returns:
+ - Best partition loss for each batch sample, average over
+ the batch. torch.Tensor(loss_value)
+ - The estimated mixtures (estimated sources summed according to the partition)
+ if return_est is True. torch.Tensor of shape :math:`(batch, nmix, ...)`.
+ """
+ # Check input dimensions
+ assert est_targets.shape[0] == targets.shape[0]
+ assert est_targets.shape[2] == targets.shape[2]
+
+ if not self.generalized:
+ min_loss, min_loss_idx, parts = self.best_part_mixit(
+ self.loss_func, est_targets, targets, **kwargs
+ )
+ else:
+ min_loss, min_loss_idx, parts = self.best_part_mixit_generalized(
+ self.loss_func, est_targets, targets, **kwargs
+ )
+ # Take the mean over the batch
+ mean_loss = torch.mean(min_loss)
+ if not return_est:
+ return mean_loss
+ # Order and sum on the best partition to get the estimated mixtures
+ reordered = self.reorder_source(est_targets, targets, min_loss_idx, parts)
+ return mean_loss, reordered
+
+ @staticmethod
+ def best_part_mixit(loss_func, est_targets, targets, **kwargs):
+ r"""Find best partition of the estimated sources that gives the minimum
+ loss for the MixIT training paradigm in [1]. Valid for any number of
+ mixtures as soon as they contain the same number of sources.
+ Args:
+ loss_func: function with signature ``(est_targets, targets, **kwargs)``
+ The loss function to get batch losses from.
+ est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`.
+ The batch of target estimates.
+ targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`.
+ The batch of training targets (mixtures).
+ **kwargs: additional keyword argument that will be passed to the
+ loss function.
+ Returns:
+ - :class:`torch.Tensor`:
+ The loss corresponding to the best permutation of size (batch,).
+ - :class:`torch.LongTensor`:
+ The indices of the best partition.
+ - :class:`list`:
+ list of the possible partitions of the sources.
+ """
+ nmix = targets.shape[1]
+ nsrc = est_targets.shape[1]
+ if nsrc % nmix != 0:
+ raise ValueError(
+ "The mixtures are assumed to contain the same number of sources"
+ )
+ nsrcmix = nsrc // nmix
+
+ # Generate all unique partitions of size k from a list lst of
+ # length n, where l = n // k is the number of parts. The total
+ # number of such partitions is: NPK(n,k) = n! / ((k!)^l * l!)
+ # Algorithm recursively distributes items over parts
+ def parts_mixit(lst, k, l):
+ if l == 0:
+ yield []
+ else:
+ for c in combinations(lst, k):
+ rest = [x for x in lst if x not in c]
+ for r in parts_mixit(rest, k, l - 1):
+ yield [list(c), *r]
+
+ # Generate all the possible partitions
+ parts = list(parts_mixit(range(nsrc), nsrcmix, nmix))
+ # Compute the loss corresponding to each partition
+ loss_set = MixITLossWrapper.loss_set_from_parts(
+ loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs
+ )
+ # Indexes and values of min losses for each batch element
+ min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True)
+ return min_loss, min_loss_indexes, parts
+
+ @staticmethod
+ def best_part_mixit_generalized(loss_func, est_targets, targets, **kwargs):
+ r"""Find best partition of the estimated sources that gives the minimum
+ loss for the MixIT training paradigm in [1]. Valid only for two mixtures,
+ but those mixtures do not necessarly have to contain the same number of
+ sources e.g the case where one mixture is silent is allowed..
+ Args:
+ loss_func: function with signature ``(est_targets, targets, **kwargs)``
+ The loss function to get batch losses from.
+ est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`.
+ The batch of target estimates.
+ targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`.
+ The batch of training targets (mixtures).
+ **kwargs: additional keyword argument that will be passed to the
+ loss function.
+ Returns:
+ - :class:`torch.Tensor`:
+ The loss corresponding to the best permutation of size (batch,).
+ - :class:`torch.LongTensor`:
+ The indexes of the best permutations.
+ - :class:`list`:
+ list of the possible partitions of the sources.
+ """
+ nmix = targets.shape[1] # number of mixtures
+ nsrc = est_targets.shape[1] # number of estimated sources
+ if nmix != 2:
+ raise ValueError("Works only with two mixtures")
+
+ # Generate all unique partitions of any size from a list lst of
+ # length n. Algorithm recursively distributes items over parts
+ def parts_mixit_gen(lst):
+ partitions = []
+ for k in range(len(lst) + 1):
+ for c in combinations(lst, k):
+ rest = []
+ if c != () and len(c) != len(lst):
+ for item in lst:
+ if item not in c:
+ rest.append(item)
+ partitions.append([c, rest])
+ return partitions
+
+ # Generate all the possible partitions
+ parts = parts_mixit_gen(range(nsrc))
+ # Compute the loss corresponding to each partition
+ loss_set = MixITLossWrapper.loss_set_from_parts(
+ loss_func, est_targets=est_targets, targets=targets, parts=parts, **kwargs
+ )
+ # Indexes and values of min losses for each batch element
+ min_loss, min_loss_indexes = torch.min(loss_set, dim=1, keepdim=True)
+ return min_loss, min_loss_indexes, parts
+
+ @staticmethod
+ def loss_set_from_parts(loss_func, est_targets, targets, parts, **kwargs):
+ """Common loop between both best_part_mixit"""
+ loss_set = []
+ for partition in parts:
+ # sum the sources according to the given partition
+ est_mixes = torch.stack(
+ [est_targets[:, idx, :].sum(1) for idx in partition], dim=1
+ )
+ # get loss for the given partition
+ loss_set.append(loss_func(est_mixes, targets, **kwargs)[:, None])
+ loss_set = torch.cat(loss_set, dim=1)
+ return loss_set
+
+ @staticmethod
+ def reorder_source(est_targets, targets, min_loss_idx, parts):
+ """Reorder sources according to the best partition.
+ Args:
+ est_targets: torch.Tensor. Expected shape :math:`(batch, nsrc, ...)`.
+ The batch of target estimates.
+ targets: torch.Tensor. Expected shape :math:`(batch, nmix, ...)`.
+ The batch of training targets.
+ min_loss_idx: torch.LongTensor. The indexes of the best permutations.
+ parts: list of the possible partitions of the sources.
+ Returns:
+ :class:`torch.Tensor`: Reordered sources of shape :math:`(batch, nmix, time)`.
+ """
+ # For each batch there is a different min_loss_idx
+ ordered = torch.zeros_like(targets)
+ for b, idx in enumerate(min_loss_idx):
+ right_partition = parts[idx]
+ # Sum the estimated sources to get the estimated mixtures
+ ordered[b, :, :] = torch.stack(
+ [est_targets[b, idx, :][None, :, :].sum(1) for idx in right_partition],
+ dim=1,
+ )
+
+ return ordered
diff --git a/look2hear/losses/pit_wrapper.py b/look2hear/losses/pit_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..3434e36f449faf2517d37327349e489c191db1dc
--- /dev/null
+++ b/look2hear/losses/pit_wrapper.py
@@ -0,0 +1,180 @@
+from itertools import permutations
+import torch
+from torch import nn
+from scipy.optimize import linear_sum_assignment
+
+
+class PITLossWrapper(nn.Module):
+ def __init__(
+ self, loss_func, pit_from="pw_mtx", equidistant_weight=False, perm_reduce=None, threshold_byloss=True
+ ):
+ super().__init__()
+ self.loss_func = loss_func
+ self.pit_from = pit_from
+ self.perm_reduce = perm_reduce
+ self.threshold_byloss = threshold_byloss
+ self.equidistant_weight = equidistant_weight
+ if self.pit_from not in ["pw_mtx", "pw_pt", "perm_avg", "pw_mtx_broadcast", "pw_mtx_multidecoder_keepmtx", "pw_mtx_multidecoder_batchmin"]:
+ raise ValueError(
+ "Unsupported loss function type {} for now. Expected"
+ "one of [`pw_mtx`, `pw_pt`, `perm_avg`, `pw_mtx_broadcast`]".format(self.pit_from)
+ )
+
+ def forward(self, ests, targets, return_ests=False, reduce_kwargs=None, **kwargs):
+ n_src = targets.shape[1]
+ if self.pit_from == "pw_mtx":
+ pw_loss = self.loss_func(ests, targets, **kwargs)
+ elif self.pit_from == "pw_mtx_broadcast":
+ pw_loss = self.loss_func[0](ests, targets, **kwargs)
+ elif self.pit_from == "pw_mtx_multidecoder_keepmtx":
+ ests_last_block = ests[-1]
+ pw_loss = self.loss_func[0](ests_last_block, targets, **kwargs)
+ elif self.pit_from == "pw_mtx_multidecoder_batchmin":
+ blocks_num = len(ests)
+ ests = torch.cat(ests, dim=0)
+ targets = torch.cat([targets] * blocks_num, dim=0)
+ pw_loss = self.loss_func(ests, targets, **kwargs)
+ elif self.pit_from == "pw_pt":
+ pw_loss = self.get_pw_losses(self.loss_func, ests, targets, **kwargs)
+ elif self.pit_from == "perm_avg":
+ min_loss, batch_indices = self.best_perm_from_perm_avg_loss(
+ self.loss_func, ests, targets, **kwargs
+ )
+ # print(batch_indices)
+ mean_loss = torch.mean(min_loss)
+ if not return_ests:
+ return mean_loss
+ reordered = self.reordered_sources(ests, batch_indices)
+ return mean_loss, reordered
+ else:
+ return
+ # import pdb; pdb.set_trace()
+ assert pw_loss.ndim == 3, (
+ "Something went wrong with the loss " "function, please read the docs."
+ )
+ assert (
+ pw_loss.shape[0] == targets.shape[0]
+ ), "PIT loss needs same batch dim as input"
+
+ reduce_kwargs = reduce_kwargs if reduce_kwargs is not None else dict()
+ min_loss, batch_indices = self.find_best_perm(
+ pw_loss, perm_reduce=self.perm_reduce, **reduce_kwargs
+ )
+ if self.pit_from == "pw_mtx_multidecoder_keepmtx":
+ reordered = []
+ mean_loss = 0
+ for i in range(len(ests)):
+ reordered_ests_each_block = self.reordered_sources(ests[i], batch_indices)
+ reordered.append(reordered_ests_each_block)
+ loss_each_block = self.loss_func[1](reordered_ests_each_block, targets, **kwargs)
+ if self.threshold_byloss:
+ if loss_each_block[loss_each_block > -30].nelement() > 0:
+ loss_each_block = loss_each_block[loss_each_block > -30]
+ if self.equidistant_weight:
+ mean_loss = mean_loss + (i + 1) * 1 / len(ests) * loss_each_block.mean()
+ else:
+ mean_loss = mean_loss + 1 / len(ests) * loss_each_block.mean()
+ reordered = torch.cat(reordered, dim=0)
+ if not return_ests:
+ return mean_loss
+ return mean_loss, reordered
+ else:
+ if self.threshold_byloss:
+ if min_loss[min_loss > -30].nelement() > 0:
+ min_loss = min_loss[min_loss > -30]
+ mean_loss = torch.mean(min_loss)
+ reordered = self.reordered_sources(ests, batch_indices)
+ # import pdb; pdb.set_trace()
+ if self.pit_from == "pw_mtx_broadcast":
+ mean_loss += 0.5 * self.loss_func[1](reordered, targets, **kwargs).mean()
+ if not return_ests:
+ return mean_loss
+ return mean_loss, reordered
+
+ def get_pw_losses(self, loss_func, ests, targets, **kwargs):
+ B, n_src, _ = targets.shape
+ pair_wise_losses = targets.new_empty(B, n_src, n_src)
+ for est_idx, est_src in enumerate(ests.transpose(0, 1)):
+ for target_idx, target_src in enumerate(targets.transpose(0, 1)):
+ pair_wise_losses[:, est_idx, target_idx] = loss_func(
+ est_src, target_src, **kwargs
+ )
+ return pair_wise_losses
+
+ def best_perm_from_perm_avg_loss(self, loss_func, ests, targets, **kwargs):
+ n_src = targets.shape[1]
+ perms = torch.tensor(list(permutations(range(n_src))), dtype=torch.long)
+ # import pdb; pdb.set_trace()
+ loss_set = torch.stack(
+ [loss_func(ests[:, perm], targets) for perm in perms], dim=1
+ )
+ min_loss, min_loss_idx = torch.min(loss_set, dim=1)
+ batch_indices = torch.stack([perms[m] for m in min_loss_idx], dim=0)
+ return min_loss, batch_indices
+
+ def reordered_sources(self, sources, batch_indices):
+ reordered_sources = torch.stack(
+ [torch.index_select(s, 0, b) for s, b in zip(sources, batch_indices)]
+ )
+ return reordered_sources
+
+ def find_best_perm(self, pair_wise_losses, perm_reduce=None, **kwargs):
+ n_src = pair_wise_losses.shape[-1]
+ if perm_reduce is not None or n_src <= 3:
+ min_loss, batch_indices = self.find_best_perm_factorial(
+ pair_wise_losses, perm_reduce=perm_reduce, **kwargs
+ )
+ else:
+ min_loss, batch_indices = self.find_best_perm_hungarian(pair_wise_losses)
+ return min_loss, batch_indices
+
+ def find_best_perm_factorial(self, pair_wise_losses, perm_reduce=None, **kwargs):
+ n_src = pair_wise_losses.shape[-1]
+ # After transposition, dim 1 corresp. to sources and dim 2 to estimates
+ pwl = pair_wise_losses.transpose(-1, -2)
+ perms = pwl.new_tensor(list(permutations(range(n_src))), dtype=torch.long)
+ # Column permutation indices
+ idx = torch.unsqueeze(perms, 2)
+ # Loss mean of each permutation
+ if perm_reduce is None:
+ # one-hot, [n_src!, n_src, n_src]
+ # import pdb; pdb.set_trace()
+ perms_one_hot = pwl.new_zeros((*perms.size(), n_src)).scatter_(2, idx, 1)
+ loss_set = torch.einsum("bij,pij->bp", [pwl, perms_one_hot])
+ loss_set /= n_src
+ else:
+ # batch = pwl.shape[0]; n_perm = idx.shape[0]
+ # [batch, n_src!, n_src] : Pairwise losses for each permutation.
+ pwl_set = pwl[:, torch.arange(n_src), idx.squeeze(-1)]
+ # Apply reduce [batch, n_src!, n_src] --> [batch, n_src!]
+ loss_set = perm_reduce(pwl_set, **kwargs)
+ # Indexes and values of min losses for each batch element
+ min_loss, min_loss_idx = torch.min(loss_set, dim=1)
+
+ # Permutation indices for each batch.
+ batch_indices = torch.stack([perms[m] for m in min_loss_idx], dim=0)
+ return min_loss, batch_indices
+
+ def find_best_perm_hungarian(self, pair_wise_losses: torch.Tensor):
+ pwl = pair_wise_losses.transpose(-1, -2)
+ # Just bring the numbers to cpu(), not the graph
+ pwl_copy = pwl.detach().cpu()
+ # Loop over batch + row indices are always ordered for square matrices.
+ batch_indices = torch.tensor(
+ [linear_sum_assignment(pwl)[1] for pwl in pwl_copy]
+ ).to(pwl.device)
+ min_loss = torch.gather(pwl, 2, batch_indices[..., None]).mean([-1, -2])
+ return min_loss, batch_indices
+
+
+if __name__ == "__main__":
+ import torch
+ from matrix import pairwise_neg_sisdr, pairwise_neg_sisdr
+
+ ests = torch.randn(10, 2, 32000)
+ targets = torch.randn(10, 2, 32000)
+
+ pit_wrapper_1 = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
+ pit_wrapper_2 = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
+ print(pit_wrapper_1(ests, targets))
+ print(pit_wrapper_2(ests, targets))
diff --git a/look2hear/losses/sisnri.py b/look2hear/losses/sisnri.py
new file mode 100644
index 0000000000000000000000000000000000000000..c879546baf1922e0612072a7a4b394201d3e857b
--- /dev/null
+++ b/look2hear/losses/sisnri.py
@@ -0,0 +1,42 @@
+import torch
+from itertools import permutations
+
+class SISNRi(object):
+ def __init__(self):
+ super(Loss, self).__init__()
+
+ def sisnr(self, mix, est, ref, eps = 1e-8):
+ """
+ input:
+ mix: B x L
+ est: B x L
+ output: B
+ """
+ est = est - torch.mean(est, dim = -1, keepdim = True)
+ ref = ref - torch.mean(ref, dim = -1, keepdim = True)
+ mix = mix - torch.mean(mix, dim = -1, keepdim = True)
+ est_p = (torch.sum(est * ref, dim = -1, keepdim = True) * ref) / torch.sum(ref * ref, dim = -1, keepdim = True)
+ est_v = est - est_p
+ mix_p = (torch.sum(mix * ref, dim = -1, keepdim = True) * ref) / torch.sum(ref * ref, dim = -1, keepdim = True)
+ mix_v = mix - mix_p
+ est_sisnr = 10 * torch.log10((torch.sum(est_p * est_p, dim = -1, keepdim = True) + eps) / (torch.sum(est_v * est_v, dim = -1, keepdim = True) + eps))
+ mix_sisnr = 10 * torch.log10((torch.sum(mix_p * mix_p, dim = -1, keepdim = True) + eps) / (torch.sum(mix_v * mix_v, dim = -1, keepdim = True) + eps))
+ return est_sisnr - mix_sisnr
+
+ def compute_loss(self, mix, ests, refs):
+ """
+ input:
+ mix: B x L
+ est: num_spk x B x L
+ output: 1
+ """
+
+ def sisnr_loss(permute):
+ # B
+ return torch.mean(torch.stack([self.sisnr(mix, ests[s], refs[t]) for s, t in enumerate(permute)]), dim = 0, keepdim = True)
+ num_spks = len(ests)
+ # pmt_num x B
+ sisnr_mat = torch.stack([sisnr_loss(p) for p in permutations(range(num_spks))])
+ # B
+ max_pmt, _ = torch.max(sisnr_mat, dim=0)
+ return -torch.mean(max_pmt)
\ No newline at end of file
diff --git a/look2hear/metrics/__init__.py b/look2hear/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd65a55beba10e7ea1b6c3b000bf9d4e67c876af
--- /dev/null
+++ b/look2hear/metrics/__init__.py
@@ -0,0 +1,4 @@
+from .wrapper import MetricsTracker
+from .splitwrapper import SPlitMetricsTracker
+
+__all__ = ["MetricsTracker", "SPlitMetricsTracker"]
diff --git a/look2hear/metrics/splitwrapper.py b/look2hear/metrics/splitwrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d985b6d2a4636dd244f053d85469f001171228f
--- /dev/null
+++ b/look2hear/metrics/splitwrapper.py
@@ -0,0 +1,115 @@
+import csv
+import torch
+import numpy as np
+import logging
+
+# from torch_mir_eval.separation import bss_eval_sources
+from ..losses import (
+ PITLossWrapper,
+ pairwise_neg_sisdr,
+ pairwise_neg_snr,
+ singlesrc_neg_sisdr,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class SPlitMetricsTracker:
+ def __init__(self, save_file: str = ""):
+ self.one_all_snrs = []
+ self.one_all_snrs_i = []
+ self.one_all_sisnrs = []
+ self.one_all_sisnrs_i = []
+ self.two_all_snrs = []
+ self.two_all_snrs_i = []
+ self.two_all_sisnrs = []
+ self.two_all_sisnrs_i = []
+ csv_columns = [
+ "snt_id",
+ "one_snr",
+ "one_snr_i",
+ "one_si-snr",
+ "one_si-snr_i",
+ "two_snr",
+ "two_snr_i",
+ "two_si-snr",
+ "two_si-snr_i",
+ ]
+ self.results_csv = open(save_file, "w")
+ self.writer = csv.DictWriter(self.results_csv, fieldnames=csv_columns)
+ self.writer.writeheader()
+ self.pit_sisnr = PITLossWrapper(pairwise_neg_sisdr, pit_from="pw_mtx")
+ self.pit_snr = PITLossWrapper(pairwise_neg_snr, pit_from="pw_mtx")
+
+ def __call__(self, mix, clean, estimate, key):
+ _, ests_np = self.pit_snr(
+ estimate.unsqueeze(0), clean.unsqueeze(0), return_ests=True
+ )
+ # sisnr
+ two_sisnr = self.pit_sisnr(ests_np[:, 0:2], clean.unsqueeze(0)[:, 0:2])
+ one_sisnr = self.pit_sisnr(
+ ests_np[:, 2].unsqueeze(1), clean.unsqueeze(0)[:, 2].unsqueeze(1)
+ )
+ mix = torch.stack([mix] * clean.shape[0], dim=0)
+ two_sisnr_baseline = self.pit_sisnr(
+ mix.unsqueeze(0)[:, 0:2], clean.unsqueeze(0)[:, 0:2]
+ )
+ one_sisnr_baseline = self.pit_sisnr(
+ mix.unsqueeze(0)[:, 2].unsqueeze(1), clean.unsqueeze(0)[:, 2].unsqueeze(1)
+ )
+ two_sisnr_i = two_sisnr - two_sisnr_baseline
+ one_sisnr_i = one_sisnr - one_sisnr_baseline
+ # sdr
+ two_snr = self.pit_snr(ests_np[:, 0:2], clean.unsqueeze(0)[:, 0:2])
+ one_snr = self.pit_snr(
+ ests_np[:, 2].unsqueeze(1), clean.unsqueeze(0)[:, 2].unsqueeze(1)
+ )
+ two_snr_baseline = self.pit_snr(
+ mix.unsqueeze(0)[:, 0:2], clean.unsqueeze(0)[:, 0:2]
+ )
+ one_snr_baseline = self.pit_snr(
+ mix.unsqueeze(0)[:, 2].unsqueeze(1), clean.unsqueeze(0)[:, 2].unsqueeze(1)
+ )
+ two_snr_i = two_snr - two_snr_baseline
+ one_snr_i = one_snr - one_snr_baseline
+
+ row = {
+ "snt_id": key,
+ "one_snr": -one_snr.item(),
+ "one_snr_i": -one_snr_i.item(),
+ "one_si-snr": -one_sisnr.item(),
+ "one_si-snr_i": -one_sisnr_i.item(),
+ "two_snr": -two_snr.item(),
+ "two_snr_i": -two_snr_i.item(),
+ "two_si-snr": -two_sisnr.item(),
+ "two_si-snr_i": -two_sisnr_i.item(),
+ }
+ self.writer.writerow(row)
+ # Metric Accumulation
+ self.one_all_snrs.append(-one_snr.item())
+ self.one_all_snrs_i.append(-one_snr_i.item())
+ self.one_all_sisnrs.append(-one_sisnr.item())
+ self.one_all_sisnrs_i.append(-one_sisnr_i.item())
+ self.two_all_snrs.append(-two_snr.item())
+ self.two_all_snrs_i.append(-two_snr_i.item())
+ self.two_all_sisnrs.append(-two_sisnr.item())
+ self.two_all_sisnrs_i.append(-two_sisnr_i.item())
+
+ def final(self,):
+ row = {
+ "snt_id": "avg",
+ "one_snr": np.array(self.one_all_snrs).mean(),
+ "one_snr_i": np.array(self.one_all_snrs_i).mean(),
+ "one_si-snr": np.array(self.one_all_sisnrs).mean(),
+ "one_si-snr_i": np.array(self.one_all_sisnrs_i).mean(),
+ "two_snr": np.array(self.two_all_snrs).mean(),
+ "two_snr_i": np.array(self.two_all_snrs_i).mean(),
+ "two_si-snr": np.array(self.two_all_sisnrs).mean(),
+ "two_si-snr_i": np.array(self.two_all_sisnrs_i).mean(),
+ }
+ self.writer.writerow(row)
+ # logger.info("Mean SISNR is {}".format(row["si-snr"]))
+ # logger.info("Mean SISNRi is {}".format(row["si-snr_i"]))
+ # logger.info("Mean SDR is {}".format(row["sdr"]))
+ # logger.info("Mean SDRi is {}".format(row["sdr_i"]))
+ self.results_csv.close()
diff --git a/look2hear/metrics/wrapper.py b/look2hear/metrics/wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c2a10ce66a4c9c8617ce85594834715dc7ec162
--- /dev/null
+++ b/look2hear/metrics/wrapper.py
@@ -0,0 +1,84 @@
+import csv
+import torch
+import numpy as np
+import logging
+
+from torch_mir_eval.separation import bss_eval_sources
+import fast_bss_eval
+from ..losses import (
+ PITLossWrapper,
+ pairwise_neg_sisdr,
+ pairwise_neg_snr,
+ singlesrc_neg_sisdr,
+ PairwiseNegSDR,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class MetricsTracker:
+ def __init__(self, save_file: str = ""):
+ self.all_sdrs = []
+ self.all_sdrs_i = []
+ self.all_sisnrs = []
+ self.all_sisnrs_i = []
+ csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i"]
+ self.results_csv = open(save_file, "w")
+ self.writer = csv.DictWriter(self.results_csv, fieldnames=csv_columns)
+ self.writer.writeheader()
+ self.pit_sisnr = PITLossWrapper(
+ PairwiseNegSDR("sisdr", zero_mean=False), pit_from="pw_mtx"
+ )
+ self.pit_snr = PITLossWrapper(
+ PairwiseNegSDR("snr", zero_mean=False), pit_from="pw_mtx"
+ )
+
+ def __call__(self, mix, clean, estimate, key):
+ # sisnr
+ sisnr = self.pit_sisnr(estimate.unsqueeze(0), clean.unsqueeze(0))
+ mix = torch.stack([mix] * clean.shape[0], dim=0)
+ sisnr_baseline = self.pit_sisnr(mix.unsqueeze(0), clean.unsqueeze(0))
+ sisnr_i = sisnr - sisnr_baseline
+
+ # sdr
+ sdr = -fast_bss_eval.sdr_pit_loss(estimate, clean).mean()
+ sdr_baseline = -fast_bss_eval.sdr_pit_loss(mix, clean).mean()
+ sdr_i = sdr - sdr_baseline
+ # import pdb; pdb.set_trace()
+ row = {
+ "snt_id": key,
+ "sdr": sdr.item(),
+ "sdr_i": sdr_i.item(),
+ "si-snr": -sisnr.item(),
+ "si-snr_i": -sisnr_i.item(),
+ }
+ self.writer.writerow(row)
+ # Metric Accumulation
+ self.all_sdrs.append(sdr.item())
+ self.all_sdrs_i.append(sdr_i.item())
+ self.all_sisnrs.append(-sisnr.item())
+ self.all_sisnrs_i.append(-sisnr_i.item())
+
+ def update(self, ):
+ return {"sdr_i": np.array(self.all_sdrs_i).mean(),
+ "si-snr_i": np.array(self.all_sisnrs_i).mean()
+ }
+
+ def final(self,):
+ row = {
+ "snt_id": "avg",
+ "sdr": np.array(self.all_sdrs).mean(),
+ "sdr_i": np.array(self.all_sdrs_i).mean(),
+ "si-snr": np.array(self.all_sisnrs).mean(),
+ "si-snr_i": np.array(self.all_sisnrs_i).mean(),
+ }
+ self.writer.writerow(row)
+ row = {
+ "snt_id": "std",
+ "sdr": np.array(self.all_sdrs).std(),
+ "sdr_i": np.array(self.all_sdrs_i).std(),
+ "si-snr": np.array(self.all_sisnrs).std(),
+ "si-snr_i": np.array(self.all_sisnrs_i).std(),
+ }
+ self.writer.writerow(row)
+ self.results_csv.close()
diff --git a/look2hear/models/__init__.py b/look2hear/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..00bc01fb67cd5b4d1d563122ea3f2e18d48eec40
--- /dev/null
+++ b/look2hear/models/__init__.py
@@ -0,0 +1,42 @@
+from .tiger import TIGER
+from .tiger_dnr import TIGERDNR
+from .base_model import BaseModel
+
+__all__ = [
+ "TIGER",
+]
+
+
+def register_model(custom_model):
+ """Register a custom model, gettable with `models.get`.
+
+ Args:
+ custom_model: Custom model to register.
+
+ """
+ if (
+ custom_model.__name__ in globals().keys()
+ or custom_model.__name__.lower() in globals().keys()
+ ):
+ raise ValueError(
+ f"Model {custom_model.__name__} already exists. Choose another name."
+ )
+ globals().update({custom_model.__name__: custom_model})
+
+
+def get(identifier):
+ """Returns an model class from a string (case-insensitive).
+
+ Args:
+ identifier (str): the model name.
+
+ Returns:
+ :class:`torch.nn.Module`
+ """
+ if isinstance(identifier, str):
+ to_get = {k.lower(): v for k, v in globals().items()}
+ cls = to_get.get(identifier.lower())
+ if cls is None:
+ raise ValueError(f"Could not interpret model name : {str(identifier)}")
+ return cls
+ raise ValueError(f"Could not interpret model name : {str(identifier)}")
diff --git a/look2hear/models/base_model.py b/look2hear/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d82764064553fd889cb9dc7739330a94fff27518
--- /dev/null
+++ b/look2hear/models/base_model.py
@@ -0,0 +1,94 @@
+###
+# Author: Kai Li
+# Date: 2021-06-17 23:08:32
+# LastEditors: Please set LastEditors
+# LastEditTime: 2022-05-26 18:06:22
+###
+import torch
+import torch.nn as nn
+from huggingface_hub import PyTorchModelHubMixin
+
+def _unsqueeze_to_3d(x):
+ """Normalize shape of `x` to [batch, n_chan, time]."""
+ if x.ndim == 1:
+ return x.reshape(1, 1, -1)
+ elif x.ndim == 2:
+ return x.unsqueeze(1)
+ else:
+ return x
+
+
+def pad_to_appropriate_length(x, lcm):
+ values_to_pad = int(x.shape[-1]) % lcm
+ if values_to_pad:
+ appropriate_shape = x.shape
+ padded_x = torch.zeros(
+ list(appropriate_shape[:-1])
+ + [appropriate_shape[-1] + lcm - values_to_pad],
+ dtype=torch.float32,
+ ).to(x.device)
+ padded_x[..., : x.shape[-1]] = x
+ return padded_x
+ return x
+
+
+class BaseModel(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/JusperLee/Apollo", pipeline_tag="audio-to-audio"):
+ def __init__(self, sample_rate, in_chan=1):
+ super().__init__()
+ self._sample_rate = sample_rate
+ self._in_chan = in_chan
+
+ def forward(self, *args, **kwargs):
+ raise NotImplementedError
+
+ def sample_rate(self,):
+ return self._sample_rate
+
+ @staticmethod
+ def load_state_dict_in_audio(model, pretrained_dict):
+ model_dict = model.state_dict()
+ update_dict = {}
+ for k, v in pretrained_dict.items():
+ if "audio_model" in k:
+ update_dict[k[12:]] = v
+ model_dict.update(update_dict)
+ model.load_state_dict(model_dict)
+ return model
+
+ # @staticmethod
+ # def from_pretrain(pretrained_model_conf_or_path, *args, **kwargs):
+ # from . import get
+
+ # conf = torch.load(
+ # pretrained_model_conf_or_path, map_location="cpu"
+ # ) # Attempt to find the model and instantiate it.
+
+ # model_class = get(conf["model_name"])
+ # # model_class = get("Conv_TasNet")
+ # model = model_class(*args, **kwargs)
+ # model.load_state_dict(conf["state_dict"])
+ # return model
+
+ def serialize(self):
+ import pytorch_lightning as pl # Not used in torch.hub
+
+ model_conf = dict(
+ model_name=self.__class__.__name__,
+ state_dict=self.get_state_dict(),
+ model_args=self.get_model_args(),
+ )
+ # Additional infos
+ infos = dict()
+ infos["software_versions"] = dict(
+ torch_version=torch.__version__, pytorch_lightning_version=pl.__version__,
+ )
+ model_conf["infos"] = infos
+ return model_conf
+
+ def get_state_dict(self):
+ """In case the state dict needs to be modified before sharing the model."""
+ return self.state_dict()
+
+ def get_model_args(self):
+ """Should return args to re-instantiate the class."""
+ raise NotImplementedError
\ No newline at end of file
diff --git a/look2hear/models/tiger.py b/look2hear/models/tiger.py
new file mode 100644
index 0000000000000000000000000000000000000000..51fcbb0cb4b8d09efbe629baa6c2d8903b240860
--- /dev/null
+++ b/look2hear/models/tiger.py
@@ -0,0 +1,649 @@
+import inspect
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from .base_model import BaseModel
+from ..layers import activations, normalizations
+
+
+def GlobLN(nOut):
+ return nn.GroupNorm(1, nOut, eps=1e-8)
+
+
+class ConvNormAct(nn.Module):
+ """
+ This class defines the convolution layer with normalization and a PReLU
+ activation
+ """
+
+ def __init__(self, nIn, nOut, kSize, stride=1, groups=1):
+ """
+ :param nIn: number of input channels
+ :param nOut: number of output channels
+ :param kSize: kernel size
+ :param stride: stride rate for down-sampling. Default is 1
+ """
+ super().__init__()
+ padding = int((kSize - 1) / 2)
+ self.conv = nn.Conv1d(
+ nIn, nOut, kSize, stride=stride, padding=padding, bias=True, groups=groups
+ )
+ self.norm = GlobLN(nOut)
+ self.act = nn.PReLU()
+
+ def forward(self, input):
+ output = self.conv(input)
+ output = self.norm(output)
+ return self.act(output)
+
+
+class ConvNorm(nn.Module):
+ """
+ This class defines the convolution layer with normalization and PReLU activation
+ """
+
+ def __init__(self, nIn, nOut, kSize, stride=1, groups=1, bias=True):
+ """
+ :param nIn: number of input channels
+ :param nOut: number of output channels
+ :param kSize: kernel size
+ :param stride: stride rate for down-sampling. Default is 1
+ """
+ super().__init__()
+ padding = int((kSize - 1) / 2)
+ self.conv = nn.Conv1d(
+ nIn, nOut, kSize, stride=stride, padding=padding, bias=bias, groups=groups
+ )
+ self.norm = GlobLN(nOut)
+
+ def forward(self, input):
+ output = self.conv(input)
+ return self.norm(output)
+
+class ATTConvActNorm(nn.Module):
+ def __init__(
+ self,
+ in_chan: int = 1,
+ out_chan: int = 1,
+ kernel_size: int = -1,
+ stride: int = 1,
+ groups: int = 1,
+ dilation: int = 1,
+ padding: int = None,
+ norm_type: str = None,
+ act_type: str = None,
+ n_freqs: int = -1,
+ xavier_init: bool = False,
+ bias: bool = True,
+ is2d: bool = False,
+ *args,
+ **kwargs,
+ ):
+ super(ATTConvActNorm, self).__init__()
+ self.in_chan = in_chan
+ self.out_chan = out_chan
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.groups = groups
+ self.dilation = dilation
+ self.padding = padding
+ self.norm_type = norm_type
+ self.act_type = act_type
+ self.n_freqs = n_freqs
+ self.xavier_init = xavier_init
+ self.bias = bias
+
+ if self.padding is None:
+ self.padding = 0 if self.stride > 1 else "same"
+
+ if kernel_size > 0:
+ conv = nn.Conv2d if is2d else nn.Conv1d
+
+ self.conv = conv(
+ in_channels=self.in_chan,
+ out_channels=self.out_chan,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=self.groups,
+ bias=self.bias,
+ )
+ if self.xavier_init:
+ nn.init.xavier_uniform_(self.conv.weight)
+ else:
+ self.conv = nn.Identity()
+
+ self.act = activations.get(self.act_type)()
+ self.norm = normalizations.get(self.norm_type)(
+ (self.out_chan, self.n_freqs) if self.norm_type == "LayerNormalization4D" else self.out_chan
+ )
+
+ def forward(self, x: torch.Tensor):
+ output = self.conv(x)
+ output = self.act(output)
+ output = self.norm(output)
+ return output
+
+ def get_config(self):
+ encoder_args = {}
+
+ for k, v in (self.__dict__).items():
+ if not k.startswith("_") and k != "training":
+ if not inspect.ismethod(v):
+ encoder_args[k] = v
+
+ return encoder_args
+
+class DilatedConvNorm(nn.Module):
+ """
+ This class defines the dilated convolution with normalized output.
+ """
+
+ def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1):
+ """
+ :param nIn: number of input channels
+ :param nOut: number of output channels
+ :param kSize: kernel size
+ :param stride: optional stride rate for down-sampling
+ :param d: optional dilation rate
+ """
+ super().__init__()
+ self.conv = nn.Conv1d(
+ nIn,
+ nOut,
+ kSize,
+ stride=stride,
+ dilation=d,
+ padding=((kSize - 1) // 2) * d,
+ groups=groups,
+ )
+ # self.norm = nn.GroupNorm(1, nOut, eps=1e-08)
+ self.norm = GlobLN(nOut)
+
+ def forward(self, input):
+ output = self.conv(input)
+ return self.norm(output)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_size, drop=0.1):
+ super().__init__()
+ self.fc1 = ConvNorm(in_features, hidden_size, 1, bias=False)
+ self.dwconv = nn.Conv1d(
+ hidden_size, hidden_size, 5, 1, 2, bias=True, groups=hidden_size
+ )
+ self.act = nn.ReLU()
+ self.fc2 = ConvNorm(hidden_size, in_features, 1, bias=False)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.dwconv(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+class InjectionMultiSum(nn.Module):
+ def __init__(self, inp: int, oup: int, kernel: int = 1) -> None:
+ super().__init__()
+ groups = 1
+ if inp == oup:
+ groups = inp
+ self.local_embedding = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
+ self.global_embedding = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
+ self.global_act = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
+ self.act = nn.Sigmoid()
+
+ def forward(self, x_l, x_g):
+ """
+ x_g: global features
+ x_l: local features
+ """
+ B, N, T = x_l.shape
+ local_feat = self.local_embedding(x_l)
+
+ global_act = self.global_act(x_g)
+ sig_act = F.interpolate(self.act(global_act), size=T, mode="nearest")
+ # sig_act = self.act(global_act)
+
+ global_feat = self.global_embedding(x_g)
+ global_feat = F.interpolate(global_feat, size=T, mode="nearest")
+
+ out = local_feat * sig_act + global_feat
+ return out
+
+class InjectionMulti(nn.Module):
+ def __init__(self, inp: int, oup: int, kernel: int = 1) -> None:
+ super().__init__()
+ groups = 1
+ if inp == oup:
+ groups = inp
+ self.local_embedding = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
+ self.global_act = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
+ self.act = nn.Sigmoid()
+
+ def forward(self, x_l, x_g):
+ """
+ x_g: global features
+ x_l: local features
+ """
+ B, N, T = x_l.shape
+ local_feat = self.local_embedding(x_l)
+
+ global_act = self.global_act(x_g)
+ sig_act = F.interpolate(self.act(global_act), size=T, mode="nearest")
+ # sig_act = self.act(global_act)
+
+ out = local_feat * sig_act
+ return out
+
+class UConvBlock(nn.Module):
+ """
+ This class defines the block which performs successive downsampling and
+ upsampling in order to be able to analyze the input features in multiple
+ resolutions.
+ """
+
+ def __init__(self, out_channels=128, in_channels=512, upsampling_depth=4, model_T=True):
+ super().__init__()
+ self.proj_1x1 = ConvNormAct(out_channels, in_channels, 1, stride=1, groups=1)
+ self.depth = upsampling_depth
+ self.spp_dw = nn.ModuleList()
+ self.spp_dw.append(
+ DilatedConvNorm(
+ in_channels, in_channels, kSize=5, stride=1, groups=in_channels, d=1
+ )
+ )
+ for i in range(1, upsampling_depth):
+ self.spp_dw.append(
+ DilatedConvNorm(
+ in_channels,
+ in_channels,
+ kSize=5,
+ stride=2,
+ groups=in_channels,
+ d=1,
+ )
+ )
+
+ self.loc_glo_fus = nn.ModuleList([])
+ for i in range(upsampling_depth):
+ self.loc_glo_fus.append(InjectionMultiSum(in_channels, in_channels))
+
+ self.res_conv = nn.Conv1d(in_channels, out_channels, 1)
+
+ self.globalatt = Mlp(in_channels, in_channels, drop=0.1)
+
+ self.last_layer = nn.ModuleList([])
+ for i in range(self.depth - 1):
+ self.last_layer.append(InjectionMultiSum(in_channels, in_channels, 5))
+
+ def forward(self, x):
+ """
+ :param x: input feature map
+ :return: transformed feature map
+ """
+ residual = x.clone()
+ # Reduce --> project high-dimensional feature maps to low-dimensional space
+ output1 = self.proj_1x1(x)
+ output = [self.spp_dw[0](output1)]
+
+ # Do the downsampling process from the previous level
+ for k in range(1, self.depth):
+ out_k = self.spp_dw[k](output[-1])
+ output.append(out_k)
+
+ # global features
+ global_f = torch.zeros(
+ output[-1].shape, requires_grad=True, device=output1.device
+ )
+ for fea in output:
+ global_f = global_f + F.adaptive_avg_pool1d(
+ fea, output_size=output[-1].shape[-1]
+ )
+ # global_f = global_f + fea
+ global_f = self.globalatt(global_f) # [B, N, T]
+
+ x_fused = []
+ # Gather them now in reverse order
+ for idx in range(self.depth):
+ local = output[idx]
+ x_fused.append(self.loc_glo_fus[idx](local, global_f))
+
+ expanded = None
+ for i in range(self.depth - 2, -1, -1):
+ if i == self.depth - 2:
+ expanded = self.last_layer[i](x_fused[i], x_fused[i - 1])
+ else:
+ expanded = self.last_layer[i](x_fused[i], expanded)
+ # import pdb; pdb.set_trace()
+ return self.res_conv(expanded) + residual
+
+class MultiHeadSelfAttention2D(nn.Module):
+ def __init__(
+ self,
+ in_chan: int,
+ n_freqs: int,
+ n_head: int = 4,
+ hid_chan: int = 4,
+ act_type: str = "prelu",
+ norm_type: str = "LayerNormalization4D",
+ dim: int = 3,
+ *args,
+ **kwargs,
+ ):
+ super(MultiHeadSelfAttention2D, self).__init__()
+ self.in_chan = in_chan
+ self.n_freqs = n_freqs
+ self.n_head = n_head
+ self.hid_chan = hid_chan
+ self.act_type = act_type
+ self.norm_type = norm_type
+ self.dim = dim
+
+ assert self.in_chan % self.n_head == 0
+
+ self.Queries = nn.ModuleList()
+ self.Keys = nn.ModuleList()
+ self.Values = nn.ModuleList()
+
+ for _ in range(self.n_head):
+ self.Queries.append(
+ ATTConvActNorm(
+ in_chan=self.in_chan,
+ out_chan=self.hid_chan,
+ kernel_size=1,
+ act_type=self.act_type,
+ norm_type=self.norm_type,
+ n_freqs=self.n_freqs,
+ is2d=True,
+ )
+ )
+ self.Keys.append(
+ ATTConvActNorm(
+ in_chan=self.in_chan,
+ out_chan=self.hid_chan,
+ kernel_size=1,
+ act_type=self.act_type,
+ norm_type=self.norm_type,
+ n_freqs=self.n_freqs,
+ is2d=True,
+ )
+ )
+ self.Values.append(
+ ATTConvActNorm(
+ in_chan=self.in_chan,
+ out_chan=self.in_chan // self.n_head,
+ kernel_size=1,
+ act_type=self.act_type,
+ norm_type=self.norm_type,
+ n_freqs=self.n_freqs,
+ is2d=True,
+ )
+ )
+
+ self.attn_concat_proj = ATTConvActNorm(
+ in_chan=self.in_chan,
+ out_chan=self.in_chan,
+ kernel_size=1,
+ act_type=self.act_type,
+ norm_type=self.norm_type,
+ n_freqs=self.n_freqs,
+ is2d=True,
+ )
+
+ def forward(self, x: torch.Tensor):
+ if self.dim == 4:
+ x = x.transpose(-2, -1).contiguous()
+
+ batch_size, _, time, freq = x.size()
+ residual = x
+
+ all_Q = [q(x) for q in self.Queries] # [B, E, T, F]
+ all_K = [k(x) for k in self.Keys] # [B, E, T, F]
+ all_V = [v(x) for v in self.Values] # [B, C/n_head, T, F]
+
+ Q = torch.cat(all_Q, dim=0) # [B', E, T, F] B' = B*n_head
+ K = torch.cat(all_K, dim=0) # [B', E, T, F]
+ V = torch.cat(all_V, dim=0) # [B', C/n_head, T, F]
+
+ Q = Q.transpose(1, 2).flatten(start_dim=2) # [B', T, E*F]
+ K = K.transpose(1, 2).flatten(start_dim=2) # [B', T, E*F]
+ V = V.transpose(1, 2) # [B', T, C/n_head, F]
+ old_shape = V.shape
+ V = V.flatten(start_dim=2) # [B', T, C*F/n_head]
+ emb_dim = Q.shape[-1] # C*F/n_head
+
+ attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) # [B', T, T]
+ attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T]
+ V = torch.matmul(attn_mat, V) # [B', T, C*F/n_head]
+ V = V.reshape(old_shape) # [B', T, C/n_head, F]
+ V = V.transpose(1, 2) # [B', C/n_head, T, F]
+ emb_dim = V.shape[1] # C/n_head
+
+ x = V.view([self.n_head, batch_size, emb_dim, time, freq]) # [n_head, B, C/n_head, T, F]
+ x = x.transpose(0, 1).contiguous() # [B, n_head, C/n_head, T, F]
+
+ x = x.view([batch_size, self.n_head * emb_dim, time, freq]) # [B, C, T, F]
+ x = self.attn_concat_proj(x) # [B, C, T, F]
+
+ x = x + residual
+
+ if self.dim == 4:
+ x = x.transpose(-2, -1).contiguous()
+
+ return x
+
+
+class Recurrent(nn.Module):
+ def __init__(
+ self,
+ out_channels=128,
+ in_channels=512,
+ nband=8,
+ upsampling_depth=3,
+ n_head=4,
+ att_hid_chan=4,
+ kernel_size: int = 8,
+ stride: int = 1,
+ _iter=4
+ ):
+ super().__init__()
+ self.nband = nband
+
+ self.freq_path = nn.ModuleList([
+ UConvBlock(out_channels, in_channels, upsampling_depth),
+ MultiHeadSelfAttention2D(out_channels, 1, n_head=n_head, hid_chan=att_hid_chan, act_type="prelu", norm_type="LayerNormalization4D", dim=4),
+ normalizations.get("LayerNormalization4D")((out_channels, 1))
+ ])
+
+ self.frame_path = nn.ModuleList([
+ UConvBlock(out_channels, in_channels, upsampling_depth),
+ MultiHeadSelfAttention2D(out_channels, 1, n_head=n_head, hid_chan=att_hid_chan, act_type="prelu", norm_type="LayerNormalization4D", dim=4),
+ normalizations.get("LayerNormalization4D")((out_channels, 1))
+ ])
+
+ self.iter = _iter
+ self.concat_block = nn.Sequential(
+ nn.Conv2d(out_channels, out_channels, 1, 1, groups=out_channels), nn.PReLU()
+ )
+
+ def forward(self, x):
+ # B, nband, N, T
+ B, nband, N, T = x.shape
+ x = x.permute(0, 2, 1, 3).contiguous() # B, N, nband, T
+ mixture = x.clone()
+ for i in range(self.iter):
+ if i == 0:
+ x = self.freq_time_process(x, B, nband, N, T) # B, N, nband, T
+ else:
+ x = self.freq_time_process(self.concat_block(mixture + x), B, nband, N, T) # B, N, nband, T
+
+ return x.permute(0, 2, 1, 3).contiguous() # B, nband, N, T
+
+ def freq_time_process(self, x, B, nband, N, T):
+ # Process Frequency Path
+ residual_1 = x.clone()
+ x = x.permute(0, 3, 1, 2).contiguous() # B, T, N, nband
+ freq_fea = self.freq_path[0](x.view(B*T, N, nband)) # B*T, N, nband
+ freq_fea = freq_fea.view(B, T, N, nband).permute(0, 2, 1, 3).contiguous() # B, N, T, nband
+ freq_fea = self.freq_path[1](freq_fea) # B, N, T, nband
+ freq_fea = self.freq_path[2](freq_fea) # B, N, T, nband
+ freq_fea = freq_fea.permute(0, 1, 3, 2).contiguous()
+ x = freq_fea + residual_1 # B, N, nband, T
+ # Process Frame Path
+ residual_2 = x.clone()
+ x2 = x.permute(0, 2, 1, 3).contiguous()
+ frame_fea = self.frame_path[0](x2.view(B*nband, N, T)) # B*nband, N, T
+ frame_fea = frame_fea.view(B, nband, N, T).permute(0, 2, 1, 3).contiguous()
+ frame_fea = self.frame_path[1](frame_fea) # B, N, nband, T
+ frame_fea = self.frame_path[2](frame_fea) # B, N, nband, T
+ x = frame_fea + residual_2 # B, N, nband, T
+ return x
+
+class TIGER(BaseModel):
+ def __init__(
+ self,
+ out_channels=128,
+ in_channels=512,
+ num_blocks=16,
+ upsampling_depth=4,
+ att_n_head=4,
+ att_hid_chan=4,
+ att_kernel_size=8,
+ att_stride=1,
+ win=2048,
+ stride=512,
+ num_sources=2,
+ sample_rate=44100,
+ ):
+ super(TIGER, self).__init__(sample_rate=sample_rate)
+
+ self.sample_rate = sample_rate
+ self.win = win
+ self.stride = stride
+ self.group = self.win // 2
+ self.enc_dim = self.win // 2 + 1
+ self.feature_dim = out_channels
+ self.num_output = num_sources
+ self.eps = torch.finfo(torch.float32).eps
+
+ # 0-1k (25 hop), 1k-2k (100 hop), 2k-4k (250 hop), 4k-8k (500 hop)
+ bandwidth_25 = int(np.floor(25 / (sample_rate / 2.) * self.enc_dim))
+ bandwidth_100 = int(np.floor(100 / (sample_rate / 2.) * self.enc_dim))
+ bandwidth_250 = int(np.floor(250 / (sample_rate / 2.) * self.enc_dim))
+ bandwidth_500 = int(np.floor(500 / (sample_rate / 2.) * self.enc_dim))
+ self.band_width = [bandwidth_25]*40
+ self.band_width += [bandwidth_100]*10
+ self.band_width += [bandwidth_250]*8
+ self.band_width += [bandwidth_500]*8
+ self.band_width.append(self.enc_dim - np.sum(self.band_width))
+ self.nband = len(self.band_width)
+ print(self.band_width)
+
+ self.BN = nn.ModuleList([])
+ for i in range(self.nband):
+ self.BN.append(nn.Sequential(nn.GroupNorm(1, self.band_width[i]*2, self.eps),
+ nn.Conv1d(self.band_width[i]*2, self.feature_dim, 1)
+ )
+ )
+
+ self.separator = Recurrent(self.feature_dim, in_channels, self.nband, upsampling_depth, att_n_head, att_hid_chan, att_kernel_size, att_stride, num_blocks)
+
+ self.mask = nn.ModuleList([])
+ for i in range(self.nband):
+ self.mask.append(nn.Sequential(
+ nn.PReLU(),
+ nn.Conv1d(self.feature_dim, self.band_width[i]*4*num_sources, 1, groups=num_sources)
+ )
+ )
+
+ def pad_input(self, input, window, stride):
+ """
+ Zero-padding input according to window/stride size.
+ """
+ batch_size, nsample = input.shape
+
+ # pad the signals at the end for matching the window/stride size
+ rest = window - (stride + nsample % window) % window
+ if rest > 0:
+ pad = torch.zeros(batch_size, rest).type(input.type())
+ input = torch.cat([input, pad], 1)
+ pad_aux = torch.zeros(batch_size, stride).type(input.type())
+ input = torch.cat([pad_aux, input, pad_aux], 1)
+
+ return input, rest
+
+ def forward(self, input):
+ # input shape: (B, C, T)
+ was_one_d = False
+ if input.ndim == 1:
+ was_one_d = True
+ input = input.unsqueeze(0).unsqueeze(1)
+ if input.ndim == 2:
+ was_one_d = True
+ input = input.unsqueeze(1)
+ if input.ndim == 3:
+ input = input
+ batch_size, nch, nsample = input.shape
+ input = input.view(batch_size*nch, -1)
+
+ # frequency-domain separation
+ spec = torch.stft(input, n_fft=self.win, hop_length=self.stride,
+ window=torch.hann_window(self.win).to(input.device).type(input.type()),
+ return_complex=True)
+
+ # print(spec.shape)
+
+ # concat real and imag, split to subbands
+ spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T
+ subband_spec_RI = []
+ subband_spec = []
+ band_idx = 0
+ for i in range(len(self.band_width)):
+ subband_spec_RI.append(spec_RI[:,:,band_idx:band_idx+self.band_width[i]].contiguous())
+ subband_spec.append(spec[:,band_idx:band_idx+self.band_width[i]]) # B*nch, BW, T
+ band_idx += self.band_width[i]
+
+ # normalization and bottleneck
+ subband_feature = []
+ for i in range(len(self.band_width)):
+ subband_feature.append(self.BN[i](subband_spec_RI[i].view(batch_size*nch, self.band_width[i]*2, -1)))
+ subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T
+ # import pdb; pdb.set_trace()
+ # separator
+ sep_output = self.separator(subband_feature.view(batch_size*nch, self.nband, self.feature_dim, -1)) # B, nband, N, T
+ sep_output = sep_output.view(batch_size*nch, self.nband, self.feature_dim, -1)
+
+ sep_subband_spec = []
+ for i in range(self.nband):
+ this_output = self.mask[i](sep_output[:,i]).view(batch_size*nch, 2, 2, self.num_output, self.band_width[i], -1)
+ this_mask = this_output[:,0] * torch.sigmoid(this_output[:,1]) # B*nch, 2, K, BW, T
+ this_mask_real = this_mask[:,0] # B*nch, K, BW, T
+ this_mask_imag = this_mask[:,1] # B*nch, K, BW, T
+ # force mask sum to 1
+ this_mask_real_sum = this_mask_real.sum(1).unsqueeze(1) # B*nch, 1, BW, T
+ this_mask_imag_sum = this_mask_imag.sum(1).unsqueeze(1) # B*nch, 1, BW, T
+ this_mask_real = this_mask_real - (this_mask_real_sum - 1) / self.num_output
+ this_mask_imag = this_mask_imag - this_mask_imag_sum / self.num_output
+ est_spec_real = subband_spec[i].real.unsqueeze(1) * this_mask_real - subband_spec[i].imag.unsqueeze(1) * this_mask_imag # B*nch, K, BW, T
+ est_spec_imag = subband_spec[i].real.unsqueeze(1) * this_mask_imag + subband_spec[i].imag.unsqueeze(1) * this_mask_real # B*nch, K, BW, T
+ sep_subband_spec.append(torch.complex(est_spec_real, est_spec_imag))
+ sep_subband_spec = torch.cat(sep_subband_spec, 2)
+
+ output = torch.istft(sep_subband_spec.view(batch_size*nch*self.num_output, self.enc_dim, -1),
+ n_fft=self.win, hop_length=self.stride,
+ window=torch.hann_window(self.win).to(input.device).type(input.type()), length=nsample)
+ output = output.view(batch_size*nch, self.num_output, -1)
+ # if was_one_d:
+ # return output.squeeze(0)
+ return output
+
+ def get_model_args(self):
+ model_args = {"n_sample_rate": 2}
+ return model_args
\ No newline at end of file
diff --git a/look2hear/models/tiger_dnr.py b/look2hear/models/tiger_dnr.py
new file mode 100644
index 0000000000000000000000000000000000000000..27f7ab131e14a57b82bfc3abce9e62b71dbbb4bc
--- /dev/null
+++ b/look2hear/models/tiger_dnr.py
@@ -0,0 +1,878 @@
+import inspect
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from .base_model import BaseModel
+from ..layers import activations, normalizations
+
+
+def GlobLN(nOut):
+ return nn.GroupNorm(1, nOut, eps=1e-8)
+
+
+class ConvNormAct(nn.Module):
+ """
+ This class defines the convolution layer with normalization and a PReLU
+ activation
+ """
+
+ def __init__(self, nIn, nOut, kSize, stride=1, groups=1):
+ """
+ :param nIn: number of input channels
+ :param nOut: number of output channels
+ :param kSize: kernel size
+ :param stride: stride rate for down-sampling. Default is 1
+ """
+ super().__init__()
+ padding = int((kSize - 1) / 2)
+ self.conv = nn.Conv1d(
+ nIn, nOut, kSize, stride=stride, padding=padding, bias=True, groups=groups
+ )
+ self.norm = GlobLN(nOut)
+ self.act = nn.PReLU()
+
+ def forward(self, input):
+ output = self.conv(input)
+ output = self.norm(output)
+ return self.act(output)
+
+
+class ConvNorm(nn.Module):
+ """
+ This class defines the convolution layer with normalization and PReLU activation
+ """
+
+ def __init__(self, nIn, nOut, kSize, stride=1, groups=1, bias=True):
+ """
+ :param nIn: number of input channels
+ :param nOut: number of output channels
+ :param kSize: kernel size
+ :param stride: stride rate for down-sampling. Default is 1
+ """
+ super().__init__()
+ padding = int((kSize - 1) / 2)
+ self.conv = nn.Conv1d(
+ nIn, nOut, kSize, stride=stride, padding=padding, bias=bias, groups=groups
+ )
+ self.norm = GlobLN(nOut)
+
+ def forward(self, input):
+ output = self.conv(input)
+ return self.norm(output)
+
+
+class ATTConvActNorm(nn.Module):
+ def __init__(
+ self,
+ in_chan: int = 1,
+ out_chan: int = 1,
+ kernel_size: int = -1,
+ stride: int = 1,
+ groups: int = 1,
+ dilation: int = 1,
+ padding: int = None,
+ norm_type: str = None,
+ act_type: str = None,
+ n_freqs: int = -1,
+ xavier_init: bool = False,
+ bias: bool = True,
+ is2d: bool = False,
+ *args,
+ **kwargs,
+ ):
+ super(ATTConvActNorm, self).__init__()
+ self.in_chan = in_chan
+ self.out_chan = out_chan
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.groups = groups
+ self.dilation = dilation
+ self.padding = padding
+ self.norm_type = norm_type
+ self.act_type = act_type
+ self.n_freqs = n_freqs
+ self.xavier_init = xavier_init
+ self.bias = bias
+
+ if self.padding is None:
+ self.padding = 0 if self.stride > 1 else "same"
+
+ if kernel_size > 0:
+ conv = nn.Conv2d if is2d else nn.Conv1d
+
+ self.conv = conv(
+ in_channels=self.in_chan,
+ out_channels=self.out_chan,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=self.groups,
+ bias=self.bias,
+ )
+ if self.xavier_init:
+ nn.init.xavier_uniform_(self.conv.weight)
+ else:
+ self.conv = nn.Identity()
+
+ self.act = activations.get(self.act_type)()
+ self.norm = normalizations.get(self.norm_type)(
+ (self.out_chan, self.n_freqs)
+ if self.norm_type == "LayerNormalization4D"
+ else self.out_chan
+ )
+
+ def forward(self, x: torch.Tensor):
+ output = self.conv(x)
+ output = self.act(output)
+ output = self.norm(output)
+ return output
+
+ def get_config(self):
+ encoder_args = {}
+
+ for k, v in (self.__dict__).items():
+ if not k.startswith("_") and k != "training":
+ if not inspect.ismethod(v):
+ encoder_args[k] = v
+
+ return encoder_args
+
+
+class DilatedConvNorm(nn.Module):
+ """
+ This class defines the dilated convolution with normalized output.
+ """
+
+ def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1):
+ """
+ :param nIn: number of input channels
+ :param nOut: number of output channels
+ :param kSize: kernel size
+ :param stride: optional stride rate for down-sampling
+ :param d: optional dilation rate
+ """
+ super().__init__()
+ self.conv = nn.Conv1d(
+ nIn,
+ nOut,
+ kSize,
+ stride=stride,
+ dilation=d,
+ padding=((kSize - 1) // 2) * d,
+ groups=groups,
+ )
+ # self.norm = nn.GroupNorm(1, nOut, eps=1e-08)
+ self.norm = GlobLN(nOut)
+
+ def forward(self, input):
+ output = self.conv(input)
+ return self.norm(output)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_size, drop=0.1):
+ super().__init__()
+ self.fc1 = ConvNorm(in_features, hidden_size, 1, bias=False)
+ self.dwconv = nn.Conv1d(
+ hidden_size, hidden_size, 5, 1, 2, bias=True, groups=hidden_size
+ )
+ self.act = nn.ReLU()
+ self.fc2 = ConvNorm(hidden_size, in_features, 1, bias=False)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.dwconv(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class InjectionMultiSum(nn.Module):
+ def __init__(self, inp: int, oup: int, kernel: int = 1) -> None:
+ super().__init__()
+ groups = 1
+ if inp == oup:
+ groups = inp
+ self.local_embedding = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
+ self.global_embedding = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
+ self.global_act = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
+ self.act = nn.Sigmoid()
+
+ def forward(self, x_l, x_g):
+ """
+ x_g: global features
+ x_l: local features
+ """
+ B, N, T = x_l.shape
+ local_feat = self.local_embedding(x_l)
+
+ global_act = self.global_act(x_g)
+ sig_act = F.interpolate(self.act(global_act), size=T, mode="nearest")
+ # sig_act = self.act(global_act)
+
+ global_feat = self.global_embedding(x_g)
+ global_feat = F.interpolate(global_feat, size=T, mode="nearest")
+
+ out = local_feat * sig_act + global_feat
+ return out
+
+
+class InjectionMulti(nn.Module):
+ def __init__(self, inp: int, oup: int, kernel: int = 1) -> None:
+ super().__init__()
+ groups = 1
+ if inp == oup:
+ groups = inp
+ self.local_embedding = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
+ self.global_act = ConvNorm(inp, oup, kernel, groups=groups, bias=False)
+ self.act = nn.Sigmoid()
+
+ def forward(self, x_l, x_g):
+ """
+ x_g: global features
+ x_l: local features
+ """
+ B, N, T = x_l.shape
+ local_feat = self.local_embedding(x_l)
+
+ global_act = self.global_act(x_g)
+ sig_act = F.interpolate(self.act(global_act), size=T, mode="nearest")
+ # sig_act = self.act(global_act)
+
+ out = local_feat * sig_act
+ return out
+
+
+class UConvBlock(nn.Module):
+ """
+ This class defines the block which performs successive downsampling and
+ upsampling in order to be able to analyze the input features in multiple
+ resolutions.
+ """
+
+ def __init__(
+ self, out_channels=128, in_channels=512, upsampling_depth=4, model_T=True
+ ):
+ super().__init__()
+ self.proj_1x1 = ConvNormAct(out_channels, in_channels, 1, stride=1, groups=1)
+ self.depth = upsampling_depth
+ self.spp_dw = nn.ModuleList()
+ self.spp_dw.append(
+ DilatedConvNorm(
+ in_channels, in_channels, kSize=5, stride=1, groups=in_channels, d=1
+ )
+ )
+ for i in range(1, upsampling_depth):
+ self.spp_dw.append(
+ DilatedConvNorm(
+ in_channels,
+ in_channels,
+ kSize=5,
+ stride=2,
+ groups=in_channels,
+ d=1,
+ )
+ )
+
+ self.loc_glo_fus = nn.ModuleList([])
+ for i in range(upsampling_depth):
+ self.loc_glo_fus.append(InjectionMultiSum(in_channels, in_channels))
+
+ self.res_conv = nn.Conv1d(in_channels, out_channels, 1)
+
+ self.globalatt = Mlp(in_channels, in_channels, drop=0.1)
+
+ self.last_layer = nn.ModuleList([])
+ for i in range(self.depth - 1):
+ self.last_layer.append(InjectionMultiSum(in_channels, in_channels, 5))
+
+ def forward(self, x):
+ """
+ :param x: input feature map
+ :return: transformed feature map
+ """
+ residual = x.clone()
+ # Reduce --> project high-dimensional feature maps to low-dimensional space
+ output1 = self.proj_1x1(x)
+ output = [self.spp_dw[0](output1)]
+
+ # Do the downsampling process from the previous level
+ for k in range(1, self.depth):
+ out_k = self.spp_dw[k](output[-1])
+ output.append(out_k)
+
+ # global features
+ global_f = torch.zeros(
+ output[-1].shape, requires_grad=True, device=output1.device
+ )
+ for fea in output:
+ global_f = global_f + F.adaptive_avg_pool1d(
+ fea, output_size=output[-1].shape[-1]
+ )
+ # global_f = global_f + fea
+ global_f = self.globalatt(global_f) # [B, N, T]
+
+ x_fused = []
+ # Gather them now in reverse order
+ for idx in range(self.depth):
+ local = output[idx]
+ x_fused.append(self.loc_glo_fus[idx](local, global_f))
+
+ expanded = None
+ for i in range(self.depth - 2, -1, -1):
+ if i == self.depth - 2:
+ expanded = self.last_layer[i](x_fused[i], x_fused[i - 1])
+ else:
+ expanded = self.last_layer[i](x_fused[i], expanded)
+ # import pdb; pdb.set_trace()
+ return self.res_conv(expanded) + residual
+
+
+class MultiHeadSelfAttention2D(nn.Module):
+ def __init__(
+ self,
+ in_chan: int,
+ n_freqs: int,
+ n_head: int = 4,
+ hid_chan: int = 4,
+ act_type: str = "prelu",
+ norm_type: str = "LayerNormalization4D",
+ dim: int = 3,
+ *args,
+ **kwargs,
+ ):
+ super(MultiHeadSelfAttention2D, self).__init__()
+ self.in_chan = in_chan
+ self.n_freqs = n_freqs
+ self.n_head = n_head
+ self.hid_chan = hid_chan
+ self.act_type = act_type
+ self.norm_type = norm_type
+ self.dim = dim
+
+ assert self.in_chan % self.n_head == 0
+
+ self.Queries = nn.ModuleList()
+ self.Keys = nn.ModuleList()
+ self.Values = nn.ModuleList()
+
+ for _ in range(self.n_head):
+ self.Queries.append(
+ ATTConvActNorm(
+ in_chan=self.in_chan,
+ out_chan=self.hid_chan,
+ kernel_size=1,
+ act_type=self.act_type,
+ norm_type=self.norm_type,
+ n_freqs=self.n_freqs,
+ is2d=True,
+ )
+ )
+ self.Keys.append(
+ ATTConvActNorm(
+ in_chan=self.in_chan,
+ out_chan=self.hid_chan,
+ kernel_size=1,
+ act_type=self.act_type,
+ norm_type=self.norm_type,
+ n_freqs=self.n_freqs,
+ is2d=True,
+ )
+ )
+ self.Values.append(
+ ATTConvActNorm(
+ in_chan=self.in_chan,
+ out_chan=self.in_chan // self.n_head,
+ kernel_size=1,
+ act_type=self.act_type,
+ norm_type=self.norm_type,
+ n_freqs=self.n_freqs,
+ is2d=True,
+ )
+ )
+
+ self.attn_concat_proj = ATTConvActNorm(
+ in_chan=self.in_chan,
+ out_chan=self.in_chan,
+ kernel_size=1,
+ act_type=self.act_type,
+ norm_type=self.norm_type,
+ n_freqs=self.n_freqs,
+ is2d=True,
+ )
+
+ def forward(self, x: torch.Tensor):
+ if self.dim == 4:
+ x = x.transpose(-2, -1).contiguous()
+
+ batch_size, _, time, freq = x.size()
+ residual = x
+
+ all_Q = [q(x) for q in self.Queries] # [B, E, T, F]
+ all_K = [k(x) for k in self.Keys] # [B, E, T, F]
+ all_V = [v(x) for v in self.Values] # [B, C/n_head, T, F]
+
+ Q = torch.cat(all_Q, dim=0) # [B', E, T, F] B' = B*n_head
+ K = torch.cat(all_K, dim=0) # [B', E, T, F]
+ V = torch.cat(all_V, dim=0) # [B', C/n_head, T, F]
+
+ Q = Q.transpose(1, 2).flatten(start_dim=2) # [B', T, E*F]
+ K = K.transpose(1, 2).flatten(start_dim=2) # [B', T, E*F]
+ V = V.transpose(1, 2) # [B', T, C/n_head, F]
+ old_shape = V.shape
+ V = V.flatten(start_dim=2) # [B', T, C*F/n_head]
+ emb_dim = Q.shape[-1] # C*F/n_head
+
+ attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) # [B', T, T]
+ attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T]
+ V = torch.matmul(attn_mat, V) # [B', T, C*F/n_head]
+ V = V.reshape(old_shape) # [B', T, C/n_head, F]
+ V = V.transpose(1, 2) # [B', C/n_head, T, F]
+ emb_dim = V.shape[1] # C/n_head
+
+ x = V.view(
+ [self.n_head, batch_size, emb_dim, time, freq]
+ ) # [n_head, B, C/n_head, T, F]
+ x = x.transpose(0, 1).contiguous() # [B, n_head, C/n_head, T, F]
+
+ x = x.view([batch_size, self.n_head * emb_dim, time, freq]) # [B, C, T, F]
+ x = self.attn_concat_proj(x) # [B, C, T, F]
+
+ x = x + residual
+
+ if self.dim == 4:
+ x = x.transpose(-2, -1).contiguous()
+
+ return x
+
+
+class Recurrent(nn.Module):
+ def __init__(
+ self,
+ out_channels=128,
+ in_channels=512,
+ nband=8,
+ upsampling_depth=3,
+ n_head=4,
+ att_hid_chan=4,
+ kernel_size: int = 8,
+ stride: int = 1,
+ _iter=4,
+ ):
+ super().__init__()
+ self.nband = nband
+
+ self.freq_path = nn.ModuleList(
+ [
+ UConvBlock(out_channels, in_channels, upsampling_depth),
+ MultiHeadSelfAttention2D(
+ out_channels,
+ 1,
+ n_head=n_head,
+ hid_chan=att_hid_chan,
+ act_type="prelu",
+ norm_type="LayerNormalization4D",
+ dim=4,
+ ),
+ normalizations.get("LayerNormalization4D")((out_channels, 1)),
+ ]
+ )
+
+ self.frame_path = nn.ModuleList(
+ [
+ UConvBlock(out_channels, in_channels, upsampling_depth),
+ MultiHeadSelfAttention2D(
+ out_channels,
+ 1,
+ n_head=n_head,
+ hid_chan=att_hid_chan,
+ act_type="prelu",
+ norm_type="LayerNormalization4D",
+ dim=4,
+ ),
+ normalizations.get("LayerNormalization4D")((out_channels, 1)),
+ ]
+ )
+
+ self.iter = _iter
+ self.concat_block = nn.Sequential(
+ nn.Conv2d(out_channels, out_channels, 1, 1, groups=out_channels), nn.PReLU()
+ )
+
+ def forward(self, x):
+ # B, nband, N, T
+ B, nband, N, T = x.shape
+ x = x.permute(0, 2, 1, 3).contiguous() # B, N, nband, T
+ mixture = x.clone()
+ for i in range(self.iter):
+ if i == 0:
+ x = self.freq_time_process(x, B, nband, N, T) # B, N, nband, T
+ else:
+ x = self.freq_time_process(
+ self.concat_block(mixture + x), B, nband, N, T
+ ) # B, N, nband, T
+
+ return x.permute(0, 2, 1, 3).contiguous() # B, nband, N, T
+
+ def freq_time_process(self, x, B, nband, N, T):
+ # Process Frequency Path
+ residual_1 = x.clone()
+ x = x.permute(0, 3, 1, 2).contiguous() # B, T, N, nband
+ freq_fea = self.freq_path[0](x.view(B * T, N, nband)) # B*T, N, nband
+ freq_fea = (
+ freq_fea.view(B, T, N, nband).permute(0, 2, 1, 3).contiguous()
+ ) # B, N, T, nband
+ freq_fea = self.freq_path[1](freq_fea) # B, N, T, nband
+ freq_fea = self.freq_path[2](freq_fea) # B, N, T, nband
+ freq_fea = freq_fea.permute(0, 1, 3, 2).contiguous()
+ x = freq_fea + residual_1 # B, N, nband, T
+ # Process Frame Path
+ residual_2 = x.clone()
+ x2 = x.permute(0, 2, 1, 3).contiguous()
+ frame_fea = self.frame_path[0](x2.view(B * nband, N, T)) # B*nband, N, T
+ frame_fea = frame_fea.view(B, nband, N, T).permute(0, 2, 1, 3).contiguous()
+ frame_fea = self.frame_path[1](frame_fea) # B, N, nband, T
+ frame_fea = self.frame_path[2](frame_fea) # B, N, nband, T
+ x = frame_fea + residual_2 # B, N, nband, T
+ return x
+
+
+class TIGER(nn.Module):
+ def __init__(
+ self,
+ out_channels=128,
+ in_channels=512,
+ num_blocks=16,
+ upsampling_depth=4,
+ att_n_head=4,
+ att_hid_chan=4,
+ att_kernel_size=8,
+ att_stride=1,
+ win=2048,
+ stride=512,
+ num_sources=2,
+ sample_rate=44100,
+ ):
+ super(TIGER, self).__init__()
+
+ self.sample_rate = sample_rate
+ self.win = win
+ self.stride = stride
+ self.group = self.win // 2
+ self.enc_dim = self.win // 2 + 1
+ self.feature_dim = out_channels
+ self.num_output = num_sources
+ self.eps = torch.finfo(torch.float32).eps
+
+ # 0-1k (50 hop), 1k-2k (100 hop), 2k-4k (250 hop), 4k-8k (500 hop), 8k-16k (1k hop), 16k-20k (2k hop), 20k-inf
+ bandwidth_50 = int(np.floor(50 / (sample_rate / 2.0) * self.enc_dim))
+ bandwidth_100 = int(np.floor(100 / (sample_rate / 2.0) * self.enc_dim))
+ bandwidth_250 = int(np.floor(250 / (sample_rate / 2.0) * self.enc_dim))
+ bandwidth_500 = int(np.floor(500 / (sample_rate / 2.0) * self.enc_dim))
+ bandwidth_1k = int(np.floor(1000 / (sample_rate / 2.0) * self.enc_dim))
+ bandwidth_2k = int(np.floor(2000 / (sample_rate / 2.0) * self.enc_dim))
+ self.band_width = [bandwidth_50] * 20
+ self.band_width += [bandwidth_100] * 10
+ self.band_width += [bandwidth_250] * 8
+ self.band_width += [bandwidth_500] * 8
+ self.band_width += [bandwidth_1k] * 8
+ self.band_width += [bandwidth_2k] * 2
+ self.band_width.append(self.enc_dim - np.sum(self.band_width))
+ self.nband = len(self.band_width)
+ print(self.band_width)
+
+ self.BN = nn.ModuleList([])
+ for i in range(self.nband):
+ self.BN.append(
+ nn.Sequential(
+ nn.GroupNorm(1, self.band_width[i] * 2, self.eps),
+ nn.Conv1d(self.band_width[i] * 2, self.feature_dim, 1),
+ )
+ )
+
+ self.separator = Recurrent(
+ self.feature_dim,
+ in_channels,
+ self.nband,
+ upsampling_depth,
+ att_n_head,
+ att_hid_chan,
+ att_kernel_size,
+ att_stride,
+ num_blocks,
+ )
+
+ self.mask = nn.ModuleList([])
+ for i in range(self.nband):
+ self.mask.append(
+ nn.Sequential(
+ nn.PReLU(),
+ nn.Conv1d(
+ self.feature_dim,
+ self.band_width[i] * 4 * num_sources,
+ 1,
+ groups=num_sources,
+ ),
+ )
+ )
+
+ def pad_input(self, input, window, stride):
+ """
+ Zero-padding input according to window/stride size.
+ """
+ batch_size, nsample = input.shape
+
+ # pad the signals at the end for matching the window/stride size
+ rest = window - (stride + nsample % window) % window
+ if rest > 0:
+ pad = torch.zeros(batch_size, rest).type(input.type())
+ input = torch.cat([input, pad], 1)
+ pad_aux = torch.zeros(batch_size, stride).type(input.type())
+ input = torch.cat([pad_aux, input, pad_aux], 1)
+
+ return input, rest
+
+ def forward(self, input):
+ # input shape: (B, C, T)
+ was_one_d = False
+ if input.ndim == 1:
+ was_one_d = True
+ input = input.unsqueeze(0).unsqueeze(1)
+ if input.ndim == 2:
+ was_one_d = True
+ input = input.unsqueeze(1)
+ if input.ndim == 3:
+ input = input
+ batch_size, nch, nsample = input.shape
+ input = input.view(batch_size * nch, -1)
+
+ # frequency-domain separation
+ spec = torch.stft(
+ input,
+ n_fft=self.win,
+ hop_length=self.stride,
+ window=torch.hann_window(self.win).type(input.dtype).to(input.device),
+ return_complex=True,
+ )
+
+ # print(spec.shape)
+
+ # concat real and imag, split to subbands
+ spec_RI = torch.stack([spec.real, spec.imag], 1) # B*nch, 2, F, T
+ subband_spec_RI = []
+ subband_spec = []
+ band_idx = 0
+ for i in range(len(self.band_width)):
+ subband_spec_RI.append(
+ spec_RI[:, :, band_idx : band_idx + self.band_width[i]].contiguous()
+ )
+ subband_spec.append(
+ spec[:, band_idx : band_idx + self.band_width[i]]
+ ) # B*nch, BW, T
+ band_idx += self.band_width[i]
+
+ # normalization and bottleneck
+ subband_feature = []
+ for i in range(len(self.band_width)):
+ subband_feature.append(
+ self.BN[i](
+ subband_spec_RI[i].view(
+ batch_size * nch, self.band_width[i] * 2, -1
+ )
+ )
+ )
+ subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T
+ # import pdb; pdb.set_trace()
+ # separator
+ sep_output = self.separator(
+ subband_feature.view(batch_size * nch, self.nband, self.feature_dim, -1)
+ ) # B, nband, N, T
+ sep_output = sep_output.view(batch_size * nch, self.nband, self.feature_dim, -1)
+
+ sep_subband_spec = []
+ for i in range(self.nband):
+ this_output = self.mask[i](sep_output[:, i]).view(
+ batch_size * nch, 2, 2, self.num_output, self.band_width[i], -1
+ )
+ this_mask = this_output[:, 0] * torch.sigmoid(
+ this_output[:, 1]
+ ) # B*nch, 2, K, BW, T
+ this_mask_real = this_mask[:, 0] # B*nch, K, BW, T
+ this_mask_imag = this_mask[:, 1] # B*nch, K, BW, T
+ # force mask sum to 1
+ this_mask_real_sum = this_mask_real.sum(1).unsqueeze(1) # B*nch, 1, BW, T
+ this_mask_imag_sum = this_mask_imag.sum(1).unsqueeze(1) # B*nch, 1, BW, T
+ this_mask_real = this_mask_real - (this_mask_real_sum - 1) / self.num_output
+ this_mask_imag = this_mask_imag - this_mask_imag_sum / self.num_output
+ est_spec_real = (
+ subband_spec[i].real.unsqueeze(1) * this_mask_real
+ - subband_spec[i].imag.unsqueeze(1) * this_mask_imag
+ ) # B*nch, K, BW, T
+ est_spec_imag = (
+ subband_spec[i].real.unsqueeze(1) * this_mask_imag
+ + subband_spec[i].imag.unsqueeze(1) * this_mask_real
+ ) # B*nch, K, BW, T
+ sep_subband_spec.append(torch.complex(est_spec_real, est_spec_imag))
+ sep_subband_spec = torch.cat(sep_subband_spec, 2)
+
+ output = torch.istft(
+ sep_subband_spec.view(batch_size * nch * self.num_output, self.enc_dim, -1),
+ n_fft=self.win,
+ hop_length=self.stride,
+ window=torch.hann_window(self.win).type(input.dtype).to(input.device),
+ length=nsample,
+ )
+ output = output.view(batch_size * nch, self.num_output, -1)
+ # if was_one_d:
+ # return output.squeeze(0)
+ return output
+
+ def get_model_args(self):
+ model_args = {"n_sample_rate": 2}
+ return model_args
+
+
+class TIGERDNR(BaseModel):
+ def __init__(
+ self,
+ out_channels=132,
+ in_channels=256,
+ num_blocks=8,
+ upsampling_depth=5,
+ att_n_head=4,
+ att_hid_chan=4,
+ att_kernel_size=8,
+ att_stride=1,
+ win=2048,
+ stride=512,
+ num_sources=3,
+ sample_rate=44100,
+ ):
+ super(TIGERDNR, self).__init__(sample_rate=sample_rate)
+ self.sr = sample_rate
+
+ self.dialog = TIGER(
+ out_channels=out_channels,
+ in_channels=in_channels,
+ num_blocks=num_blocks,
+ upsampling_depth=upsampling_depth,
+ att_n_head=att_n_head,
+ att_hid_chan=att_hid_chan,
+ att_kernel_size=att_kernel_size,
+ att_stride=att_stride,
+ win=win,
+ stride=stride,
+ num_sources=num_sources,
+ sample_rate=sample_rate,
+ )
+ self.effect = TIGER(
+ out_channels=out_channels,
+ in_channels=in_channels,
+ num_blocks=num_blocks,
+ upsampling_depth=upsampling_depth,
+ att_n_head=att_n_head,
+ att_hid_chan=att_hid_chan,
+ att_kernel_size=att_kernel_size,
+ att_stride=att_stride,
+ win=win,
+ stride=stride,
+ num_sources=num_sources,
+ sample_rate=sample_rate,
+ )
+ self.music = TIGER(
+ out_channels=out_channels,
+ in_channels=in_channels,
+ num_blocks=num_blocks,
+ upsampling_depth=upsampling_depth,
+ att_n_head=att_n_head,
+ att_hid_chan=att_hid_chan,
+ att_kernel_size=att_kernel_size,
+ att_stride=att_stride,
+ win=win,
+ stride=stride,
+ num_sources=num_sources,
+ sample_rate=sample_rate,
+ )
+
+ def wav_chunk_inference(self, model, mixture_tensor, target_length=12.0, hop_length=4.0, batch_size=1, n_tracks=3):
+ """
+ Input:
+ mixture_tensor: Tensor, [nch, input_length]
+
+ Output:
+ all_target_tensor: Tensor, [nch, n_track, input_length]
+ """
+
+ batch_mixture = mixture_tensor # [1, nch, T]
+ # print(batch_mixture.shape, [:,:int(self.sr*24)])
+
+ # split data into segments
+ batch_length = batch_mixture.shape[-1]
+
+ session = int(self.sr * target_length)
+ target = int(self.sr * target_length)
+ ignore = (session - target) // 2
+ hop = int(self.sr * hop_length)
+ tr_ratio = target_length / hop_length
+ if ignore > 0:
+ zero_pad = torch.zeros(batch_mixture.shape[0], batch_mixture.shape[1], ignore).type(batch_mixture.dtype).to(batch_mixture.device)
+ batch_mixture_pad = torch.cat([zero_pad, batch_mixture, zero_pad], -1)
+ else:
+ batch_mixture_pad = batch_mixture
+ if target - hop > 0:
+ hop_pad = torch.zeros(batch_mixture.shape[0], batch_mixture.shape[1], target-hop).type(batch_mixture.dtype).to(batch_mixture.device)
+ batch_mixture_pad = torch.cat([hop_pad, batch_mixture_pad, hop_pad], -1)
+
+ skip_idx = ignore + target - hop
+ zero_pad = torch.zeros(batch_mixture.shape[0], batch_mixture.shape[1], session).type(batch_mixture.dtype).to(batch_mixture.device)
+ num_session = (batch_mixture_pad.shape[-1] - session) // hop + 2
+ all_target = torch.zeros(batch_mixture_pad.shape[0], n_tracks, batch_mixture_pad.shape[1], batch_mixture_pad.shape[2]).to(batch_mixture_pad.device)
+ all_input = []
+ all_segment_length = []
+
+ for i in range(num_session):
+ this_input = batch_mixture_pad[:,:,i*hop:i*hop+session]
+ segment_length = this_input.shape[-1]
+ if segment_length < session:
+ this_input = torch.cat([this_input, zero_pad[:,:,:session-segment_length]], -1)
+ all_input.append(this_input)
+ all_segment_length.append(segment_length)
+
+ all_input = torch.cat(all_input, 0)
+ num_batch = num_session // batch_size
+ if num_session % batch_size > 0:
+ num_batch += 1
+
+ for i in range(num_batch):
+
+ this_input = all_input[i*batch_size:(i+1)*batch_size]
+ actual_batch_size = this_input.shape[0]
+ with torch.no_grad():
+ est_target = model(this_input)
+ # batch, ntrack, nch, T = est_target.shape
+ # est_target = est_target.transpose(1, 2).view(batch*nch, ntrack, T)
+ est_target = est_target.unsqueeze(2)
+
+ for j in range(actual_batch_size):
+ this_est_target = est_target[j,:,:,:all_segment_length[i*batch_size+j]][:,:,ignore:ignore+target].unsqueeze(0)
+ all_target[:,:,:,ignore+(i*batch_size+j)*hop:ignore+(i*batch_size+j)*hop+target] += this_est_target # [batch, ntrack, nch, T]
+
+ all_target = all_target[:,:,:,skip_idx:skip_idx+batch_length].contiguous() / tr_ratio
+
+ return all_target.squeeze(0)
+
+ def forward(self, mixture_tensor):
+ all_target_dialog = self.wav_chunk_inference(self.dialog, mixture_tensor)[2]
+ all_target_effect = self.wav_chunk_inference(self.effect, mixture_tensor)[1]
+ all_target_music = self.wav_chunk_inference(self.music, mixture_tensor)[0]
+ return all_target_dialog, all_target_effect, all_target_music
+
+ def get_model_args(self):
+ model_args = {"n_sample_rate": 2}
+ return model_args
\ No newline at end of file
diff --git a/look2hear/system/__init__.py b/look2hear/system/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e85918dfdb4b1973860440f97fd35ccb38e8ffe
--- /dev/null
+++ b/look2hear/system/__init__.py
@@ -0,0 +1,11 @@
+from .optimizers import make_optimizer
+from .audio_litmodule import AudioLightningModule
+from .audio_litmodule_multidecoder import AudioLightningModuleMultiDecoder
+from .schedulers import DPTNetScheduler
+
+__all__ = [
+ "make_optimizer",
+ "AudioLightningModule",
+ "DPTNetScheduler",
+ "AudioLightningModuleMultiDecoder"
+]
diff --git a/look2hear/system/audio_litmodule.py b/look2hear/system/audio_litmodule.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc5ec6b5c211338a28bd30bcf05ab276ef86c4db
--- /dev/null
+++ b/look2hear/system/audio_litmodule.py
@@ -0,0 +1,247 @@
+import torch
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+from collections.abc import MutableMapping
+from speechbrain.processing.speech_augmentation import SpeedPerturb
+
+def flatten_dict(d, parent_key="", sep="_"):
+ """Flattens a dictionary into a single-level dictionary while preserving
+ parent keys. Taken from
+ `SO `_
+
+ Args:
+ d (MutableMapping): Dictionary to be flattened.
+ parent_key (str): String to use as a prefix to all subsequent keys.
+ sep (str): String to use as a separator between two key levels.
+
+ Returns:
+ dict: Single-level dictionary, flattened.
+ """
+ items = []
+ for k, v in d.items():
+ new_key = parent_key + sep + k if parent_key else k
+ if isinstance(v, MutableMapping):
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
+ else:
+ items.append((new_key, v))
+ return dict(items)
+
+
+class AudioLightningModule(pl.LightningModule):
+ def __init__(
+ self,
+ audio_model=None,
+ video_model=None,
+ optimizer=None,
+ loss_func=None,
+ train_loader=None,
+ val_loader=None,
+ test_loader=None,
+ scheduler=None,
+ config=None,
+ ):
+ super().__init__()
+ self.audio_model = audio_model
+ self.video_model = video_model
+ self.optimizer = optimizer
+ self.loss_func = loss_func
+ self.train_loader = train_loader
+ self.val_loader = val_loader
+ self.test_loader = test_loader
+ self.scheduler = scheduler
+ self.config = {} if config is None else config
+ # Speed Aug
+ self.speedperturb = SpeedPerturb(
+ self.config["datamodule"]["data_config"]["sample_rate"],
+ speeds=[95, 100, 105],
+ perturb_prob=1.0
+ )
+ # Save lightning"s AttributeDict under self.hparams
+ self.default_monitor = "val_loss/dataloader_idx_0"
+ self.save_hyperparameters(self.config_to_hparams(self.config))
+ # self.print(self.audio_model)
+ self.validation_step_outputs = []
+ self.test_step_outputs = []
+
+
+ def forward(self, wav, mouth=None):
+ """Applies forward pass of the model.
+
+ Returns:
+ :class:`torch.Tensor`
+ """
+ return self.audio_model(wav)
+
+ def training_step(self, batch, batch_nb):
+ mixtures, targets, _ = batch
+
+ new_targets = []
+ min_len = -1
+ if self.config["training"]["SpeedAug"] == True:
+ with torch.no_grad():
+ for i in range(targets.shape[1]):
+ new_target = self.speedperturb(targets[:, i, :])
+ new_targets.append(new_target)
+ if i == 0:
+ min_len = new_target.shape[-1]
+ else:
+ if new_target.shape[-1] < min_len:
+ min_len = new_target.shape[-1]
+
+ targets = torch.zeros(
+ targets.shape[0],
+ targets.shape[1],
+ min_len,
+ device=targets.device,
+ dtype=torch.float,
+ )
+ for i, new_target in enumerate(new_targets):
+ targets[:, i, :] = new_targets[i][:, 0:min_len]
+
+ mixtures = targets.sum(1)
+ # print(mixtures.shape)
+ est_sources = self(mixtures)
+ loss = self.loss_func["train"](est_sources, targets)
+
+ self.log(
+ "train_loss",
+ loss,
+ on_epoch=True,
+ prog_bar=True,
+ sync_dist=True,
+ logger=True,
+ )
+
+ return {"loss": loss}
+
+
+ def validation_step(self, batch, batch_nb, dataloader_idx):
+ # cal val loss
+ if dataloader_idx == 0:
+ mixtures, targets, _ = batch
+ # print(mixtures.shape)
+ est_sources = self(mixtures)
+ loss = self.loss_func["val"](est_sources, targets)
+ self.log(
+ "val_loss",
+ loss,
+ on_epoch=True,
+ prog_bar=True,
+ sync_dist=True,
+ logger=True,
+ )
+
+ self.validation_step_outputs.append(loss)
+
+ return {"val_loss": loss}
+
+ # cal test loss
+ if (self.trainer.current_epoch) % 10 == 0 and dataloader_idx == 1:
+ mixtures, targets, _ = batch
+ # print(mixtures.shape)
+ est_sources = self(mixtures)
+ tloss = self.loss_func["val"](est_sources, targets)
+ self.log(
+ "test_loss",
+ tloss,
+ on_epoch=True,
+ prog_bar=True,
+ sync_dist=True,
+ logger=True,
+ )
+ self.test_step_outputs.append(tloss)
+ return {"test_loss": tloss}
+
+ def on_validation_epoch_end(self):
+ # val
+ avg_loss = torch.stack(self.validation_step_outputs).mean()
+ val_loss = torch.mean(self.all_gather(avg_loss))
+ self.log(
+ "lr",
+ self.optimizer.param_groups[0]["lr"],
+ on_epoch=True,
+ prog_bar=True,
+ sync_dist=True,
+ )
+ self.logger.experiment.log(
+ {"learning_rate": self.optimizer.param_groups[0]["lr"], "epoch": self.current_epoch}
+ )
+ self.logger.experiment.log(
+ {"val_pit_sisnr": -val_loss, "epoch": self.current_epoch}
+ )
+
+ # test
+ if (self.trainer.current_epoch) % 10 == 0:
+ avg_loss = torch.stack(self.test_step_outputs).mean()
+ test_loss = torch.mean(self.all_gather(avg_loss))
+ self.logger.experiment.log(
+ {"test_pit_sisnr": -test_loss, "epoch": self.current_epoch}
+ )
+ self.validation_step_outputs.clear() # free memory
+ self.test_step_outputs.clear() # free memory
+
+ def configure_optimizers(self):
+ """Initialize optimizers, batch-wise and epoch-wise schedulers."""
+ if self.scheduler is None:
+ return self.optimizer
+
+ if not isinstance(self.scheduler, (list, tuple)):
+ self.scheduler = [self.scheduler] # support multiple schedulers
+
+ epoch_schedulers = []
+ for sched in self.scheduler:
+ if not isinstance(sched, dict):
+ if isinstance(sched, ReduceLROnPlateau):
+ sched = {"scheduler": sched, "monitor": self.default_monitor}
+ epoch_schedulers.append(sched)
+ else:
+ sched.setdefault("monitor", self.default_monitor)
+ sched.setdefault("frequency", 1)
+ # Backward compat
+ if sched["interval"] == "batch":
+ sched["interval"] = "step"
+ assert sched["interval"] in [
+ "epoch",
+ "step",
+ ], "Scheduler interval should be either step or epoch"
+ epoch_schedulers.append(sched)
+ return [self.optimizer], epoch_schedulers
+
+ # def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
+ # if metric is None:
+ # scheduler.step()
+ # else:
+ # scheduler.step(metric)
+
+ def train_dataloader(self):
+ """Training dataloader"""
+ return self.train_loader
+
+ def val_dataloader(self):
+ """Validation dataloader"""
+ return [self.val_loader, self.test_loader]
+
+ def on_save_checkpoint(self, checkpoint):
+ """Overwrite if you want to save more things in the checkpoint."""
+ checkpoint["training_config"] = self.config
+ return checkpoint
+
+ @staticmethod
+ def config_to_hparams(dic):
+ """Sanitizes the config dict to be handled correctly by torch
+ SummaryWriter. It flatten the config dict, converts ``None`` to
+ ``"None"`` and any list and tuple into torch.Tensors.
+
+ Args:
+ dic (dict): Dictionary to be transformed.
+
+ Returns:
+ dict: Transformed dictionary.
+ """
+ dic = flatten_dict(dic)
+ for k, v in dic.items():
+ if v is None:
+ dic[k] = str(v)
+ elif isinstance(v, (list, tuple)):
+ dic[k] = torch.tensor(v)
+ return dic
diff --git a/look2hear/system/audio_litmodule_multidecoder.py b/look2hear/system/audio_litmodule_multidecoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b89663e2a7b58e8050397c06653bb1306ddafb3
--- /dev/null
+++ b/look2hear/system/audio_litmodule_multidecoder.py
@@ -0,0 +1,258 @@
+import torch
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+from collections.abc import MutableMapping
+# from speechbrain.processing.speech_augmentation import SpeedPerturb
+
+def flatten_dict(d, parent_key="", sep="_"):
+ """Flattens a dictionary into a single-level dictionary while preserving
+ parent keys. Taken from
+ `SO `_
+
+ Args:
+ d (MutableMapping): Dictionary to be flattened.
+ parent_key (str): String to use as a prefix to all subsequent keys.
+ sep (str): String to use as a separator between two key levels.
+
+ Returns:
+ dict: Single-level dictionary, flattened.
+ """
+ items = []
+ for k, v in d.items():
+ new_key = parent_key + sep + k if parent_key else k
+ if isinstance(v, MutableMapping):
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
+ else:
+ items.append((new_key, v))
+ return dict(items)
+
+
+class AudioLightningModuleMultiDecoder(pl.LightningModule):
+ def __init__(
+ self,
+ audio_model=None,
+ video_model=None,
+ optimizer=None,
+ loss_func=None,
+ train_loader=None,
+ val_loader=None,
+ test_loader=None,
+ scheduler=None,
+ config=None,
+ ):
+ super().__init__()
+ self.audio_model = audio_model
+ self.video_model = video_model
+ self.optimizer = optimizer
+ self.loss_func = loss_func
+ self.train_loader = train_loader
+ self.val_loader = val_loader
+ self.test_loader = test_loader
+ self.scheduler = scheduler
+ self.config = {} if config is None else config
+ # Speed Aug
+ # self.speedperturb = SpeedPerturb(
+ # self.config["datamodule"]["data_config"]["sample_rate"],
+ # speeds=[95, 100, 105],
+ # perturb_prob=1.0
+ # )
+ # Save lightning"s AttributeDict under self.hparams
+ self.default_monitor = "val_loss/dataloader_idx_0"
+ self.save_hyperparameters(self.config_to_hparams(self.config))
+ # self.print(self.audio_model)
+ self.validation_step_outputs = []
+ self.test_step_outputs = []
+
+
+ def forward(self, wav, mouth=None):
+ """Applies forward pass of the model.
+
+ Returns:
+ :class:`torch.Tensor`
+ """
+ if self.training:
+ return self.audio_model(wav)
+ else:
+ ests, num_blocks = self.audio_model(wav)
+ assert ests.shape[0] % num_blocks == 0
+ batch_size = int(ests.shape[0] / num_blocks)
+ return ests[-batch_size:]
+
+ def training_step(self, batch, batch_nb):
+ mixtures, targets, _ = batch
+
+ new_targets = []
+ min_len = -1
+ # if self.config["training"]["SpeedAug"] == True:
+ # with torch.no_grad():
+ # for i in range(targets.shape[1]):
+ # new_target = self.speedperturb(targets[:, i, :])
+ # new_targets.append(new_target)
+ # if i == 0:
+ # min_len = new_target.shape[-1]
+ # else:
+ # if new_target.shape[-1] < min_len:
+ # min_len = new_target.shape[-1]
+
+ # targets = torch.zeros(
+ # targets.shape[0],
+ # targets.shape[1],
+ # min_len,
+ # device=targets.device,
+ # dtype=torch.float,
+ # )
+ # for i, new_target in enumerate(new_targets):
+ # targets[:, i, :] = new_targets[i][:, 0:min_len]
+
+ # mixtures = targets.sum(1)
+ # print(mixtures.shape)
+ est_sources, num_blocks = self(mixtures)
+ assert est_sources.shape[0] % num_blocks == 0
+ batch_size = int(est_sources.shape[0] / num_blocks)
+ ests_sources_each_block = []
+ for i in range(num_blocks):
+ ests_sources_each_block.append(est_sources[i * batch_size : (i + 1) * batch_size])
+ loss = self.loss_func["train"](ests_sources_each_block, targets)
+
+ self.log(
+ "train_loss",
+ loss,
+ on_epoch=True,
+ prog_bar=True,
+ sync_dist=True,
+ logger=True,
+ )
+
+ return {"loss": loss}
+
+
+ def validation_step(self, batch, batch_nb, dataloader_idx):
+ # cal val loss
+ if dataloader_idx == 0:
+ mixtures, targets, _ = batch
+ # print(mixtures.shape)
+ est_sources = self(mixtures)
+ loss = self.loss_func["val"](est_sources, targets)
+ self.log(
+ "val_loss",
+ loss,
+ on_epoch=True,
+ prog_bar=True,
+ sync_dist=True,
+ logger=True,
+ )
+
+ self.validation_step_outputs.append(loss)
+
+ return {"val_loss": loss}
+
+ # cal test loss
+ if (self.trainer.current_epoch) % 10 == 0 and dataloader_idx == 1:
+ mixtures, targets, _ = batch
+ # print(mixtures.shape)
+ est_sources = self(mixtures)
+ tloss = self.loss_func["val"](est_sources, targets)
+ self.log(
+ "test_loss",
+ tloss,
+ on_epoch=True,
+ prog_bar=True,
+ sync_dist=True,
+ logger=True,
+ )
+ self.test_step_outputs.append(tloss)
+ return {"test_loss": tloss}
+
+ def on_validation_epoch_end(self):
+ # val
+ avg_loss = torch.stack(self.validation_step_outputs).mean()
+ val_loss = torch.mean(self.all_gather(avg_loss))
+ self.log(
+ "lr",
+ self.optimizer.param_groups[0]["lr"],
+ on_epoch=True,
+ prog_bar=True,
+ sync_dist=True,
+ )
+ self.logger.experiment.log(
+ {"learning_rate": self.optimizer.param_groups[0]["lr"], "epoch": self.current_epoch}
+ )
+ self.logger.experiment.log(
+ {"val_pit_sisnr": -val_loss, "epoch": self.current_epoch}
+ )
+
+ # test
+ if (self.trainer.current_epoch) % 10 == 0:
+ avg_loss = torch.stack(self.test_step_outputs).mean()
+ test_loss = torch.mean(self.all_gather(avg_loss))
+ self.logger.experiment.log(
+ {"test_pit_sisnr": -test_loss, "epoch": self.current_epoch}
+ )
+ self.validation_step_outputs.clear() # free memory
+ self.test_step_outputs.clear() # free memory
+
+ def configure_optimizers(self):
+ """Initialize optimizers, batch-wise and epoch-wise schedulers."""
+ if self.scheduler is None:
+ return self.optimizer
+
+ if not isinstance(self.scheduler, (list, tuple)):
+ self.scheduler = [self.scheduler] # support multiple schedulers
+
+ epoch_schedulers = []
+ for sched in self.scheduler:
+ if not isinstance(sched, dict):
+ if isinstance(sched, ReduceLROnPlateau):
+ sched = {"scheduler": sched, "monitor": self.default_monitor}
+ epoch_schedulers.append(sched)
+ else:
+ sched.setdefault("monitor", self.default_monitor)
+ sched.setdefault("frequency", 1)
+ # Backward compat
+ if sched["interval"] == "batch":
+ sched["interval"] = "step"
+ assert sched["interval"] in [
+ "epoch",
+ "step",
+ ], "Scheduler interval should be either step or epoch"
+ epoch_schedulers.append(sched)
+ return [self.optimizer], epoch_schedulers
+
+ # def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
+ # if metric is None:
+ # scheduler.step()
+ # else:
+ # scheduler.step(metric)
+
+ def train_dataloader(self):
+ """Training dataloader"""
+ return self.train_loader
+
+ def val_dataloader(self):
+ """Validation dataloader"""
+ return [self.val_loader, self.test_loader]
+
+ def on_save_checkpoint(self, checkpoint):
+ """Overwrite if you want to save more things in the checkpoint."""
+ checkpoint["training_config"] = self.config
+ return checkpoint
+
+ @staticmethod
+ def config_to_hparams(dic):
+ """Sanitizes the config dict to be handled correctly by torch
+ SummaryWriter. It flatten the config dict, converts ``None`` to
+ ``"None"`` and any list and tuple into torch.Tensors.
+
+ Args:
+ dic (dict): Dictionary to be transformed.
+
+ Returns:
+ dict: Transformed dictionary.
+ """
+ dic = flatten_dict(dic)
+ for k, v in dic.items():
+ if v is None:
+ dic[k] = str(v)
+ elif isinstance(v, (list, tuple)):
+ dic[k] = torch.tensor(v)
+ return dic
diff --git a/look2hear/system/optimizers.py b/look2hear/system/optimizers.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca3293e5f1cdf319fd1b39bb7c8649816f8d5d76
--- /dev/null
+++ b/look2hear/system/optimizers.py
@@ -0,0 +1,106 @@
+from torch.optim.optimizer import Optimizer
+from torch.optim import Adam, RMSprop, SGD, Adadelta, Adagrad, Adamax, AdamW, ASGD
+from torch_optimizer import (
+ AccSGD,
+ AdaBound,
+ AdaMod,
+ DiffGrad,
+ Lamb,
+ NovoGrad,
+ PID,
+ QHAdam,
+ QHM,
+ RAdam,
+ SGDW,
+ Yogi,
+ Ranger,
+ RangerQH,
+ RangerVA,
+)
+
+
+__all__ = [
+ "AccSGD",
+ "AdaBound",
+ "AdaMod",
+ "DiffGrad",
+ "Lamb",
+ "NovoGrad",
+ "PID",
+ "QHAdam",
+ "QHM",
+ "RAdam",
+ "SGDW",
+ "Yogi",
+ "Ranger",
+ "RangerQH",
+ "RangerVA",
+ "Adam",
+ "RMSprop",
+ "SGD",
+ "Adadelta",
+ "Adagrad",
+ "Adamax",
+ "AdamW",
+ "ASGD",
+ "make_optimizer",
+ "get",
+]
+
+
+def make_optimizer(params, optim_name="adam", **kwargs):
+ """
+
+ Args:
+ params (iterable): Output of `nn.Module.parameters()`.
+ optimizer (str or :class:`torch.optim.Optimizer`): Identifier understood
+ by :func:`~.get`.
+ **kwargs (dict): keyword arguments for the optimizer.
+
+ Returns:
+ torch.optim.Optimizer
+ Examples
+ >>> from torch import nn
+ >>> model = nn.Sequential(nn.Linear(10, 10))
+ >>> optimizer = make_optimizer(model.parameters(), optimizer='sgd',
+ >>> lr=1e-3)
+ """
+ return get(optim_name)(params, **kwargs)
+
+
+def register_optimizer(custom_opt):
+ """Register a custom opt, gettable with `optimzers.get`.
+
+ Args:
+ custom_opt: Custom optimizer to register.
+
+ """
+ if (
+ custom_opt.__name__ in globals().keys()
+ or custom_opt.__name__.lower() in globals().keys()
+ ):
+ raise ValueError(
+ f"Activation {custom_opt.__name__} already exists. Choose another name."
+ )
+ globals().update({custom_opt.__name__: custom_opt})
+
+
+def get(identifier):
+ """Returns an optimizer function from a string. Returns its input if it
+ is callable (already a :class:`torch.optim.Optimizer` for example).
+
+ Args:
+ identifier (str or Callable): the optimizer identifier.
+
+ Returns:
+ :class:`torch.optim.Optimizer` or None
+ """
+ if isinstance(identifier, Optimizer):
+ return identifier
+ elif isinstance(identifier, str):
+ to_get = {k.lower(): v for k, v in globals().items()}
+ cls = to_get.get(identifier.lower())
+ if cls is None:
+ raise ValueError(f"Could not interpret optimizer : {str(identifier)}")
+ return cls
+ raise ValueError(f"Could not interpret optimizer : {str(identifier)}")
diff --git a/look2hear/system/schedulers.py b/look2hear/system/schedulers.py
new file mode 100644
index 0000000000000000000000000000000000000000..085dcb7f1b21f479373b6f7bc7aeb5cef33b7731
--- /dev/null
+++ b/look2hear/system/schedulers.py
@@ -0,0 +1,115 @@
+import torch
+from torch.optim.optimizer import Optimizer
+import pytorch_lightning as pl
+
+
+class BaseScheduler(object):
+ """Base class for the step-wise scheduler logic.
+
+ Args:
+ optimizer (Optimize): Optimizer instance to apply lr schedule on.
+
+ Subclass this and overwrite ``_get_lr`` to write your own step-wise scheduler.
+ """
+
+ def __init__(self, optimizer):
+ self.optimizer = optimizer
+ self.step_num = 0
+
+ def zero_grad(self):
+ self.optimizer.zero_grad()
+
+ def _get_lr(self):
+ raise NotImplementedError
+
+ def _set_lr(self, lr):
+ for param_group in self.optimizer.param_groups:
+ param_group["lr"] = lr
+
+ def step(self, metrics=None, epoch=None):
+ """Update step-wise learning rate before optimizer.step."""
+ self.step_num += 1
+ lr = self._get_lr()
+ self._set_lr(lr)
+
+ def load_state_dict(self, state_dict):
+ self.__dict__.update(state_dict)
+
+ def state_dict(self):
+ return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
+
+ def as_tensor(self, start=0, stop=100_000):
+ """Returns the scheduler values from start to stop."""
+ lr_list = []
+ for _ in range(start, stop):
+ self.step_num += 1
+ lr_list.append(self._get_lr())
+ self.step_num = 0
+ return torch.tensor(lr_list)
+
+ def plot(self, start=0, stop=100_000): # noqa
+ """Plot the scheduler values from start to stop."""
+ import matplotlib.pyplot as plt
+
+ all_lr = self.as_tensor(start=start, stop=stop)
+ plt.plot(all_lr.numpy())
+ plt.show()
+
+class DPTNetScheduler(BaseScheduler):
+ """Dual Path Transformer Scheduler used in [1]
+
+ Args:
+ optimizer (Optimizer): Optimizer instance to apply lr schedule on.
+ steps_per_epoch (int): Number of steps per epoch.
+ d_model(int): The number of units in the layer output.
+ warmup_steps (int): The number of steps in the warmup stage of training.
+ noam_scale (float): Linear increase rate in first phase.
+ exp_max (float): Max learning rate in second phase.
+ exp_base (float): Exp learning rate base in second phase.
+
+ Schedule:
+ This scheduler increases the learning rate linearly for the first
+ ``warmup_steps``, and then decay it by 0.98 for every two epochs.
+
+ References
+ [1]: Jingjing Chen et al. "Dual-Path Transformer Network: Direct Context-
+ Aware Modeling for End-to-End Monaural Speech Separation" Interspeech 2020.
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ steps_per_epoch,
+ d_model,
+ warmup_steps=4000,
+ noam_scale=1.0,
+ exp_max=0.0004,
+ exp_base=0.98,
+ ):
+ super().__init__(optimizer)
+ self.noam_scale = noam_scale
+ self.d_model = d_model
+ self.warmup_steps = warmup_steps
+ self.exp_max = exp_max
+ self.exp_base = exp_base
+ self.steps_per_epoch = steps_per_epoch
+ self.epoch = 0
+
+ def _get_lr(self):
+ if self.step_num % self.steps_per_epoch == 0:
+ self.epoch += 1
+
+ if self.step_num > self.warmup_steps:
+ # exp decaying
+ lr = self.exp_max * (self.exp_base ** ((self.epoch - 1) // 2))
+ else:
+ # noam
+ lr = (
+ self.noam_scale
+ * self.d_model ** (-0.5)
+ * min(self.step_num ** (-0.5), self.step_num * self.warmup_steps ** (-1.5))
+ )
+ return lr
+
+# Backward compat
+_BaseScheduler = BaseScheduler
\ No newline at end of file
diff --git a/look2hear/utils/__init__.py b/look2hear/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9d90195f22f7a5eb37ca18d9b17bfc2f9126e5c
--- /dev/null
+++ b/look2hear/utils/__init__.py
@@ -0,0 +1,41 @@
+from .stft import STFT
+from .torch_utils import pad_x_to_y, shape_reconstructed, tensors_to_device
+from .parser_utils import (
+ prepare_parser_from_dict,
+ parse_args_as_dict,
+ str_int_float,
+ str2bool,
+ str2bool_arg,
+ isfloat,
+ isint,
+)
+from .lightning_utils import print_only, RichProgressBarTheme, MyRichProgressBar, BatchesProcessedColumn, MyMetricsTextColumn
+from .complex_utils import is_complex, is_torch_complex_tensor, new_complex_like
+from .get_layer_from_string import get_layer
+from .inversible_interface import InversibleInterface
+from .nets_utils import make_pad_mask
+
+__all__ = [
+ "STFT",
+ "pad_x_to_y",
+ "shape_reconstructed",
+ "tensors_to_device",
+ "prepare_parser_from_dict",
+ "parse_args_as_dict",
+ "str_int_float",
+ "str2bool",
+ "str2bool_arg",
+ "isfloat",
+ "isint",
+ "print_only",
+ "RichProgressBarTheme",
+ "MyRichProgressBar",
+ "BatchesProcessedColumn",
+ "MyMetricsTextColumn",
+ "is_complex",
+ "is_torch_complex_tensor",
+ "new_complex_like",
+ "get_layer",
+ "InversibleInterface",
+ "make_pad_mask",
+]
diff --git a/look2hear/utils/complex_utils.py b/look2hear/utils/complex_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e353197bfad9a62731a3265fd43bf628840b54d
--- /dev/null
+++ b/look2hear/utils/complex_utils.py
@@ -0,0 +1,191 @@
+"""Beamformer module."""
+from typing import Sequence, Tuple, Union
+
+import torch
+from packaging.version import parse as V
+from torch_complex import functional as FC
+from torch_complex.tensor import ComplexTensor
+
+EPS = torch.finfo(torch.double).eps
+is_torch_1_8_plus = V(torch.__version__) >= V("1.8.0")
+is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
+
+
+def new_complex_like(
+ ref: Union[torch.Tensor, ComplexTensor],
+ real_imag: Tuple[torch.Tensor, torch.Tensor],
+):
+ if isinstance(ref, ComplexTensor):
+ return ComplexTensor(*real_imag)
+ elif is_torch_complex_tensor(ref):
+ return torch.complex(*real_imag)
+ else:
+ raise ValueError(
+ "Please update your PyTorch version to 1.9+ for complex support."
+ )
+
+
+def is_torch_complex_tensor(c):
+ return (
+ not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c)
+ )
+
+
+def is_complex(c):
+ return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c)
+
+
+def to_double(c):
+ if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
+ return c.to(dtype=torch.complex128)
+ else:
+ return c.double()
+
+
+def to_float(c):
+ if not isinstance(c, ComplexTensor) and is_torch_1_9_plus and torch.is_complex(c):
+ return c.to(dtype=torch.complex64)
+ else:
+ return c.float()
+
+
+def cat(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
+ if not isinstance(seq, (list, tuple)):
+ raise TypeError(
+ "cat(): argument 'tensors' (position 1) must be tuple of Tensors, "
+ "not Tensor"
+ )
+ if isinstance(seq[0], ComplexTensor):
+ return FC.cat(seq, *args, **kwargs)
+ else:
+ return torch.cat(seq, *args, **kwargs)
+
+
+def complex_norm(
+ c: Union[torch.Tensor, ComplexTensor], dim=-1, keepdim=False
+) -> torch.Tensor:
+ if not is_complex(c):
+ raise TypeError("Input is not a complex tensor.")
+ if is_torch_complex_tensor(c):
+ return torch.norm(c, dim=dim, keepdim=keepdim)
+ else:
+ if dim is None:
+ return torch.sqrt((c.real**2 + c.imag**2).sum() + EPS)
+ else:
+ return torch.sqrt(
+ (c.real**2 + c.imag**2).sum(dim=dim, keepdim=keepdim) + EPS
+ )
+
+
+def einsum(equation, *operands):
+ # NOTE: Do not mix ComplexTensor and torch.complex in the input!
+ # NOTE (wangyou): Until PyTorch 1.9.0, torch.einsum does not support
+ # mixed input with complex and real tensors.
+ if len(operands) == 1:
+ if isinstance(operands[0], (tuple, list)):
+ operands = operands[0]
+ complex_module = FC if isinstance(operands[0], ComplexTensor) else torch
+ return complex_module.einsum(equation, *operands)
+ elif len(operands) != 2:
+ op0 = operands[0]
+ same_type = all(op.dtype == op0.dtype for op in operands[1:])
+ if same_type:
+ _einsum = FC.einsum if isinstance(op0, ComplexTensor) else torch.einsum
+ return _einsum(equation, *operands)
+ else:
+ raise ValueError("0 or More than 2 operands are not supported.")
+ a, b = operands
+ if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
+ return FC.einsum(equation, a, b)
+ elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
+ if not torch.is_complex(a):
+ o_real = torch.einsum(equation, a, b.real)
+ o_imag = torch.einsum(equation, a, b.imag)
+ return torch.complex(o_real, o_imag)
+ elif not torch.is_complex(b):
+ o_real = torch.einsum(equation, a.real, b)
+ o_imag = torch.einsum(equation, a.imag, b)
+ return torch.complex(o_real, o_imag)
+ else:
+ return torch.einsum(equation, a, b)
+ else:
+ return torch.einsum(equation, a, b)
+
+
+def inverse(
+ c: Union[torch.Tensor, ComplexTensor]
+) -> Union[torch.Tensor, ComplexTensor]:
+ if isinstance(c, ComplexTensor):
+ return c.inverse2()
+ else:
+ return c.inverse()
+
+
+def matmul(
+ a: Union[torch.Tensor, ComplexTensor], b: Union[torch.Tensor, ComplexTensor]
+) -> Union[torch.Tensor, ComplexTensor]:
+ # NOTE: Do not mix ComplexTensor and torch.complex in the input!
+ # NOTE (wangyou): Until PyTorch 1.9.0, torch.matmul does not support
+ # multiplication between complex and real tensors.
+ if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
+ return FC.matmul(a, b)
+ elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
+ if not torch.is_complex(a):
+ o_real = torch.matmul(a, b.real)
+ o_imag = torch.matmul(a, b.imag)
+ return torch.complex(o_real, o_imag)
+ elif not torch.is_complex(b):
+ o_real = torch.matmul(a.real, b)
+ o_imag = torch.matmul(a.imag, b)
+ return torch.complex(o_real, o_imag)
+ else:
+ return torch.matmul(a, b)
+ else:
+ return torch.matmul(a, b)
+
+
+def trace(a: Union[torch.Tensor, ComplexTensor]):
+ # NOTE (wangyou): until PyTorch 1.9.0, torch.trace does not
+ # support bacth processing. Use FC.trace() as fallback.
+ return FC.trace(a)
+
+
+def reverse(a: Union[torch.Tensor, ComplexTensor], dim=0):
+ if isinstance(a, ComplexTensor):
+ return FC.reverse(a, dim=dim)
+ else:
+ return torch.flip(a, dims=(dim,))
+
+
+def solve(b: Union[torch.Tensor, ComplexTensor], a: Union[torch.Tensor, ComplexTensor]):
+ """Solve the linear equation ax = b."""
+ # NOTE: Do not mix ComplexTensor and torch.complex in the input!
+ # NOTE (wangyou): Until PyTorch 1.9.0, torch.solve does not support
+ # mixed input with complex and real tensors.
+ if isinstance(a, ComplexTensor) or isinstance(b, ComplexTensor):
+ if isinstance(a, ComplexTensor) and isinstance(b, ComplexTensor):
+ return FC.solve(b, a, return_LU=False)
+ else:
+ return matmul(inverse(a), b)
+ elif is_torch_1_9_plus and (torch.is_complex(a) or torch.is_complex(b)):
+ if torch.is_complex(a) and torch.is_complex(b):
+ return torch.linalg.solve(a, b)
+ else:
+ return matmul(inverse(a), b)
+ else:
+ if is_torch_1_8_plus:
+ return torch.linalg.solve(a, b)
+ else:
+ return torch.solve(b, a)[0]
+
+
+def stack(seq: Sequence[Union[ComplexTensor, torch.Tensor]], *args, **kwargs):
+ if not isinstance(seq, (list, tuple)):
+ raise TypeError(
+ "stack(): argument 'tensors' (position 1) must be tuple of Tensors, "
+ "not Tensor"
+ )
+ if isinstance(seq[0], ComplexTensor):
+ return FC.stack(seq, *args, **kwargs)
+ else:
+ return torch.stack(seq, *args, **kwargs)
\ No newline at end of file
diff --git a/look2hear/utils/get_layer_from_string.py b/look2hear/utils/get_layer_from_string.py
new file mode 100644
index 0000000000000000000000000000000000000000..3987a8df65fa6fd6a69f74d2d3aa2bc06971512e
--- /dev/null
+++ b/look2hear/utils/get_layer_from_string.py
@@ -0,0 +1,43 @@
+import difflib
+
+import torch
+
+
+def get_layer(l_name, library=torch.nn):
+ """Return layer object handler from library e.g. from torch.nn
+
+ E.g. if l_name=="elu", returns torch.nn.ELU.
+
+ Args:
+ l_name (string): Case insensitive name for layer in library (e.g. .'elu').
+ library (module): Name of library/module where to search for object handler
+ with l_name e.g. "torch.nn".
+
+ Returns:
+ layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU)
+
+ """
+
+ all_torch_layers = [x for x in dir(torch.nn)]
+ match = [x for x in all_torch_layers if l_name.lower() == x.lower()]
+ if len(match) == 0:
+ close_matches = difflib.get_close_matches(
+ l_name, [x.lower() for x in all_torch_layers]
+ )
+ raise NotImplementedError(
+ "Layer with name {} not found in {}.\n Closest matches: {}".format(
+ l_name, str(library), close_matches
+ )
+ )
+ elif len(match) > 1:
+ close_matches = difflib.get_close_matches(
+ l_name, [x.lower() for x in all_torch_layers]
+ )
+ raise NotImplementedError(
+ "Multiple matchs for layer with name {} not found in {}.\n "
+ "All matches: {}".format(l_name, str(library), close_matches)
+ )
+ else:
+ # valid
+ layer_handler = getattr(library, match[0])
+ return layer_handler
\ No newline at end of file
diff --git a/look2hear/utils/inversible_interface.py b/look2hear/utils/inversible_interface.py
new file mode 100644
index 0000000000000000000000000000000000000000..919107ea0c1c8b9e35b957ab3109cf11faa6067f
--- /dev/null
+++ b/look2hear/utils/inversible_interface.py
@@ -0,0 +1,13 @@
+from abc import ABC, abstractmethod
+from typing import Tuple
+
+import torch
+
+
+class InversibleInterface(ABC):
+ @abstractmethod
+ def inverse(
+ self, input: torch.Tensor, input_lengths: torch.Tensor = None
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # return output, output_lengths
+ raise NotImplementedError
\ No newline at end of file
diff --git a/look2hear/utils/lightning_utils.py b/look2hear/utils/lightning_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..29f2f73cc3ca6c4587c63ed5c3ef051609a0cb1a
--- /dev/null
+++ b/look2hear/utils/lightning_utils.py
@@ -0,0 +1,104 @@
+from rich import print
+from dataclasses import dataclass
+from pytorch_lightning.utilities import rank_zero_only
+from typing import Union
+from pytorch_lightning.callbacks.progress.rich_progress import *
+from rich.console import Console, RenderableType
+from rich.progress_bar import ProgressBar
+from rich.style import Style
+from rich.text import Text
+from rich.progress import (
+ BarColumn,
+ DownloadColumn,
+ Progress,
+ TaskID,
+ TextColumn,
+ TimeRemainingColumn,
+ TransferSpeedColumn,
+ ProgressColumn
+)
+from rich import print, reconfigure
+
+@rank_zero_only
+def print_only(message: str):
+ print(message)
+
+@dataclass
+class RichProgressBarTheme:
+ """Styles to associate to different base components.
+
+ Args:
+ description: Style for the progress bar description. For eg., Epoch x, Testing, etc.
+ progress_bar: Style for the bar in progress.
+ progress_bar_finished: Style for the finished progress bar.
+ progress_bar_pulse: Style for the progress bar when `IterableDataset` is being processed.
+ batch_progress: Style for the progress tracker (i.e 10/50 batches completed).
+ time: Style for the processed time and estimate time remaining.
+ processing_speed: Style for the speed of the batches being processed.
+ metrics: Style for the metrics
+
+ https://rich.readthedocs.io/en/stable/style.html
+ """
+
+ description: Union[str, Style] = "#FF4500"
+ progress_bar: Union[str, Style] = "#f92672"
+ progress_bar_finished: Union[str, Style] = "#b7cc8a"
+ progress_bar_pulse: Union[str, Style] = "#f92672"
+ batch_progress: Union[str, Style] = "#fc608a"
+ time: Union[str, Style] = "#45ada2"
+ processing_speed: Union[str, Style] = "#DC143C"
+ metrics: Union[str, Style] = "#228B22"
+
+class BatchesProcessedColumn(ProgressColumn):
+ def __init__(self, style: Union[str, Style]):
+ self.style = style
+ super().__init__()
+
+ def render(self, task) -> RenderableType:
+ total = task.total if task.total != float("inf") else "--"
+ return Text(f"{int(task.completed)}/{int(total)}", style=self.style)
+
+class MyMetricsTextColumn(ProgressColumn):
+ """A column containing text."""
+
+ def __init__(self, style):
+ self._tasks = {}
+ self._current_task_id = 0
+ self._metrics = {}
+ self._style = style
+ super().__init__()
+
+ def update(self, metrics):
+ # Called when metrics are ready to be rendered.
+ # This is to prevent render from causing deadlock issues by requesting metrics
+ # in separate threads.
+ self._metrics = metrics
+
+ def render(self, task) -> Text:
+ text = ""
+ for k, v in self._metrics.items():
+ text += f"{k}: {round(v, 3) if isinstance(v, float) else v} "
+ return Text(text, justify="left", style=self._style)
+
+class MyRichProgressBar(RichProgressBar):
+ """A progress bar prints metrics at the end of each epoch
+ """
+
+ def _init_progress(self, trainer):
+ if self.is_enabled and (self.progress is None or self._progress_stopped):
+ self._reset_progress_bar_ids()
+ reconfigure(**self._console_kwargs)
+ # file = open("Look2Hear/Experiments/run_logs/EdgeFRCNN-Noncausal.log", 'w')
+ self._console: Console = Console(force_terminal=True)
+ self._console.clear_live()
+ self._metric_component = MetricsTextColumn(trainer, self.theme.metrics)
+ self.progress = CustomProgress(
+ *self.configure_columns(trainer),
+ self._metric_component,
+ auto_refresh=False,
+ disable=self.is_disabled,
+ console=self._console,
+ )
+ self.progress.start()
+ # progress has started
+ self._progress_stopped = False
\ No newline at end of file
diff --git a/look2hear/utils/nets_utils.py b/look2hear/utils/nets_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7a677023d87e6370352454994d12e458437e2c2
--- /dev/null
+++ b/look2hear/utils/nets_utils.py
@@ -0,0 +1,503 @@
+# -*- coding: utf-8 -*-
+
+"""Network related utility tools."""
+
+import logging
+from typing import Dict
+
+import numpy as np
+import torch
+
+
+def to_device(m, x):
+ """Send tensor into the device of the module.
+
+ Args:
+ m (torch.nn.Module): Torch module.
+ x (Tensor): Torch tensor.
+
+ Returns:
+ Tensor: Torch tensor located in the same place as torch module.
+
+ """
+ if isinstance(m, torch.nn.Module):
+ device = next(m.parameters()).device
+ elif isinstance(m, torch.Tensor):
+ device = m.device
+ else:
+ raise TypeError(
+ "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
+ )
+ return x.to(device)
+
+
+def pad_list(xs, pad_value):
+ """Perform padding for the list of tensors.
+
+ Args:
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
+ pad_value (float): Value for padding.
+
+ Returns:
+ Tensor: Padded tensor (B, Tmax, `*`).
+
+ Examples:
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
+ >>> x
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
+ >>> pad_list(x, 0)
+ tensor([[1., 1., 1., 1.],
+ [1., 1., 0., 0.],
+ [1., 0., 0., 0.]])
+
+ """
+ n_batch = len(xs)
+ max_len = max(x.size(0) for x in xs)
+ pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
+
+ for i in range(n_batch):
+ pad[i, : xs[i].size(0)] = xs[i]
+
+ return pad
+
+
+def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
+ """Make mask tensor containing indices of padded part.
+
+ Args:
+ lengths (LongTensor or List): Batch of lengths (B,).
+ xs (Tensor, optional): The reference tensor.
+ If set, masks will be the same shape as this tensor.
+ length_dim (int, optional): Dimension indicator of the above tensor.
+ See the example.
+
+ Returns:
+ Tensor: Mask tensor containing indices of padded part.
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
+
+ Examples:
+ With only lengths.
+
+ >>> lengths = [5, 3, 2]
+ >>> make_pad_mask(lengths)
+ masks = [[0, 0, 0, 0 ,0],
+ [0, 0, 0, 1, 1],
+ [0, 0, 1, 1, 1]]
+
+ With the reference tensor.
+
+ >>> xs = torch.zeros((3, 2, 4))
+ >>> make_pad_mask(lengths, xs)
+ tensor([[[0, 0, 0, 0],
+ [0, 0, 0, 0]],
+ [[0, 0, 0, 1],
+ [0, 0, 0, 1]],
+ [[0, 0, 1, 1],
+ [0, 0, 1, 1]]], dtype=torch.uint8)
+ >>> xs = torch.zeros((3, 2, 6))
+ >>> make_pad_mask(lengths, xs)
+ tensor([[[0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1]],
+ [[0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1]],
+ [[0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
+
+ With the reference tensor and dimension indicator.
+
+ >>> xs = torch.zeros((3, 6, 6))
+ >>> make_pad_mask(lengths, xs, 1)
+ tensor([[[0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1]],
+ [[0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1]],
+ [[0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
+ >>> make_pad_mask(lengths, xs, 2)
+ tensor([[[0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1],
+ [0, 0, 0, 0, 0, 1]],
+ [[0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1],
+ [0, 0, 0, 1, 1, 1]],
+ [[0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1],
+ [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
+
+ """
+ if length_dim == 0:
+ raise ValueError("length_dim cannot be 0: {}".format(length_dim))
+
+ if not isinstance(lengths, list):
+ lengths = lengths.long().tolist()
+
+ bs = int(len(lengths))
+ if maxlen is None:
+ if xs is None:
+ maxlen = int(max(lengths))
+ else:
+ maxlen = xs.size(length_dim)
+ else:
+ assert xs is None
+ assert maxlen >= int(max(lengths))
+
+ seq_range = torch.arange(0, maxlen, dtype=torch.int64)
+ seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
+ seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+
+ if xs is not None:
+ assert xs.size(0) == bs, (xs.size(0), bs)
+
+ if length_dim < 0:
+ length_dim = xs.dim() + length_dim
+ # ind = (:, None, ..., None, :, , None, ..., None)
+ ind = tuple(
+ slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
+ )
+ mask = mask[ind].expand_as(xs).to(xs.device)
+ return mask
+
+
+def make_non_pad_mask(lengths, xs=None, length_dim=-1):
+ """Make mask tensor containing indices of non-padded part.
+
+ Args:
+ lengths (LongTensor or List): Batch of lengths (B,).
+ xs (Tensor, optional): The reference tensor.
+ If set, masks will be the same shape as this tensor.
+ length_dim (int, optional): Dimension indicator of the above tensor.
+ See the example.
+
+ Returns:
+ ByteTensor: mask tensor containing indices of padded part.
+ dtype=torch.uint8 in PyTorch 1.2-
+ dtype=torch.bool in PyTorch 1.2+ (including 1.2)
+
+ Examples:
+ With only lengths.
+
+ >>> lengths = [5, 3, 2]
+ >>> make_non_pad_mask(lengths)
+ masks = [[1, 1, 1, 1 ,1],
+ [1, 1, 1, 0, 0],
+ [1, 1, 0, 0, 0]]
+
+ With the reference tensor.
+
+ >>> xs = torch.zeros((3, 2, 4))
+ >>> make_non_pad_mask(lengths, xs)
+ tensor([[[1, 1, 1, 1],
+ [1, 1, 1, 1]],
+ [[1, 1, 1, 0],
+ [1, 1, 1, 0]],
+ [[1, 1, 0, 0],
+ [1, 1, 0, 0]]], dtype=torch.uint8)
+ >>> xs = torch.zeros((3, 2, 6))
+ >>> make_non_pad_mask(lengths, xs)
+ tensor([[[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0]],
+ [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0]],
+ [[1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
+
+ With the reference tensor and dimension indicator.
+
+ >>> xs = torch.zeros((3, 6, 6))
+ >>> make_non_pad_mask(lengths, xs, 1)
+ tensor([[[1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0]],
+ [[1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0]],
+ [[1, 1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1, 1],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
+ >>> make_non_pad_mask(lengths, xs, 2)
+ tensor([[[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 0]],
+ [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 0, 0, 0]],
+ [[1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0],
+ [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
+
+ """
+ return ~make_pad_mask(lengths, xs, length_dim)
+
+
+def mask_by_length(xs, lengths, fill=0):
+ """Mask tensor according to length.
+
+ Args:
+ xs (Tensor): Batch of input tensor (B, `*`).
+ lengths (LongTensor or List): Batch of lengths (B,).
+ fill (int or float): Value to fill masked part.
+
+ Returns:
+ Tensor: Batch of masked input tensor (B, `*`).
+
+ Examples:
+ >>> x = torch.arange(5).repeat(3, 1) + 1
+ >>> x
+ tensor([[1, 2, 3, 4, 5],
+ [1, 2, 3, 4, 5],
+ [1, 2, 3, 4, 5]])
+ >>> lengths = [5, 3, 2]
+ >>> mask_by_length(x, lengths)
+ tensor([[1, 2, 3, 4, 5],
+ [1, 2, 3, 0, 0],
+ [1, 2, 0, 0, 0]])
+
+ """
+ assert xs.size(0) == len(lengths)
+ ret = xs.data.new(*xs.size()).fill_(fill)
+ for i, l in enumerate(lengths):
+ ret[i, :l] = xs[i, :l]
+ return ret
+
+
+def th_accuracy(pad_outputs, pad_targets, ignore_label):
+ """Calculate accuracy.
+
+ Args:
+ pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
+ pad_targets (LongTensor): Target label tensors (B, Lmax, D).
+ ignore_label (int): Ignore label id.
+
+ Returns:
+ float: Accuracy value (0.0 - 1.0).
+
+ """
+ pad_pred = pad_outputs.view(
+ pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
+ ).argmax(2)
+ mask = pad_targets != ignore_label
+ numerator = torch.sum(
+ pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
+ )
+ denominator = torch.sum(mask)
+ return float(numerator) / float(denominator)
+
+
+def to_torch_tensor(x):
+ """Change to torch.Tensor or ComplexTensor from numpy.ndarray.
+
+ Args:
+ x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
+
+ Returns:
+ Tensor or ComplexTensor: Type converted inputs.
+
+ Examples:
+ >>> xs = np.ones(3, dtype=np.float32)
+ >>> xs = to_torch_tensor(xs)
+ tensor([1., 1., 1.])
+ >>> xs = torch.ones(3, 4, 5)
+ >>> assert to_torch_tensor(xs) is xs
+ >>> xs = {'real': xs, 'imag': xs}
+ >>> to_torch_tensor(xs)
+ ComplexTensor(
+ Real:
+ tensor([1., 1., 1.])
+ Imag;
+ tensor([1., 1., 1.])
+ )
+
+ """
+ # If numpy, change to torch tensor
+ if isinstance(x, np.ndarray):
+ if x.dtype.kind == "c":
+ # Dynamically importing because torch_complex requires python3
+ from torch_complex.tensor import ComplexTensor
+
+ return ComplexTensor(x)
+ else:
+ return torch.from_numpy(x)
+
+ # If {'real': ..., 'imag': ...}, convert to ComplexTensor
+ elif isinstance(x, dict):
+ # Dynamically importing because torch_complex requires python3
+ from torch_complex.tensor import ComplexTensor
+
+ if "real" not in x or "imag" not in x:
+ raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
+ # Relative importing because of using python3 syntax
+ return ComplexTensor(x["real"], x["imag"])
+
+ # If torch.Tensor, as it is
+ elif isinstance(x, torch.Tensor):
+ return x
+
+ else:
+ error = (
+ "x must be numpy.ndarray, torch.Tensor or a dict like "
+ "{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
+ "but got {}".format(type(x))
+ )
+ try:
+ from torch_complex.tensor import ComplexTensor
+ except Exception:
+ # If PY2
+ raise ValueError(error)
+ else:
+ # If PY3
+ if isinstance(x, ComplexTensor):
+ return x
+ else:
+ raise ValueError(error)
+
+
+def get_subsample(train_args, mode, arch):
+ """Parse the subsampling factors from the args for the specified `mode` and `arch`.
+
+ Args:
+ train_args: argument Namespace containing options.
+ mode: one of ('asr', 'mt', 'st')
+ arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
+
+ Returns:
+ np.ndarray / List[np.ndarray]: subsampling factors.
+ """
+ if arch == "transformer":
+ return np.array([1])
+
+ elif mode == "mt" and arch == "rnn":
+ # +1 means input (+1) and layers outputs (train_args.elayer)
+ subsample = np.ones(train_args.elayers + 1, dtype=np.int64)
+ logging.warning("Subsampling is not performed for machine translation.")
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
+ return subsample
+
+ elif (
+ (mode == "asr" and arch in ("rnn", "rnn-t"))
+ or (mode == "mt" and arch == "rnn")
+ or (mode == "st" and arch == "rnn")
+ ):
+ subsample = np.ones(train_args.elayers + 1, dtype=np.int64)
+ if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
+ ss = train_args.subsample.split("_")
+ for j in range(min(train_args.elayers + 1, len(ss))):
+ subsample[j] = int(ss[j])
+ else:
+ logging.warning(
+ "Subsampling is not performed for vgg*. "
+ "It is performed in max pooling layers at CNN."
+ )
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
+ return subsample
+
+ elif mode == "asr" and arch == "rnn_mix":
+ subsample = np.ones(
+ train_args.elayers_sd + train_args.elayers + 1, dtype=np.int64
+ )
+ if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
+ ss = train_args.subsample.split("_")
+ for j in range(
+ min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
+ ):
+ subsample[j] = int(ss[j])
+ else:
+ logging.warning(
+ "Subsampling is not performed for vgg*. "
+ "It is performed in max pooling layers at CNN."
+ )
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
+ return subsample
+
+ elif mode == "asr" and arch == "rnn_mulenc":
+ subsample_list = []
+ for idx in range(train_args.num_encs):
+ subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int64)
+ if train_args.etype[idx].endswith("p") and not train_args.etype[
+ idx
+ ].startswith("vgg"):
+ ss = train_args.subsample[idx].split("_")
+ for j in range(min(train_args.elayers[idx] + 1, len(ss))):
+ subsample[j] = int(ss[j])
+ else:
+ logging.warning(
+ "Encoder %d: Subsampling is not performed for vgg*. "
+ "It is performed in max pooling layers at CNN.",
+ idx + 1,
+ )
+ logging.info("subsample: " + " ".join([str(x) for x in subsample]))
+ subsample_list.append(subsample)
+ return subsample_list
+
+ else:
+ raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
+
+
+def rename_state_dict(
+ old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
+):
+ """Replace keys of old prefix with new prefix in state dict."""
+ # need this list not to break the dict iterator
+ old_keys = [k for k in state_dict if k.startswith(old_prefix)]
+ if len(old_keys) > 0:
+ logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
+ for k in old_keys:
+ v = state_dict.pop(k)
+ new_k = k.replace(old_prefix, new_prefix)
+ state_dict[new_k] = v
+
+
+def get_activation(act):
+ """Return activation function."""
+ # Lazy load to avoid unused import
+ from espnet.nets.pytorch_backend.conformer.swish import Swish
+
+ activation_funcs = {
+ "hardtanh": torch.nn.Hardtanh,
+ "tanh": torch.nn.Tanh,
+ "relu": torch.nn.ReLU,
+ "selu": torch.nn.SELU,
+ "swish": Swish,
+ }
+
+ return activation_funcs[act]()
\ No newline at end of file
diff --git a/look2hear/utils/parser_utils.py b/look2hear/utils/parser_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cec1eaad6e74a24fd94be0c34ff54b001b22628
--- /dev/null
+++ b/look2hear/utils/parser_utils.py
@@ -0,0 +1,151 @@
+import sys
+import argparse
+
+
+def prepare_parser_from_dict(dic, parser=None):
+ """Prepare an argparser from a dictionary.
+
+ Args:
+ dic (dict): Two-level config dictionary with unique bottom-level keys.
+ parser (argparse.ArgumentParser, optional): If a parser already
+ exists, add the keys from the dictionary on the top of it.
+
+ Returns:
+ argparse.ArgumentParser:
+ Parser instance with groups corresponding to the first level keys
+ and arguments corresponding to the second level keys with default
+ values given by the values.
+ """
+
+ def standardized_entry_type(value):
+ """If the default value is None, replace NoneType by str_int_float.
+ If the default value is boolean, look for boolean strings."""
+ if value is None:
+ return str_int_float
+ if isinstance(str2bool(value), bool):
+ return str2bool_arg
+ return type(value)
+
+ if parser is None:
+ parser = argparse.ArgumentParser()
+ for k in dic.keys():
+ group = parser.add_argument_group(k)
+ if isinstance(dic[k], list):
+ entry_type = standardized_entry_type(dic[k])
+ group.add_argument("--" + k, default=dic[k], type=entry_type)
+ elif isinstance(dic[k], dict):
+ for kk in dic[k].keys():
+ entry_type = standardized_entry_type(dic[k][kk])
+ group.add_argument("--" + kk, default=dic[k][kk], type=entry_type)
+ elif isinstance(dic[k], str):
+ entry_type = standardized_entry_type(dic[k])
+ group.add_argument("--" + k, default=dic[k], type=entry_type)
+ return parser
+
+
+def str_int_float(value):
+ """Type to convert strings to int, float (in this order) if possible.
+
+ Args:
+ value (str): Value to convert.
+
+ Returns:
+ int, float, str: Converted value.
+ """
+ if isint(value):
+ return int(value)
+ if isfloat(value):
+ return float(value)
+ elif isinstance(value, str):
+ return value
+
+
+def str2bool(value):
+ """Type to convert strings to Boolean (returns input if not boolean)"""
+ if not isinstance(value, str):
+ return value
+ if value.lower() in ("yes", "true", "y", "1"):
+ return True
+ elif value.lower() in ("no", "false", "n", "0"):
+ return False
+ else:
+ return value
+
+
+def str2bool_arg(value):
+ """Argparse type to convert strings to Boolean"""
+ value = str2bool(value)
+ if isinstance(value, bool):
+ return value
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+
+def isfloat(value):
+ """Computes whether `value` can be cast to a float.
+
+ Args:
+ value (str): Value to check.
+
+ Returns:
+ bool: Whether `value` can be cast to a float.
+
+ """
+ try:
+ float(value)
+ return True
+ except ValueError:
+ return False
+
+
+def isint(value):
+ """Computes whether `value` can be cast to an int
+
+ Args:
+ value (str): Value to check.
+
+ Returns:
+ bool: Whether `value` can be cast to an int.
+
+ """
+ try:
+ int(value)
+ return True
+ except ValueError:
+ return False
+
+
+def parse_args_as_dict(parser, return_plain_args=False, args=None):
+ """Get a dict of dicts out of process `parser.parse_args()`
+
+ Top-level keys corresponding to groups and bottom-level keys corresponding
+ to arguments. Under `'main_args'`, the arguments which don't belong to a
+ argparse group (i.e main arguments defined before parsing from a dict) can
+ be found.
+
+ Args:
+ parser (argparse.ArgumentParser): ArgumentParser instance containing
+ groups. Output of `prepare_parser_from_dict`.
+ return_plain_args (bool): Whether to return the output or
+ `parser.parse_args()`.
+ args (list): List of arguments as read from the command line.
+ Used for unit testing.
+
+ Returns:
+ dict:
+ Dictionary of dictionaries containing the arguments. Optionally the
+ direct output `parser.parse_args()`.
+ """
+ args = parser.parse_args(args=args)
+ args_dic = {}
+ for group in parser._action_groups:
+ group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions}
+ args_dic[group.title] = group_dict
+ if sys.version_info.minor == 10:
+ args_dic["main_args"] = args_dic["positional arguments"]
+ del args_dic["positional arguments"]
+ else:
+ args_dic["main_args"] = args_dic["optional arguments"]
+ del args_dic["optional arguments"]
+ if return_plain_args:
+ return args_dic, args
+ return args_dic
diff --git a/look2hear/utils/separator.py b/look2hear/utils/separator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f38077bad5ef49d7af77e9aff4bc1476730036af
--- /dev/null
+++ b/look2hear/utils/separator.py
@@ -0,0 +1,69 @@
+import os
+import warnings
+import torch
+import numpy as np
+import soundfile as sf
+
+
+def get_device(tensor_or_module, default=None):
+ if hasattr(tensor_or_module, "device"):
+ return tensor_or_module.device
+ elif hasattr(tensor_or_module, "parameters"):
+ return next(tensor_or_module.parameters()).device
+ elif default is None:
+ raise TypeError(
+ f"Don't know how to get device of {type(tensor_or_module)} object"
+ )
+ else:
+ return torch.device(default)
+
+
+class Separator:
+ def forward_wav(self, wav, **kwargs):
+ raise NotImplementedError
+
+ def sample_rate(self):
+ raise NotImplementedError
+
+
+def separate(model, wav, **kwargs):
+ if isinstance(wav, np.ndarray):
+ return numpy_separate(model, wav, **kwargs)
+ elif isinstance(wav, torch.Tensor):
+ return torch_separate(model, wav, **kwargs)
+ else:
+ raise ValueError(
+ f"Only support filenames, numpy arrays and torch tensors, received {type(wav)}"
+ )
+
+
+@torch.no_grad()
+def torch_separate(model: Separator, wav: torch.Tensor, **kwargs) -> torch.Tensor:
+ """Core logic of `separate`."""
+ if model.in_channels is not None and wav.shape[-2] != model.in_channels:
+ raise RuntimeError(
+ f"Model supports {model.in_channels}-channel inputs but found audio with {wav.shape[-2]} channels."
+ f"Please match the number of channels."
+ )
+ # Handle device placement
+ input_device = get_device(wav, default="cpu")
+ model_device = get_device(model, default="cpu")
+ wav = wav.to(model_device)
+ # Forward
+ separate_func = getattr(model, "forward_wav", model)
+ out_wavs = separate_func(wav, **kwargs)
+
+ # FIXME: for now this is the best we can do.
+ out_wavs *= wav.abs().sum() / (out_wavs.abs().sum())
+
+ # Back to input device (and numpy if necessary)
+ out_wavs = out_wavs.to(input_device)
+ return out_wavs
+
+
+def numpy_separate(model: Separator, wav: np.ndarray, **kwargs) -> np.ndarray:
+ """Numpy interface to `separate`."""
+ wav = torch.from_numpy(wav)
+ out_wavs = torch_separate(model, wav, **kwargs)
+ out_wavs = out_wavs.data.numpy()
+ return out_wavs
diff --git a/look2hear/utils/stft.py b/look2hear/utils/stft.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bef618093aa219f59cbd967f4ed134d20b0837e
--- /dev/null
+++ b/look2hear/utils/stft.py
@@ -0,0 +1,797 @@
+# Copyright 2019 Jian Wu
+# License: Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as tf
+import librosa.filters as filters
+
+from typing import Optional, Tuple
+from distutils.version import LooseVersion
+
+EPSILON = float(np.finfo(np.float32).eps)
+TORCH_VERSION = th.__version__
+
+if TORCH_VERSION >= LooseVersion("1.7"):
+ from torch.fft import fft as fft_func
+else:
+ pass
+
+
+def export_jit(transform: nn.Module) -> nn.Module:
+ """
+ Export transform module for inference
+ """
+ export_out = [module for module in transform if module.exportable()]
+ return nn.Sequential(*export_out)
+
+
+def init_window(wnd: str, frame_len: int, device: th.device = "cpu") -> th.Tensor:
+ """
+ Return window coefficient
+ Args:
+ wnd: window name
+ frame_len: length of the frame
+ """
+
+ def sqrthann(frame_len, periodic=True):
+ return th.hann_window(frame_len, periodic=periodic) ** 0.5
+
+ if wnd not in ["bartlett", "hann", "hamm", "blackman", "rect", "sqrthann"]:
+ raise RuntimeError(f"Unknown window type: {wnd}")
+
+ wnd_tpl = {
+ "sqrthann": sqrthann,
+ "hann": th.hann_window,
+ "hamm": th.hamming_window,
+ "blackman": th.blackman_window,
+ "bartlett": th.bartlett_window,
+ "rect": th.ones,
+ }
+ if wnd != "rect":
+ # match with librosa
+ c = wnd_tpl[wnd](frame_len, periodic=True)
+ else:
+ c = wnd_tpl[wnd](frame_len)
+ return c.to(device)
+
+
+def init_kernel(
+ frame_len: int,
+ frame_hop: int,
+ window: th.Tensor,
+ round_pow_of_two: bool = True,
+ normalized: bool = False,
+ inverse: bool = False,
+ mode: str = "librosa",
+) -> Tuple[th.Tensor, th.Tensor]:
+ """
+ Return STFT kernels
+ Args:
+ frame_len: length of the frame
+ frame_hop: hop size between frames
+ window: window tensor
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
+ normalized: return normalized DFT matrix
+ inverse: return iDFT matrix
+ mode: framing mode (librosa or kaldi)
+ """
+ if mode not in ["librosa", "kaldi"]:
+ raise ValueError(f"Unsupported mode: {mode}")
+ # FFT size: B
+ if round_pow_of_two or mode == "kaldi":
+ fft_size = 2 ** math.ceil(math.log2(frame_len))
+ else:
+ fft_size = frame_len
+ # center padding window if needed
+ if mode == "librosa" and fft_size != frame_len:
+ lpad = (fft_size - frame_len) // 2
+ window = tf.pad(window, (lpad, fft_size - frame_len - lpad))
+ if normalized:
+ # make K^H * K = I
+ S = fft_size ** 0.5
+ else:
+ S = 1
+ # W x B x 2
+ if TORCH_VERSION >= LooseVersion("1.7"):
+ K = fft_func(th.eye(fft_size) / S, dim=-1)
+ K = th.stack([K.real, K.imag], dim=-1)
+ else:
+ I = th.stack([th.eye(fft_size), th.zeros(fft_size, fft_size)], dim=-1)
+ K = th.fft(I / S, 1)
+ if mode == "kaldi":
+ K = K[:frame_len]
+ if inverse and not normalized:
+ # to make K^H * K = I
+ K = K / fft_size
+ # 2 x B x W
+ K = th.transpose(K, 0, 2)
+ # 2B x 1 x W
+ K = th.reshape(K, (fft_size * 2, 1, K.shape[-1]))
+ return K.to(window.device), window
+
+
+def mel_filter(
+ frame_len: int,
+ round_pow_of_two: bool = True,
+ num_bins: Optional[int] = None,
+ sr: int = 16000,
+ num_mels: int = 80,
+ fmin: float = 0.0,
+ fmax: Optional[float] = None,
+ norm: bool = False,
+) -> th.Tensor:
+ """
+ Return mel filter coefficients
+ Args:
+ frame_len: length of the frame
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
+ num_bins: number of the frequency bins produced by STFT
+ num_mels: number of the mel bands
+ fmin: lowest frequency (in Hz)
+ fmax: highest frequency (in Hz)
+ norm: normalize the mel filter coefficients
+ """
+ # FFT points
+ if num_bins is None:
+ N = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
+ else:
+ N = (num_bins - 1) * 2
+ # fmin & fmax
+ freq_upper = sr // 2
+ if fmax is None:
+ fmax = freq_upper
+ else:
+ fmax = min(fmax + freq_upper if fmax < 0 else fmax, freq_upper)
+ fmin = max(0, fmin)
+ # mel filter coefficients
+ mel = filters.mel(
+ sr,
+ N,
+ n_mels=num_mels,
+ fmax=fmax,
+ fmin=fmin,
+ htk=True,
+ norm="slaney" if norm else None,
+ )
+ # num_mels x (N // 2 + 1)
+ return th.tensor(mel, dtype=th.float32)
+
+
+def speed_perturb_filter(
+ src_sr: int, dst_sr: int, cutoff_ratio: float = 0.95, num_zeros: int = 64
+) -> th.Tensor:
+ """
+ Return speed perturb filters, reference:
+ https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py
+ Args:
+ src_sr: sample rate of the source signal
+ dst_sr: sample rate of the target signal
+ Return:
+ weight (Tensor): coefficients of the filter
+ """
+ if src_sr == dst_sr:
+ raise ValueError(f"src_sr should not be equal to dst_sr: {src_sr}/{dst_sr}")
+ gcd = math.gcd(src_sr, dst_sr)
+ src_sr = src_sr // gcd
+ dst_sr = dst_sr // gcd
+ if src_sr == 1 or dst_sr == 1:
+ raise ValueError("do not support integer downsample/upsample")
+ zeros_per_block = min(src_sr, dst_sr) * cutoff_ratio
+ padding = 1 + int(num_zeros / zeros_per_block)
+ # dst_sr x src_sr x K
+ times = (
+ np.arange(dst_sr)[:, None, None] / float(dst_sr)
+ - np.arange(src_sr)[None, :, None] / float(src_sr)
+ - np.arange(2 * padding + 1)[None, None, :]
+ + padding
+ )
+ window = np.heaviside(1 - np.abs(times / padding), 0.0) * (
+ 0.5 + 0.5 * np.cos(times / padding * math.pi)
+ )
+ weight = np.sinc(times * zeros_per_block) * window * zeros_per_block / float(src_sr)
+ return th.tensor(weight, dtype=th.float32)
+
+
+def splice_feature(
+ feats: th.Tensor, lctx: int = 1, rctx: int = 1, op: str = "cat"
+) -> th.Tensor:
+ """
+ Splice feature
+ Args:
+ feats (Tensor): N x ... x T x F, original feature
+ lctx: left context
+ rctx: right context
+ op: operator on feature context
+ Return:
+ splice (Tensor): feature with context padded
+ """
+ if lctx + rctx == 0:
+ return feats
+ if op not in ["cat", "stack"]:
+ raise ValueError(f"Unknown op for feature splicing: {op}")
+ # [N x ... x T x F, ...]
+ ctx = []
+ T = feats.shape[-2]
+ for c in range(-lctx, rctx + 1):
+ idx = th.arange(c, c + T, device=feats.device, dtype=th.int64)
+ idx = th.clamp(idx, min=0, max=T - 1)
+ ctx.append(th.index_select(feats, -2, idx))
+ if op == "cat":
+ # N x ... x T x FD
+ splice = th.cat(ctx, -1)
+ else:
+ # N x ... x T x F x D
+ splice = th.stack(ctx, -1)
+ return splice
+
+
+def _forward_stft(
+ wav: th.Tensor,
+ kernel: th.Tensor,
+ window: th.Tensor,
+ return_polar: bool = False,
+ pre_emphasis: float = 0,
+ frame_hop: int = 256,
+ onesided: bool = False,
+ center: bool = False,
+ eps: float = EPSILON,
+) -> th.Tensor:
+ """
+ STFT function implemented by conv1d (not efficient, but we don't care during training)
+ Args:
+ wav (Tensor): N x (C) x S
+ kernel (Tensor): STFT transform kernels, from init_kernel(...)
+ return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor
+ pre_emphasis: factor of preemphasis
+ frame_hop: frame hop size in number samples
+ onesided: return half FFT bins
+ center: if true, we assumed to have centered frames
+ Return:
+ transform (Tensor): STFT transform results
+ """
+ wav_dim = wav.dim()
+ if wav_dim not in [2, 3]:
+ raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D")
+ # if N x S, reshape N x 1 x S
+ # else: reshape NC x 1 x S
+ N, S = wav.shape[0], wav.shape[-1]
+ wav = wav.view(-1, 1, S)
+ # NC x 1 x S+2P
+ if center:
+ pad = kernel.shape[-1] // 2
+ # NOTE: match with librosa
+ wav = tf.pad(wav, (pad, pad), mode="reflect")
+ # STFT
+ kernel = kernel * window
+ if pre_emphasis > 0:
+ # NC x W x T
+ frames = tf.unfold(
+ wav[:, None], (1, kernel.shape[-1]), stride=frame_hop, padding=0
+ )
+ # follow Kaldi's Preemphasize
+ frames[:, 1:] = frames[:, 1:] - pre_emphasis * frames[:, :-1]
+ frames[:, 0] *= 1 - pre_emphasis
+ # 1 x 2B x W, NC x W x T, NC x 2B x T
+ packed = th.matmul(kernel[:, 0][None, ...], frames)
+ else:
+ packed = tf.conv1d(wav, kernel, stride=frame_hop, padding=0)
+ # NC x 2B x T => N x C x 2B x T
+ if wav_dim == 3:
+ packed = packed.view(N, -1, packed.shape[-2], packed.shape[-1])
+ # N x (C) x B x T
+ real, imag = th.chunk(packed, 2, dim=-2)
+ # N x (C) x B/2+1 x T
+ if onesided:
+ num_bins = kernel.shape[0] // 4 + 1
+ real = real[..., :num_bins, :]
+ imag = imag[..., :num_bins, :]
+ if return_polar:
+ mag = (real ** 2 + imag ** 2 + eps) ** 0.5
+ pha = th.atan2(imag, real)
+ return th.stack([mag, pha], dim=-1)
+ else:
+ return th.stack([real, imag], dim=-1)
+
+
+def _inverse_stft(
+ transform: th.Tensor,
+ kernel: th.Tensor,
+ window: th.Tensor,
+ return_polar: bool = False,
+ frame_hop: int = 256,
+ onesided: bool = False,
+ center: bool = False,
+ eps: float = EPSILON,
+) -> th.Tensor:
+ """
+ iSTFT function implemented by conv1d
+ Args:
+ transform (Tensor): STFT transform results
+ kernel (Tensor): STFT transform kernels, from init_kernel(...)
+ return_polar (bool): keep same with the one in _forward_stft
+ frame_hop: frame hop size in number samples
+ onesided: return half FFT bins
+ center: used in _forward_stft
+ Return:
+ wav (Tensor), N x S
+ """
+ # (N) x F x T x 2
+ transform_dim = transform.dim()
+ # if F x T x 2, reshape 1 x F x T x 2
+ if transform_dim == 3:
+ transform = th.unsqueeze(transform, 0)
+ if transform_dim != 4:
+ raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D")
+
+ if return_polar:
+ real = transform[..., 0] * th.cos(transform[..., 1])
+ imag = transform[..., 0] * th.sin(transform[..., 1])
+ else:
+ real, imag = transform[..., 0], transform[..., 1]
+
+ if onesided:
+ # [self.num_bins - 2, ..., 1]
+ reverse = range(kernel.shape[0] // 4 - 1, 0, -1)
+ # extend matrix: N x B x T
+ real = th.cat([real, real[:, reverse]], 1)
+ imag = th.cat([imag, -imag[:, reverse]], 1)
+ # pack: N x 2B x T
+ packed = th.cat([real, imag], dim=1)
+ # N x 1 x T
+ wav = tf.conv_transpose1d(packed, kernel * window, stride=frame_hop, padding=0)
+ # normalized audio samples
+ # refer: https://github.com/pytorch/audio/blob/2ebbbf511fb1e6c47b59fd32ad7e66023fa0dff1/torchaudio/functional.py#L171
+ num_frames = packed.shape[-1]
+ win_length = window.shape[0]
+ # W x T
+ win = th.repeat_interleave(window[..., None] ** 2, num_frames, dim=-1)
+ # Do OLA on windows
+ # v1)
+ I = th.eye(win_length, device=win.device)[:, None]
+ denorm = tf.conv_transpose1d(win[None, ...], I, stride=frame_hop, padding=0)
+ # v2)
+ # num_samples = (num_frames - 1) * frame_hop + win_length
+ # denorm = tf.fold(win[None, ...], (num_samples, 1), (win_length, 1),
+ # stride=frame_hop)[..., 0]
+ if center:
+ pad = kernel.shape[-1] // 2
+ wav = wav[..., pad:-pad]
+ denorm = denorm[..., pad:-pad]
+ wav = wav / (denorm + eps)
+ # N x S
+ return wav.squeeze(1)
+
+
+def _pytorch_stft(
+ wav: th.Tensor,
+ frame_len: int,
+ frame_hop: int,
+ n_fft: int = 512,
+ return_polar: bool = False,
+ window: str = "sqrthann",
+ normalized: bool = False,
+ onesided: bool = True,
+ center: bool = False,
+ eps: float = EPSILON,
+) -> th.Tensor:
+ """
+ Wrapper of PyTorch STFT function
+ Args:
+ wav (Tensor): source audio signal
+ frame_len: length of the frame
+ frame_hop: hop size between frames
+ n_fft: number of the FFT size
+ return_polar: return the results in polar coordinate
+ window: window tensor
+ center: same definition with the parameter in librosa.stft
+ normalized: use normalized DFT kernel
+ onesided: output onesided STFT
+ Return:
+ transform (Tensor), STFT transform results
+ """
+ if TORCH_VERSION < LooseVersion("1.7"):
+ raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7")
+ wav_dim = wav.dim()
+ if wav_dim not in [2, 3]:
+ raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D")
+ # if N x C x S, reshape NC x S
+ wav = wav.view(-1, wav.shape[-1])
+ # STFT: N x F x T x 2
+ stft = th.stft(
+ wav,
+ n_fft,
+ hop_length=frame_hop,
+ win_length=window.shape[-1],
+ window=window,
+ center=center,
+ normalized=normalized,
+ onesided=onesided,
+ return_complex=False,
+ )
+ if wav_dim == 3:
+ N, F, T, _ = stft.shape
+ stft = stft.view(N, -1, F, T, 2)
+ # N x (C) x F x T x 2
+ if not return_polar:
+ return stft
+ # N x (C) x F x T
+ real, imag = stft[..., 0], stft[..., 1]
+ mag = (real ** 2 + imag ** 2 + eps) ** 0.5
+ pha = th.atan2(imag, real)
+ return th.stack([mag, pha], dim=-1)
+
+
+def _pytorch_istft(
+ transform: th.Tensor,
+ frame_len: int,
+ frame_hop: int,
+ window: th.Tensor,
+ n_fft: int = 512,
+ return_polar: bool = False,
+ normalized: bool = False,
+ onesided: bool = True,
+ center: bool = False,
+ eps: float = EPSILON,
+) -> th.Tensor:
+ """
+ Wrapper of PyTorch iSTFT function
+ Args:
+ transform (Tensor): results of STFT
+ frame_len: length of the frame
+ frame_hop: hop size between frames
+ window: window tensor
+ n_fft: number of the FFT size
+ return_polar: keep same with _pytorch_stft
+ center: same definition with the parameter in librosa.stft
+ normalized: use normalized DFT kernel
+ onesided: output onesided STFT
+ Return:
+ wav (Tensor): synthetic audio
+ """
+ if TORCH_VERSION < LooseVersion("1.7"):
+ raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7")
+
+ transform_dim = transform.dim()
+ # if F x T x 2, reshape 1 x F x T x 2
+ if transform_dim == 3:
+ transform = th.unsqueeze(transform, 0)
+ if transform_dim != 4:
+ raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D")
+
+ if return_polar:
+ real = transform[..., 0] * th.cos(transform[..., 1])
+ imag = transform[..., 0] * th.sin(transform[..., 1])
+ transform = th.stack([real, imag], -1)
+ # stft is a complex tensor of PyTorch
+ stft = th.view_as_complex(transform)
+ # (N) x S
+ wav = th.istft(
+ stft,
+ n_fft,
+ hop_length=frame_hop,
+ win_length=window.shape[-1],
+ window=window,
+ center=center,
+ normalized=normalized,
+ onesided=onesided,
+ return_complex=False,
+ )
+ return wav
+
+
+def forward_stft(
+ wav: th.Tensor,
+ frame_len: int,
+ frame_hop: int,
+ window: str = "sqrthann",
+ round_pow_of_two: bool = True,
+ return_polar: bool = False,
+ pre_emphasis: float = 0,
+ normalized: bool = False,
+ onesided: bool = True,
+ center: bool = False,
+ mode: str = "librosa",
+ eps: float = EPSILON,
+) -> th.Tensor:
+ """
+ STFT function implementation, equals to STFT layer
+ Args:
+ wav: source audio signal
+ frame_len: length of the frame
+ frame_hop: hop size between frames
+ return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor
+ window: window name
+ center: center flag (similar with that in librosa.stft)
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
+ pre_emphasis: factor of preemphasis
+ normalized: use normalized DFT kernel
+ onesided: output onesided STFT
+ inverse: using iDFT kernel (for iSTFT)
+ mode: STFT mode, "kaldi" or "librosa" or "torch"
+ Return:
+ transform: results of STFT
+ """
+ window = init_window(window, frame_len, device=wav.device)
+ if mode == "torch":
+ n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
+ return _pytorch_stft(
+ wav,
+ frame_len,
+ frame_hop,
+ n_fft=n_fft,
+ return_polar=return_polar,
+ window=window,
+ normalized=normalized,
+ onesided=onesided,
+ center=center,
+ eps=eps,
+ )
+ else:
+ kernel, window = init_kernel(
+ frame_len,
+ frame_hop,
+ window=window,
+ round_pow_of_two=round_pow_of_two,
+ normalized=normalized,
+ inverse=False,
+ mode=mode,
+ )
+ return _forward_stft(
+ wav,
+ kernel,
+ window,
+ return_polar=return_polar,
+ frame_hop=frame_hop,
+ pre_emphasis=pre_emphasis,
+ onesided=onesided,
+ center=center,
+ eps=eps,
+ )
+
+
+def inverse_stft(
+ transform: th.Tensor,
+ frame_len: int,
+ frame_hop: int,
+ return_polar: bool = False,
+ window: str = "sqrthann",
+ round_pow_of_two: bool = True,
+ normalized: bool = False,
+ onesided: bool = True,
+ center: bool = False,
+ mode: str = "librosa",
+ eps: float = EPSILON,
+) -> th.Tensor:
+ """
+ iSTFT function implementation, equals to iSTFT layer
+ Args:
+ transform: results of STFT
+ frame_len: length of the frame
+ frame_hop: hop size between frames
+ return_polar: keep same with function forward_stft(...)
+ window: window name
+ center: center flag (similar with that in librosa.stft)
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
+ normalized: use normalized DFT kernel
+ onesided: output onesided STFT
+ mode: STFT mode, "kaldi" or "librosa" or "torch"
+ Return:
+ wav: synthetic signals
+ """
+ window = init_window(window, frame_len, device=transform.device)
+ if mode == "torch":
+ n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
+ return _pytorch_istft(
+ transform,
+ frame_len,
+ frame_hop,
+ n_fft=n_fft,
+ return_polar=return_polar,
+ window=window,
+ normalized=normalized,
+ onesided=onesided,
+ center=center,
+ eps=eps,
+ )
+ else:
+ kernel, window = init_kernel(
+ frame_len,
+ frame_hop,
+ window,
+ round_pow_of_two=round_pow_of_two,
+ normalized=normalized,
+ inverse=True,
+ mode=mode,
+ )
+ return _inverse_stft(
+ transform,
+ kernel,
+ window,
+ return_polar=return_polar,
+ frame_hop=frame_hop,
+ onesided=onesided,
+ center=center,
+ eps=eps,
+ )
+
+
+class STFTBase(nn.Module):
+ """
+ Base layer for (i)STFT
+ Args:
+ frame_len: length of the frame
+ frame_hop: hop size between frames
+ window: window name
+ center: center flag (similar with that in librosa.stft)
+ round_pow_of_two: if true, choose round(#power_of_two) as the FFT size
+ normalized: use normalized DFT kernel
+ pre_emphasis: factor of preemphasis
+ mode: STFT mode, "kaldi" or "librosa" or "torch"
+ onesided: output onesided STFT
+ inverse: using iDFT kernel (for iSTFT)
+ """
+
+ def __init__(
+ self,
+ frame_len: int,
+ frame_hop: int,
+ window: str = "sqrthann",
+ round_pow_of_two: bool = True,
+ normalized: bool = False,
+ pre_emphasis: float = 0,
+ onesided: bool = True,
+ inverse: bool = False,
+ center: bool = False,
+ mode: str = "librosa",
+ ) -> None:
+ super(STFTBase, self).__init__()
+ if mode != "torch":
+ K, w = init_kernel(
+ frame_len,
+ frame_hop,
+ init_window(window, frame_len),
+ round_pow_of_two=round_pow_of_two,
+ normalized=normalized,
+ inverse=inverse,
+ mode=mode,
+ )
+ self.K = nn.Parameter(K, requires_grad=False)
+ self.w = nn.Parameter(w, requires_grad=False)
+ self.num_bins = self.K.shape[0] // 4 + 1
+ self.pre_emphasis = pre_emphasis
+ self.win_length = self.K.shape[2]
+ else:
+ self.K = None
+ w = init_window(window, frame_len)
+ self.w = nn.Parameter(w, requires_grad=False)
+ fft_size = (
+ 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len
+ )
+ self.num_bins = fft_size // 2 + 1
+ self.pre_emphasis = 0
+ self.win_length = fft_size
+ self.frame_len = frame_len
+ self.frame_hop = frame_hop
+ self.window = window
+ self.normalized = normalized
+ self.onesided = onesided
+ self.center = center
+ self.mode = mode
+
+ def num_frames(self, wav_len: th.Tensor) -> th.Tensor:
+ """
+ Compute number of the frames
+ """
+ assert th.sum(wav_len <= self.win_length) == 0
+ if self.center:
+ wav_len += self.win_length
+ return (
+ th.div(wav_len - self.win_length, self.frame_hop, rounding_mode="trunc") + 1
+ )
+
+ def extra_repr(self) -> str:
+ str_repr = (
+ f"num_bins={self.num_bins}, win_length={self.win_length}, "
+ + f"stride={self.frame_hop}, window={self.window}, "
+ + f"center={self.center}, mode={self.mode}"
+ )
+ if not self.onesided:
+ str_repr += f", onesided={self.onesided}"
+ if self.pre_emphasis > 0:
+ str_repr += f", pre_emphasis={self.pre_emphasis}"
+ if self.normalized:
+ str_repr += f", normalized={self.normalized}"
+ return str_repr
+
+
+class STFT(STFTBase):
+ """
+ Short-time Fourier Transform as a Layer
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(STFT, self).__init__(*args, inverse=False, **kwargs)
+
+ def forward(
+ self, wav: th.Tensor, return_polar: bool = False, eps: float = EPSILON
+ ) -> th.Tensor:
+ """
+ Accept (single or multiple channel) raw waveform and output magnitude and phase
+ Args
+ wav (Tensor) input signal, N x (C) x S
+ Return
+ transform (Tensor), N x (C) x F x T x 2
+ """
+ if self.mode == "torch":
+ return _pytorch_stft(
+ wav,
+ self.frame_len,
+ self.frame_hop,
+ n_fft=(self.num_bins - 1) * 2,
+ return_polar=return_polar,
+ window=self.w,
+ normalized=self.normalized,
+ onesided=self.onesided,
+ center=self.center,
+ eps=eps,
+ )
+ else:
+ return _forward_stft(
+ wav,
+ self.K,
+ self.w,
+ return_polar=return_polar,
+ frame_hop=self.frame_hop,
+ pre_emphasis=self.pre_emphasis,
+ onesided=self.onesided,
+ center=self.center,
+ eps=eps,
+ )
+
+
+class iSTFT(STFTBase):
+ """
+ Inverse Short-time Fourier Transform as a Layer
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(iSTFT, self).__init__(*args, inverse=True, **kwargs)
+
+ def forward(
+ self, transform: th.Tensor, return_polar: bool = False, eps: float = EPSILON
+ ) -> th.Tensor:
+ """
+ Accept phase & magnitude and output raw waveform
+ Args
+ transform (Tensor): STFT output, N x F x T x 2
+ Return
+ s (Tensor): N x S
+ """
+ if self.mode == "torch":
+ return _pytorch_istft(
+ transform,
+ self.frame_len,
+ self.frame_hop,
+ n_fft=(self.num_bins - 1) * 2,
+ return_polar=return_polar,
+ window=self.w,
+ normalized=self.normalized,
+ onesided=self.onesided,
+ center=self.center,
+ eps=eps,
+ )
+ else:
+ return _inverse_stft(
+ transform,
+ self.K,
+ self.w,
+ return_polar=return_polar,
+ frame_hop=self.frame_hop,
+ onesided=self.onesided,
+ center=self.center,
+ eps=eps,
+ )
diff --git a/look2hear/utils/torch_utils.py b/look2hear/utils/torch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c3ec70df9e7b6c89f9beb4f5aa1cb8f9ac4cf90
--- /dev/null
+++ b/look2hear/utils/torch_utils.py
@@ -0,0 +1,42 @@
+import torch
+import torch.nn as nn
+
+
+def pad_x_to_y(x, y, axis: int = -1):
+ if axis != -1:
+ raise NotImplementedError
+ inp_len = y.shape[axis]
+ output_len = x.shape[axis]
+ return nn.functional.pad(x, [0, inp_len - output_len])
+
+
+def shape_reconstructed(reconstructed, size):
+ if len(size) == 1:
+ return reconstructed.squeeze(0)
+ return reconstructed
+
+
+def tensors_to_device(tensors, device):
+ """Transfer tensor, dict or list of tensors to device.
+
+ Args:
+ tensors (:class:`torch.Tensor`): May be a single, a list or a
+ dictionary of tensors.
+ device (:class: `torch.device`): the device where to place the tensors.
+
+ Returns:
+ Union [:class:`torch.Tensor`, list, tuple, dict]:
+ Same as input but transferred to device.
+ Goes through lists and dicts and transfers the torch.Tensor to
+ device. Leaves the rest untouched.
+ """
+ if isinstance(tensors, torch.Tensor):
+ return tensors.to(device)
+ elif isinstance(tensors, (list, tuple)):
+ return [tensors_to_device(tens, device) for tens in tensors]
+ elif isinstance(tensors, dict):
+ for key in tensors.keys():
+ tensors[key] = tensors_to_device(tensors[key], device)
+ return tensors
+ else:
+ return tensors
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3bef70d53061bf077728444136b9e52697ec2ab9
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,114 @@
+accelerated-scan==0.2.0
+aiohappyeyeballs==2.4.3
+aiohttp==3.11.7
+aiosignal==1.3.1
+antlr4-python3-runtime==4.9.3
+attrs==24.2.0
+audioread==3.0.1
+av==13.1.0
+blessed==1.20.0
+blessings==1.7
+certifi==2024.8.30
+cffi==1.17.1
+charset-normalizer==3.4.0
+click==8.1.7
+coloredlogs==15.0.1
+conda-pack==0.8.1
+ctranslate2==4.5.0
+datasets==3.2.0
+decorator==5.1.1
+dill==0.3.8
+docker-pycreds==0.4.0
+einops==0.8.0
+fast_bss_eval==0.1.4
+filelock==3.16.1
+flatbuffers==24.3.25
+frozenlist==1.5.0
+fsspec==2024.9.0
+gitdb==4.0.11
+GitPython==3.1.43
+gpustat==1.1.1
+huggingface-hub==0.26.2
+humanfriendly==10.0
+hydra-core==1.3.2
+HyperPyYAML==1.2.2
+idna==3.10
+Jinja2==3.1.4
+joblib==1.4.2
+lazy_loader==0.4
+librosa==0.10.2.post1
+lightning-utilities==0.11.9
+llvmlite==0.43.0
+markdown-it-py==3.0.0
+MarkupSafe==3.0.2
+mdurl==0.1.2
+modelscope==1.20.1
+mpmath==1.3.0
+msgpack==1.1.0
+multidict==6.1.0
+multiprocess==0.70.16
+networkx==3.4.2
+ninja==1.11.1.3
+numba==0.60.0
+numpy==2.0.2
+omegaconf==2.3.0
+onnxruntime==1.20.1
+packaging==24.2
+pandas==2.2.3
+pesq==0.0.4
+pillow==11.0.0
+pip-system-certs==4.0
+platformdirs==4.3.6
+pooch==1.8.2
+propcache==0.2.0
+protobuf==5.28.3
+psutil==6.1.0
+ptflops==0.7.4
+pyarrow==18.1.0
+pycparser==2.22
+Pygments==2.18.0
+pystoi==0.4.1
+python-dateutil==2.9.0.post0
+pytorch-lightning==2.0.2
+pytorch-ranger==0.1.1
+pytz==2024.2
+PyYAML==6.0.2
+regex==2024.11.6
+requests==2.32.3
+rich==13.9.4
+rotary-embedding-torch==0.8.5
+ruamel.yaml==0.18.6
+ruamel.yaml.clib==0.2.12
+safetensors==0.4.5
+scikit-learn==1.5.2
+scipy==1.14.1
+sentencepiece==0.2.0
+sentry-sdk==2.19.0
+setproctitle==1.3.4
+six==1.16.0
+smmap==5.0.1
+soundfile==0.12.1
+soxr==0.5.0.post1
+speechbrain==0.5.14
+sympy==1.13.1
+threadpoolctl==3.5.0
+tokenizers==0.21.0
+torch==2.5.1
+torch-complex==0.4.4
+torch-mir-eval==0.4
+torch-optimizer==0.3.0
+torchaudio==2.5.1
+torchmetrics==1.6.0
+torchvision==0.20.1
+tqdm==4.67.0
+transformers==4.47.1
+triton==3.1.0
+typeguard==2.13.3
+typing_extensions==4.12.2
+tzdata==2024.2
+urllib3==2.2.3
+wandb==0.18.7
+wcwidth==0.2.13
+wrapt==1.17.0
+xxhash==3.5.0
+yarl==1.18.0
diff --git a/test/mix.wav b/test/mix.wav
new file mode 100644
index 0000000000000000000000000000000000000000..eb063704154aff42d7a52bc909f5a04ac315ac3c
--- /dev/null
+++ b/test/mix.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:96096869c977120503b2ef4b53429a4bce510b344d6a274c3371457be854f38e
+size 192044
diff --git a/test/s1.wav b/test/s1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..50ceb11a9e40dd8af0371e21f0d4c9e597bfaf62
--- /dev/null
+++ b/test/s1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2bc86a053f8fd42703f5a402d6c10d71b78121854993896ae2f204646b251cd4
+size 384080
diff --git a/test/s2.wav b/test/s2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..f24d83d5cb77a350720303969356274f9a37b17a
--- /dev/null
+++ b/test/s2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a93d50682bd48181e0fc96f6bda41e923f8426402a29fb3f53f33919af6c381f
+size 384080
diff --git a/test/spk1.wav b/test/spk1.wav
new file mode 100644
index 0000000000000000000000000000000000000000..8163478833cf217a263577c5840d4d1031a459af
--- /dev/null
+++ b/test/spk1.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:19d94b8621ca6111b1f55fd663d357f202d50feeb4d480435e9285bcc0dcf5f7
+size 384080
diff --git a/test/spk2.wav b/test/spk2.wav
new file mode 100644
index 0000000000000000000000000000000000000000..a2b5258a69e599754aa0a72775e72a806fb5e011
--- /dev/null
+++ b/test/spk2.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:aba7da1920c8b736385378ea8e8b4b69e803eb44be353703e4babe60f600c4f0
+size 384080
diff --git a/test/test_mixture_466.wav b/test/test_mixture_466.wav
new file mode 100644
index 0000000000000000000000000000000000000000..c4b4e6f731bbacb07363640c5f26e2a63de3aeae
--- /dev/null
+++ b/test/test_mixture_466.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:27d8a4c46ea934c264af5ec04f8db4f0a3b5d59481d1071b1aa8086ea3a230fd
+size 5292044
diff --git a/test/test_target_dialog_466.wav b/test/test_target_dialog_466.wav
new file mode 100644
index 0000000000000000000000000000000000000000..2ec68e47368662a6ff4aa78391ac3ae89af6a930
--- /dev/null
+++ b/test/test_target_dialog_466.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ca08ef210dcdd4e705789cce4a5de300dbc9cee0c5498a810506cb7a4fdbcf74
+size 10584080
diff --git a/test/test_target_effect_466.wav b/test/test_target_effect_466.wav
new file mode 100644
index 0000000000000000000000000000000000000000..dd8969ae932de2aaab18d4b72ef2c87cf911b3c1
--- /dev/null
+++ b/test/test_target_effect_466.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b2ef06b888f17ec1e13562a2a03d419ba7e955f2ba0634b60f7cf1eabca46ba0
+size 10584080
diff --git a/test/test_target_music_466.wav b/test/test_target_music_466.wav
new file mode 100644
index 0000000000000000000000000000000000000000..7867202d9971f917e1c523c4ec445e7632bf5df2
--- /dev/null
+++ b/test/test_target_music_466.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:cdb26e3ab09a25694d8654f047fde930a9a6e04a0746f4a42e628cf80207479f
+size 10584080