File size: 1,505 Bytes
3440f83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]

import torch.nn as nn

from vlmo.torchscale.architecture.decoder import Decoder
from vlmo.torchscale.architecture.encoder import Encoder


class EncoderDecoder(nn.Module):
    def __init__(
        self,
        args,
        encoder_embed_tokens=None,
        encoder_embed_positions=None,
        decoder_embed_tokens=None,
        decoder_embed_positions=None,
        output_projection=None,
        **kwargs
    ):
        super().__init__()
        self.args = args
        if args.share_all_embeddings:
            args.share_decoder_input_output_embed = True

        self.encoder = Encoder(args, encoder_embed_tokens, encoder_embed_positions, is_encoder_decoder=True, **kwargs)

        if args.share_all_embeddings and decoder_embed_tokens is None:
            decoder_embed_tokens = self.encoder.embed_tokens

        self.decoder = Decoder(
            args, decoder_embed_tokens, decoder_embed_positions, output_projection, is_encoder_decoder=True, **kwargs
        )

    def forward(self, src_tokens, prev_output_tokens, return_all_hiddens=False, features_only=False, **kwargs):
        encoder_out = self.encoder(src_tokens, return_all_hiddens=return_all_hiddens)
        decoder_out = self.decoder(
            prev_output_tokens,
            encoder_out=encoder_out,
            features_only=features_only,
            return_all_hiddens=return_all_hiddens,
        )
        return decoder_out