File size: 51,538 Bytes
5491e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
# Copyright 2024 FBK

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
# the code below contains parts copied from the Conformer implementation in
# https://github.com/hlt-mt/FBK-fairseq/blob/master/examples/speech_to_text/models/conformer.py
import math
from itertools import groupby
from typing import Union, Tuple, Optional

import torch
import transformers
from torch import nn, Tensor
from torch.nn import CrossEntropyLoss, functional as F

from transformers import Speech2TextPreTrainedModel, add_start_docstrings, GenerationMixin, Speech2TextProcessor, \
    Speech2TextTokenizer, Speech2TextFeatureExtractor
from transformers.modeling_outputs import Seq2SeqModelOutput, BaseModelOutput, Seq2SeqLMOutput
from transformers.models.speech_to_text.modeling_speech_to_text import Speech2TextDecoder, \
    SPEECH_TO_TEXT_INPUTS_DOCSTRING, shift_tokens_right
from transformers.utils import replace_return_docstrings, add_start_docstrings_to_model_forward, logging

from .configuration_conformer import Speech2TextConformerConfig


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "Speech2TextConformerConfig"

CONFORMER_START_DOCSTRING = r"""
    This model is an implementation of an attention-based autoregressive encoder-decoder model, in which the encoder
    is a Conformer Encoder and decoder is a Transformer Decoder. The encoder expects 80-feature spectrograms as input
    as the [`Speech2TextModel`] and its implementation follows that of the paper:
    
    `"When Good and Reproducible Results are a Giant with Feet of Clay: The Importance of Software Quality in NLP"
    (Papi, et al, ACL 2024) <https://aclanthology.org/2024.acl-long.200/>`_.
    
    This ensures consistency of results regardless of the presence of padding.
    
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`Speech2TextConformerConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""


class Conv1dSubsampler(nn.Module):
    """Convolutional subsampler: a stack of 1D convolution (along temporal
    dimension) followed by non-linear activation via gated linear units
    (https://arxiv.org/abs/1911.08460)
    """

    def __init__(self, config: Speech2TextConformerConfig):
        super(Conv1dSubsampler, self).__init__()
        self.n_layers = len(config.conv_kernel_sizes)
        in_channels = config.input_feat_per_channel * config.input_channels
        mid_channels = config.conv_channels
        out_channels = config.d_model
        self.conv_layers = nn.ModuleList(
            nn.Conv1d(
                in_channels if i == 0 else mid_channels // 2,
                mid_channels if i < self.n_layers - 1 else out_channels * 2,
                k,
                stride=2,
                padding=k // 2,
            )
            for i, k in enumerate(config.conv_kernel_sizes)
        )

    @staticmethod
    def subsampled_sequence_len(seq_lens, kernel_size=5, padding=1, stride=2):
        compressed_seq_lens = seq_lens.clone()
        return ((compressed_seq_lens.float() - kernel_size + 2 * padding) / stride + 1).floor().long()

    @staticmethod
    def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor:
        bsz, max_lens = lens.size(0), torch.max(lens).item()
        mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
        mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
        return mask

    def forward(self, src_tokens: torch.FloatTensor, padding_mask: torch.IntTensor) -> torch.Tensor:
        x = src_tokens.transpose(1, 2).contiguous()  # B x T x (C x D) -> B x (C x D) x T
        actual_src_lengths = padding_mask.sum(dim=1)
        for conv in self.conv_layers:
            x = conv(x)
            x = nn.functional.glu(x, dim=1)
            actual_src_lengths = self.subsampled_sequence_len(
                actual_src_lengths,
                kernel_size=conv.kernel_size[0],
                padding=conv.padding[0],
                stride=conv.stride[0])
            x = x.masked_fill(
                self.lengths_to_padding_mask(actual_src_lengths).unsqueeze(1), 0)
        x = x.transpose(1, 2).transpose(0, 1).contiguous()  # -> T x B x (C x D)
        return x


class PositionalEncoding(nn.Module):
    """
    Positional Encoding proposed in "Attention Is All You Need".
    "Attention Is All You Need" use sine and cosine functions of different frequencies:
        PE_(pos, 2i)    =  sin(pos / power(10000, 2i / d_model))
        PE_(pos, 2i+1)  =  cos(pos / power(10000, 2i / d_model))
    The version implemented on Fairseq differs slightly from the paper, this implementation is faithful to the
    original one. Please see
    :func:`~fairseq.modules.sinusoidal_positional_embedding.SinusoidalPositionalEmbedding.get_embedding` for more
    details.
    """

    def __init__(self, d_model: int = 512, max_len: int = 10000) -> None:
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model, requires_grad=False)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, length: int) -> Tensor:
        return self.pe[:, :length]


class RelativeMultiHeadAttention(nn.Module):
    """
    Multi-head attention with relative positional encoding.
    This concept was proposed in the `"Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
    <https://arxiv.org/pdf/1901.02860.pdf>`_.

    Args:
        d_model (int): The dimension of model
        num_heads (int): The number of attention heads.
        dropout_p (float): probability of dropout

    Inputs: query, key, value, pos_embedding, mask
        query (batch, time, dim): Tensor containing query vector
        key (batch, time, dim): Tensor containing key vector
        value (batch, time, dim): Tensor containing value vector
        pos_embedding (batch, time, dim): Positional embedding tensor
        mask (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked

    Returns:
        **outputs**: Tensor produces by relative multi head attention module.
    """

    def __init__(
            self,
            d_model: int = 512,
            num_heads: int = 16,
            dropout_p: float = 0.1,
            batch_unsafe_relative_shift: bool = False
    ):
        super(RelativeMultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model % num_heads should be zero."
        self.d_model = d_model
        self.d_head = int(d_model / num_heads)
        self.num_heads = num_heads
        self.sqrt_dim = math.sqrt(d_model)

        self.query_proj = nn.Linear(d_model, d_model)
        nn.init.xavier_uniform_(self.query_proj.weight)
        nn.init.zeros_(self.query_proj.bias)
        self.key_proj = nn.Linear(d_model, d_model)
        nn.init.xavier_uniform_(self.key_proj.weight)
        nn.init.zeros_(self.key_proj.bias)
        self.value_proj = nn.Linear(d_model, d_model)
        nn.init.xavier_uniform_(self.value_proj.weight)
        nn.init.zeros_(self.value_proj.bias)
        self.pos_proj = nn.Linear(d_model, d_model, bias=False)
        nn.init.xavier_uniform_(self.pos_proj.weight)

        self.dropout = nn.Dropout(p=dropout_p)
        # u and v are the trainable parameters of the Transformer-XL attention computation
        self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
        self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
        nn.init.xavier_uniform_(self.u_bias)
        nn.init.xavier_uniform_(self.v_bias)

        self.out_proj = nn.Linear(d_model, d_model)
        nn.init.xavier_uniform_(self.out_proj.weight)
        nn.init.zeros_(self.out_proj.bias)
        self.relative_shift_func = self._relative_shift_unsafe if batch_unsafe_relative_shift else self._relative_shift

    def forward(
            self,
            query: Tensor,
            key: Tensor,
            value: Tensor,
            pos_embedding: Tensor,
            mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        batch_size = value.size(0)

        query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
        key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
        value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
        pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)

        # Attention weights computation using Q + u as in Transformer-XL
        content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
        # Relative positional weights computation using Q + v as in Transformer-XL
        pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
        # Right shifting mechanism described in Transformer-XL
        pos_score = self.relative_shift_func(pos_score, mask)
        # Final attention weights obtained summing the attention with its relative positional embeddings
        score = (content_score + pos_score) / self.sqrt_dim

        if mask is not None:
            mask = mask.unsqueeze(1)
            score.masked_fill_(mask, -1e9 if mask.dtype == torch.float32 else -1e4)

        attn = F.softmax(score, dim=-1)
        # set to 0.0 all attention weights of padding elements
        if mask is not None:
            attn = attn.masked_fill(mask, 0.0)
        attn = self.dropout(attn)

        # Attention computation
        context = torch.matmul(attn, value).transpose(1, 2)
        context = context.contiguous().view(batch_size, -1, self.d_model)

        return self.out_proj(context), attn

    def _relative_shift(self, pos_score: Tensor, padding_mask: Tensor) -> Tensor:
        """
        This methods performs the relative shift operation row-wise.
        Although inefficient, it enforces that each row is shifted accounting its padding,
        which enforces that the result does not change depending on whether a given row
        is padded or not.
        """
        batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
        assert seq_length1 == seq_length2, "Currently we support only self-attention"
        zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
        padded_pos_score = torch.cat([zeros, pos_score], dim=-1)

        seq_lengths = (seq_length1 - (padding_mask[:, :, 0]).sum(-1)).tolist()
        for b_i in range(batch_size):
            padded_batch_pos_scores = padded_pos_score[b_i, :, :seq_lengths[b_i], :seq_lengths[b_i] + 1]
            padded_batch_pos_scores = padded_batch_pos_scores.reshape(num_heads, seq_lengths[b_i] + 1, seq_lengths[b_i])
            pos_score[b_i, :, :seq_lengths[b_i], :seq_lengths[b_i]] = padded_batch_pos_scores[:, 1:, :]
        pos_score.masked_fill_(padding_mask.unsqueeze(1), 0.0)
        return pos_score

    def _relative_shift_unsafe(self, pos_score: Tensor, padding_mask: Tensor) -> Tensor:
        """
         This implementation reflects other open source ones (e.g. fairseq), which
         shift the values from the row above in the batch. Although efficient,
         this leads to inconsistencies in the results, as the same row has different
         values according to whether it is padded (and how much it is) or not.
         """
        batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
        zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
        padded_pos_score = torch.cat([zeros, pos_score], dim=-1)

        padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
        pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)

        return pos_score


class MultiHeadedSelfAttentionModule(nn.Module):
    """
    Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL,
    the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention
    module to generalize better on different input length and the resulting encoder is more robust to the variance of
    the utterance length. Conformer use prenorm residual units with dropout which helps training
    and regularizing deeper models.

    Args:
        d_model (int): The dimension of model
        num_heads (int): The number of attention heads.
        dropout_p (float): probability of dropout

    Inputs: inputs, mask
        x (batch, time, dim): Tensor containing input vector
        mask (batch, time1, time2): Tensor containing indices to be masked

    Returns:
        **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module.
    """
    def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1, batch_unsafe_relative_shift: bool = False):
        super(MultiHeadedSelfAttentionModule, self).__init__()
        self.positional_encoding = PositionalEncoding(d_model)
        self.layer_norm = nn.LayerNorm(d_model)
        self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p, batch_unsafe_relative_shift)
        self.dropout = nn.Dropout(p=dropout_p)

    def forward(
            self, x: Tensor, encoder_padding_mask: Optional[Tensor] = None, output_attention: bool = False
    ) -> Tuple[Tensor, Tensor]:
        batch_size, seq_length, _ = x.size()
        pos_embedding = self.positional_encoding(seq_length)
        pos_embedding = pos_embedding.repeat(batch_size, 1, 1)
        # we need attention padding mask (attn_mask) to be applied during the attention calculation,
        # we obtain it from the encoder_padding_mask (B x T) by repeating it T times (x.shape[1]) and
        # taking the logical or to correctly mask both T x T dimensions
        att_mask = encoder_padding_mask.unsqueeze(1).repeat([1, x.shape[1], 1])
        att_mask = att_mask.logical_or(att_mask.transpose(1, 2))  # B x T x T

        x = self.layer_norm(x)
        outputs, attn = self.attention(x, x, x, pos_embedding=pos_embedding, mask=att_mask)

        return self.dropout(outputs), attn if output_attention else None


class FeedForwardModule(nn.Module):
    """
    Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit
    and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps
    regularizing the network.

    Args:
        encoder_dim (int): Dimension of conformer encoder
        expansion_factor (int): Expansion factor of feed forward module.
        dropout_p (float): Ratio of dropout

    Inputs: inputs
        x (batch, time, dim): Tensor contains input sequences

    Outputs: outputs
        **outputs** (batch, time, dim): Tensor produces by feed forward module.
    """

    def __init__(
            self,
            encoder_dim: int = 512,
            expansion_factor: int = 4,
            dropout_p: float = 0.1,
    ) -> None:
        super(FeedForwardModule, self).__init__()
        self.layernorm = nn.LayerNorm(encoder_dim)
        self.dropout_module = nn.Dropout(p=dropout_p)
        self.first_linear = nn.Linear(encoder_dim, encoder_dim * expansion_factor, bias=True)
        nn.init.xavier_uniform_(self.first_linear.weight)
        nn.init.zeros_(self.first_linear.bias)
        self.second_linear = nn.Linear(encoder_dim * expansion_factor, encoder_dim, bias=True)
        nn.init.xavier_uniform_(self.second_linear.weight)
        nn.init.zeros_(self.second_linear.bias)

    def forward(self, x: Tensor) -> Tensor:
        x = self.layernorm(x)
        x = self.first_linear(x)
        x = F.silu(x)
        x = self.dropout_module(x)
        x = self.second_linear(x)
        x = self.dropout_module(x)
        return x


class ConformerConvModule(nn.Module):
    """
    Conformer convolution module starts with the first pointwise convolution and a gated linear unit (GLU).
    This is followed by a single 1-D depthwise convolution layer. Batchnorm is  deployed just after the convolution
    to aid training deep models. Then, Swift (or SiLu) activation function is applied and followed by the second
    pointwise convolution. The Dropout module is applied in the end.

    Args:
        in_channels (int): Number of channels in the input
        kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
        dropout_p (float, optional): probability of dropout

    Inputs: inputs
        x (batch, time, dim): Tensor contains input sequences

    Outputs: outputs
        **outputs** (batch, time, dim): Tensor produces by conformer convolution module.
    """
    def __init__(
            self,
            in_channels: int,
            kernel_size: int = 31,
            expansion_factor: int = 2,
            dropout_p: float = 0.1,
            no_syncbatchnorm: bool = False,
    ) -> None:
        super(ConformerConvModule, self).__init__()
        assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
        assert expansion_factor == 2, "Currently, only supports expansion_factor 2"
        self.layernorm = nn.LayerNorm(in_channels)
        self.batchnorm = nn.SyncBatchNorm(in_channels) if not no_syncbatchnorm else nn.BatchNorm1d(in_channels)
        self.first_pointwise_conv1d = nn.Conv1d(
            in_channels=in_channels,
            out_channels=in_channels * expansion_factor,
            kernel_size=(1, ),
            stride=(1, ),
            padding=0,
            bias=True,
        )
        self.second_pointwise_conv1d = nn.Conv1d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(1, ),
            stride=(1, ),
            padding=0,
            bias=True,
        )
        self.depthwise_conv1d = nn.Conv1d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(kernel_size, ),
            stride=(1, ),
            groups=in_channels,
            padding=(kernel_size - 1) // 2,
            bias=False,
        )
        self.dropout_module = nn.Dropout(p=dropout_p)

    def forward(self, x: Tensor, encoder_padding_mask: Tensor) -> Tensor:
        x = self.layernorm(x).transpose(1, 2)
        x = self.first_pointwise_conv1d(x)
        x = F.glu(x, dim=1)
        bool_padding_mask = None
        if encoder_padding_mask is not None:
            bool_padding_mask = encoder_padding_mask.unsqueeze(1).bool()
        if bool_padding_mask is not None:
            x = x.float().masked_fill(bool_padding_mask, 0.0)
        x = self.depthwise_conv1d(x)
        if bool_padding_mask is not None:
            x = x.float().masked_fill(bool_padding_mask, 0.0)
        x = self.batchnorm(x)
        if bool_padding_mask is not None:
            x = x.float().masked_fill(bool_padding_mask, 0.0)
        x = F.silu(x)
        x = self.second_pointwise_conv1d(x)
        if bool_padding_mask is not None:
            x = x.float().masked_fill(bool_padding_mask, 0.0)
        x = self.dropout_module(x)
        return x.transpose(1, 2)


class ConformerEncoderLayer(nn.Module):
    """
    Conformer block contains two Feed Forward modules sandwiching the Multi-Headed Self-Attention module
    and the Convolution module. This sandwich structure is inspired by Macaron-Net, which proposes replacing
    the original feed-forward layer in the Transformer block into two half-step feed-forward layers,
    one before the attention layer and one after.

    Args:
        encoder_dim (int, optional): Dimension of conformer encoder
        num_attention_heads (int, optional): Number of attention heads
        feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
        conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
        feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
        attention_dropout_p (float, optional): Probability of attention module dropout
        conv_dropout_p (float, optional): Probability of conformer convolution module dropout
        conv_kernel_size (int or tuple, optional): Size of the convolving kernel
        half_step_residual (bool): Flag indication whether to use half step residual or not

    Inputs: inputs
        x (time, batch, dim): Tensor containing input vector

    Returns: outputs
        **outputs** (batch, time, dim): Tensor produces by conformer block.
    """

    def __init__(self, config: Speech2TextConformerConfig):
        super().__init__()
        self.encoder_dim = config.d_model
        self.num_attention_heads = config.encoder_attention_heads
        self.feed_forward_expansion_factor = config.feed_forward_expansion_factor
        self.conv_expansion_factor = config.conv_expansion_factor
        self.feed_forward_dropout_p = config.conformer_feedforward_dropout
        self.attention_dropout_p = config.conformer_attention_dropout
        self.conv_dropout_p = config.conformer_conv_dropout
        self.conv_kernel_size = config.conformer_conv_kernel_size
        self.half_step_residual = config.conformer_half_step_residual
        self.no_syncbatchnorm = config.no_syncbatchnorm
        self.batch_unsafe_relative_shift = getattr(config, 'batch_unsafe_relative_shift', False)

        if self.half_step_residual:
            self.feed_forward_residual_factor = 0.5
        else:
            self.feed_forward_residual_factor = 1

        self.first_feed_forward = FeedForwardModule(
            encoder_dim=self.encoder_dim,
            expansion_factor=self.feed_forward_expansion_factor,
            dropout_p=self.feed_forward_dropout_p,
        )

        self.attention = MultiHeadedSelfAttentionModule(
            d_model=self.encoder_dim,
            num_heads=self.num_attention_heads,
            dropout_p=self.attention_dropout_p,
            batch_unsafe_relative_shift=self.batch_unsafe_relative_shift,
        )

        self.conv_module = ConformerConvModule(
            in_channels=self.encoder_dim,
            kernel_size=self.conv_kernel_size,
            expansion_factor=self.conv_expansion_factor,
            dropout_p=self.conv_dropout_p,
            no_syncbatchnorm=self.no_syncbatchnorm,
        )

        self.second_feed_forward = FeedForwardModule(
            encoder_dim=self.encoder_dim,
            expansion_factor=self.feed_forward_expansion_factor,
            dropout_p=self.feed_forward_dropout_p,
        )

        self.layernorm = nn.LayerNorm(self.encoder_dim)

    def forward(
            self, x: Tensor, encoder_padding_mask: Tensor, output_attentions: bool = False
    ) -> Tuple[Tensor, Optional[Tensor]]:
        x = x.transpose(0, 1)  # B x T x C
        new_x = self.first_feed_forward(x)
        x = new_x * self.feed_forward_residual_factor + x
        new_x, attn = self.attention(x, encoder_padding_mask, output_attentions)
        x = new_x + x
        new_x = self.conv_module(x, encoder_padding_mask)
        x = new_x + x
        new_x = self.second_feed_forward(x)
        x = new_x * self.feed_forward_residual_factor + x
        x = self.layernorm(x).transpose(1, 0)
        return x, attn


class CTCCompressStrategy:
    FIXED_RATIO = 4
    @staticmethod
    def new_lengths(batch_predicted):
        return [len(p) for p in batch_predicted]

    @staticmethod
    def avg(prob_ctc, predicted, dtype, device):
        new_lengths = CTCCompressStrategy.new_lengths(predicted)
        new_maxlen = max(new_lengths)
        weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype)
        for b_idx, pred in enumerate(predicted):
            processed_inputs_cnt = 0
            for t_idx, same in enumerate(pred):
                new_processed_inputs_cnt = processed_inputs_cnt + same[1]
                weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = 1.0 / same[1]
                processed_inputs_cnt = new_processed_inputs_cnt
        return weights_matrix.to(device), new_lengths

    @staticmethod
    def weighted(prob_ctc, predicted, dtype, device):
        new_lengths = CTCCompressStrategy.new_lengths(predicted)
        new_maxlen = max(new_lengths)
        weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype, device=device)
        for b_idx, pred in enumerate(predicted):
            processed_inputs_cnt = 0
            for t_idx, same in enumerate(pred):
                new_processed_inputs_cnt = processed_inputs_cnt + same[1]
                # Get the probabilities of the prediction for the different time steps as weight
                weights = prob_ctc[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, same[0]]
                weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = \
                    weights / weights.sum()
                processed_inputs_cnt = new_processed_inputs_cnt
        return weights_matrix, new_lengths

    @staticmethod
    def softmax(prob_ctc, predicted, dtype, device):
        new_lengths = CTCCompressStrategy.new_lengths(predicted)
        new_maxlen = max(new_lengths)
        weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype, device=device)
        for b_idx, pred in enumerate(predicted):
            processed_inputs_cnt = 0
            for t_idx, same in enumerate(pred):
                new_processed_inputs_cnt = processed_inputs_cnt + same[1]
                # Get the probabilities of the prediction for the different time steps as weight
                weights = F.softmax(prob_ctc[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, same[0]])
                weights_matrix[b_idx, processed_inputs_cnt:new_processed_inputs_cnt, t_idx] = \
                    weights / weights.sum()
                processed_inputs_cnt = new_processed_inputs_cnt
        return weights_matrix, new_lengths

    @staticmethod
    def fixed(prob_ctc, predicted, dtype, device):
        new_maxlen = math.ceil(prob_ctc.shape[1] / CTCCompressStrategy.FIXED_RATIO)
        weights_matrix = torch.zeros((prob_ctc.shape[0], prob_ctc.shape[1], new_maxlen), dtype=dtype)
        new_lengths = []
        for b_idx, pred in enumerate(predicted):
            original_len = sum(x[1] for x in pred)
            new_len = 0
            for new_t_idx in range(new_maxlen):
                processed_inputs_cnt = new_t_idx * CTCCompressStrategy.FIXED_RATIO
                processed_inputs_cnt_end = processed_inputs_cnt + CTCCompressStrategy.FIXED_RATIO
                if processed_inputs_cnt_end > original_len:
                    processed_inputs_cnt_end = original_len
                weights_matrix[b_idx, processed_inputs_cnt:processed_inputs_cnt_end, new_t_idx] = \
                    1.0 / (processed_inputs_cnt_end - processed_inputs_cnt)
                new_len += 1
                if processed_inputs_cnt_end == original_len:
                    break
            new_lengths.append(new_len)
        return weights_matrix.to(device), new_lengths


class ConformerEncoderDecoderPreTrainedModel(Speech2TextPreTrainedModel):
    config_class = Speech2TextConformerConfig


class ConformerEncoder(ConformerEncoderDecoderPreTrainedModel):
    """
    Conformer encoder consisting of *config.encoder_layers* layers. Each layer is a
    [`ConformerEncoderLayer`].

    Args:
        config: Speech2TextConformerConfig
    """

    def __init__(self, config: Speech2TextConformerConfig):
        super().__init__(config)

        self.dropout = config.dropout

        embed_dim = config.d_model
        self.padding_idx = config.pad_token_id
        self.max_source_positions = config.max_source_positions
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        self.subsample = Conv1dSubsampler(config)

        self.layers = nn.ModuleList([ConformerEncoderLayer(config) for _ in range(config.encoder_layers)])

        self.ctc_flag = False
        if config.ctc_compress_strategy != "none":
            self.ctc_flag = True
            self.ctc_fc = nn.Linear(config.encoder_embed_dim, config.src_vocab_size)
            self.ctc_layer = config.ctc_encoder_layer
            self.ctc_compress_method = getattr(CTCCompressStrategy, config.ctc_compress_strategy)
            self.ctc_compress_max_out_size = config.ctc_compress_max_out_size
            CTCCompressStrategy.FIXED_RATIO = config.ctc_compress_fixed_ratio

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def ensure_max_ctc_out_len(self, batch_predicted):
        """
        Ensures that the output of the CTC compression is not longer than the ctc_compress_max_out_size.
        If there are samples violating this constraint, consecutive predictions are merged so to shorten the sentence.
        E.g. if the ctc_compress_max_out_size is set to 3, and the output of the CTC compression would be
        long 5, the first and second predictions are merged, as well as the third and the fourth. So, the
        corresponding vectors will be merged according to the CTC compression strategy.
        """
        if self.ctc_compress_max_out_size > 0:

            def merge_sublist(elements):
                """
                Takes a list of Tuples (predicted_element, num_corresponding_vectors) and returns
                a single tuple with the predicted_element having the highest number of corresponding_vectors
                (in case of a tie, the first is returned) and the total sum of the num_corresponding_vectors
                E.g. if the input is [(a, 3), (b, 5), (c, 6), (a, 4)], the output will be (a, 18).
                """
                sum_num_vectors = 0
                max_element = None
                max_element_cnt = 0
                temp_dict = {}
                for predicted_element, num_corresponding_vectors in elements:
                    if predicted_element in temp_dict:
                        temp_dict[predicted_element] += num_corresponding_vectors
                    else:
                        temp_dict[predicted_element] = num_corresponding_vectors
                    if temp_dict[predicted_element] > max_element_cnt:
                        max_element_cnt = temp_dict[predicted_element]
                        max_element = predicted_element
                    sum_num_vectors += num_corresponding_vectors
                return max_element, sum_num_vectors

            for b_idx, p in enumerate(batch_predicted):
                pred_len = len(p)
                if pred_len > self.ctc_compress_max_out_size:
                    reduction_factor = math.ceil(pred_len / self.ctc_compress_max_out_size)
                    i = 0
                    new_p = []
                    while i < pred_len:
                        new_p.append(merge_sublist(p[i:i + reduction_factor]))
                        i += reduction_factor
                    batch_predicted[b_idx] = new_p

        return batch_predicted

    def average_same_ctc_features(self, x_ctc, x, input_lengths):
        with torch.no_grad():
            batch_predicted = []
            prob_ctc = F.softmax(x_ctc, dim=-1).transpose(0, 1)  # from T x B x D to B x T x D
            for b in range(prob_ctc.shape[0]):
                predicted = prob_ctc[b][: input_lengths[b]].argmax(-1).tolist()
                batch_predicted.append([(p[0], len(list(p[1]))) for p in groupby(predicted)])
            batch_predicted = self.ensure_max_ctc_out_len(batch_predicted)
            weights_matrix, new_lengths = self.ctc_compress_method(
                prob_ctc, batch_predicted, x.dtype, x.device)
        # x is T x B x C -> B x C x T; weights_matrix is B x T x T'
        compressed_output = x.permute(1, 2, 0).bmm(weights_matrix)  # B x C x T'
        return compressed_output.permute(2, 0, 1), input_lengths.new(new_lengths)

    @staticmethod
    def lengths_to_padding_mask(lens: torch.LongTensor) -> Tensor:
        bsz, max_lens = lens.size(0), torch.max(lens).item()
        mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
        mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
        return mask

    def apply_ctc(self, x, input_lengths):
        x_ctc = self.ctc_fc(x)
        x, input_lengths = self.average_same_ctc_features(x_ctc, x, input_lengths)
        padding_mask = ConformerEncoder.lengths_to_padding_mask(input_lengths)
        return x, x_ctc, padding_mask

    def forward(
        self,
        input_features,
        attention_mask=None,
        head_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        Args:
            input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`):
                Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be
                obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
                `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
                `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the fbank features,
                padding and conversion into a tensor of type `torch.FloatTensor`. See
                [`~Speech2TextFeatureExtractor.__call__`]
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
                `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        inputs_embeds = self.subsample(input_features, attention_mask)
        inputs_embeds = self.embed_scale * inputs_embeds

        # subsample attention mask if necessary
        if attention_mask is not None:
            attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[0], attention_mask)

        hidden_states = nn.functional.dropout(inputs_embeds, p=self.dropout, training=self.training)

        # expand attention_mask
        if attention_mask is not None:
            padding_mask = attention_mask.ne(1).long()
        else:
            padding_mask = torch.zeros(inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        # TODO: implement head mask
        assert head_mask is None, "Head masking is not yet implemented for Conformer model"

        for idx, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states.transpose(0, 1),)
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    encoder_layer.__call__,
                    hidden_states,
                    padding_mask,
                    output_attentions,
                )
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    padding_mask,
                    output_attentions=output_attentions,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

            if self.ctc_flag and self.ctc_layer == idx + 1:
                hidden_states, ctc_output, padding_mask = self.apply_ctc(hidden_states, attention_mask.sum(dim=1))
                attention_mask = padding_mask.ne(1).long()

        hidden_states = hidden_states.transpose(0, 1)  # T x B x C -> B x T x C
        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
        )


@add_start_docstrings(
    "The bare Conformer Model outputting raw hidden-states without any specific head on top.",
    CONFORMER_START_DOCSTRING,
)
class ConformerEncoderDecoderModel(ConformerEncoderDecoderPreTrainedModel):
    def __init__(self, config: Speech2TextConformerConfig):
        super().__init__(config)

        self.encoder = ConformerEncoder(config)
        self.decoder = Speech2TextDecoder(config)

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.decoder.embed_tokens

    def set_input_embeddings(self, value):
        self.decoder.embed_tokens = value

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

    @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_features: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
        r"""
        Returns:

        Example:

         ```python
         >>> import torch
         >>> from transformers import AutoFeatureExtractor, AutoModel
         >>> from datasets import load_dataset

         >>> model = AutoModel.from_pretrained("FBK-MT/balbetto-asr-small-test")
         >>> feature_extractor = AutoFeatureExtractor.from_pretrained("FBK-MT/balbetto-asr-small-test")
         >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
         >>> inputs = feature_extractor(
         ...     ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt"
         ... )
         >>> input_features = inputs.input_features
         >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
         >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
         >>> list(last_hidden_state.shape)
         [1, 2, 256]
         ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_features,
                attention_mask=attention_mask,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
            )
        # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        # downsample encoder attention mask
        if attention_mask is not None:
            encoder_attention_mask = self._get_feature_vector_attention_mask(
                encoder_outputs[0].shape[1], attention_mask
            )
        else:
            encoder_attention_mask = None

        # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=encoder_outputs[0],
            encoder_attention_mask=encoder_attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return decoder_outputs + encoder_outputs

        return Seq2SeqModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )


@add_start_docstrings(
    "The Conformer Model with a language modeling head.",
    CONFORMER_START_DOCSTRING,
)
class ConformerEncoderDecoderForConditionalGeneration(ConformerEncoderDecoderPreTrainedModel, GenerationMixin):
    base_model_prefix = "model"
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config: Speech2TextConformerConfig):
        super().__init__(config)
        self.model = ConformerEncoderDecoderModel(config)
        self.lm_head = nn.Linear(config.d_model, self.config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def get_encoder(self):
        return self.model.get_encoder()

    def get_decoder(self):
        return self.model.get_decoder()

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    @add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_features: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
            or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
            only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> import torch
        >>> import transformers
        >>> from datasets import load_dataset

        >>> pipe = transformers.pipeline(
        ...     "automatic-speech-recognition",
        ...     model='FBK-MT/balbetto-asr-small-test',
        ...     feature_extractor='FBK-MT/balbetto-asr-small-test',
        ...     trust_remote_code=True)


        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

        >>> generated_ids = pipe(ds[0]["audio"])

        >>> transcription = pipe.feature_extractor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        >>> transcription
        'mister quilter is the apostle of the middle classes and we are glad to welcome his gospel'
        ```"""
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if labels is not None:
            if decoder_input_ids is None and decoder_inputs_embeds is None:
                decoder_input_ids = shift_tokens_right(
                    labels, self.config.pad_token_id, self.config.decoder_start_token_id
                )

        outputs = self.model(
            input_features,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            encoder_outputs=encoder_outputs,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            past_key_values=past_key_values,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        lm_logits = self.lm_head(outputs[0])

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (lm_logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
            )
        return reordered_past


Speech2TextConformerConfig.register_for_auto_class()
ConformerEncoderDecoderForConditionalGeneration.register_for_auto_class("AutoModel")
ConformerEncoderDecoderForConditionalGeneration.register_for_auto_class("AutoModelForSpeechSeq2Seq")

transformers.AutoConfig.register("conformer_encoder_decoder", Speech2TextConformerConfig)
transformers.AutoModel.register(
    Speech2TextConformerConfig, ConformerEncoderDecoderForConditionalGeneration)
transformers.AutoModelForSpeechSeq2Seq.register(
    Speech2TextConformerConfig, ConformerEncoderDecoderForConditionalGeneration)
transformers.AutoProcessor.register(Speech2TextConformerConfig, Speech2TextProcessor)
transformers.models.auto.modeling_auto.MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES['conformer_encoder_decoder'] = \
    "ConformerEncoderDecoderForConditionalGeneration"
transformers.TOKENIZER_MAPPING.register(Speech2TextConformerConfig, (Speech2TextTokenizer, None))
transformers.FEATURE_EXTRACTOR_MAPPING.register(Speech2TextConformerConfig, Speech2TextFeatureExtractor)