Text Generation
Transformers
PyTorch
English
retnet
custom_code
File size: 4,372 Bytes
7110f83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from transformers.configuration_utils import PretrainedConfig


class RetNetConfig(PretrainedConfig):
    model_type = "retnet"
    attribute_map = {
        "hidden_size": "decoder_embed_dim",
        "intermediate_size": "decoder_ffn_embed_dim",
        "num_attention_heads": "decoder_retention_heads",
        "num_hidden_layers": "decoder_layers",
    }

    def __init__(
            self,
            vocab_size: int = 50257,
            initializer_range: float = 0.02,
            is_decoder: bool = True,
            bos_token_id: int = None,
            pad_token_id: int = None,
            eos_token_id: int = None,
            output_retentions: bool = False,
            use_cache: bool = True,
            forward_impl: str = 'parallel',
            activation_fn: str = "gelu",
            dropout: float = 0.0,  # dropout probability
            activation_dropout: float = 0.0,  # dropout probability after activation in FFN.
            drop_path_rate: float = 0.0,
            decoder_embed_dim: int = 768,  # decoder embedding dimension
            decoder_value_embed_dim: int = 1280,  # decoder value embedding dimension
            decoder_ffn_embed_dim: int = 1280,  # decoder embedding dimension for FFN
            decoder_layers: int = 12,  # num decoder layers
            decoder_retention_heads: int = 3,  # num decoder retention heads
            decoder_normalize_before: bool = True,  # apply layernorm before each decoder block
            layernorm_embedding: bool = False,  # add layernorm to embedding
            no_scale_embedding: bool = True,  # if True, dont scale embeddings
            recurrent_chunk_size: int = 512,
            use_glu: bool = True,  # use GLU instead of FFN
            z_loss_coeff: float = 0.0,  # coefficient for z loss: TODO: 1e-4
            use_lm_decay: bool = False,
            deepnorm: bool = False,
            subln: bool = True,
            use_rms_norm: bool = True,
            groupnorm_affine: bool = False,
            layernorm_eps: float = 1e-6,
            tie_word_embeddings: bool = False,
            use_bias: bool = False,
            parallel_residual: bool = False,
            rotary_percentage: float = 1.0,
            **kwargs):
        self.vocab_size = vocab_size
        self.initializer_range = initializer_range
        self.output_retentions = output_retentions
        self.use_lm_decay = use_lm_decay
        self.use_glu = use_glu
        self.z_loss_coeff = z_loss_coeff
        # size related
        self.decoder_embed_dim = decoder_embed_dim
        self.decoder_value_embed_dim = decoder_value_embed_dim
        self.decoder_retention_heads = decoder_retention_heads
        self.decoder_ffn_embed_dim = decoder_ffn_embed_dim
        self.decoder_layers = decoder_layers
        # normalization related
        self.decoder_normalize_before = decoder_normalize_before
        self.activation_fn = activation_fn
        self.dropout = dropout
        self.drop_path_rate = drop_path_rate
        self.activation_dropout = activation_dropout
        self.no_scale_embedding = no_scale_embedding
        self.layernorm_embedding = layernorm_embedding
        self.deepnorm = deepnorm
        self.subln = subln
        self.use_rms_norm = use_rms_norm
        self.layernorm_eps = layernorm_eps
        self.use_bias = use_bias
        self.groupnorm_affine = groupnorm_affine
        self.parallel_residual = parallel_residual
        # Blockwise
        self.recurrent_chunk_size = recurrent_chunk_size
        self.forward_impl = forward_impl
        # rope
        self.rotary_percentage = rotary_percentage

        if self.deepnorm:
            self.decoder_normalize_before = False
            self.subln = False
        if self.subln:
            self.decoder_normalize_before = True
            self.deepnorm = False

        super().__init__(is_decoder=is_decoder,
                         bos_token_id=bos_token_id,
                         pad_token_id=pad_token_id,
                         eos_token_id=eos_token_id,
                         use_cache=use_cache,
                         tie_word_embeddings=tie_word_embeddings,
                         **kwargs)

    def override(self, args):
        for hp in self.__dict__.keys():
            if getattr(args, hp, None) is not None:
                self.__dict__[hp] = getattr(args, hp, None)