moarafa97 commited on
Commit
64bb0f2
·
verified ·
1 Parent(s): a804948

Upload custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +19 -0
custom_interface.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from speechbrain.pretrained import EncoderClassifier
2
+
3
+ class CustomEncoderWav2vec2Classifier(EncoderClassifier):
4
+ def compute_forward(self, batch, stage):
5
+ wavs, wav_lens = batch.sig
6
+ feats = self.mods.compute_features(wavs)
7
+ if self.mods.normalize:
8
+ feats = self.mods.normalize(feats, wav_lens)
9
+ x = self.mods.encoder(feats)
10
+ outputs = self.mods.classifier(x)
11
+ return outputs
12
+
13
+ def classify_file(self, path):
14
+ signal = self.load_audio(path)
15
+ batch = self.make_batch(signal)
16
+ probs = self.forward(batch)
17
+ score, index = probs.max(1)
18
+ label = self.hparams.label_encoder.decode(index)
19
+ return probs, score.item(), index.item(), label