OpenEfficientAI commited on
Commit
5ffa5c6
·
verified ·
1 Parent(s): 7dd1c42

Upload 9 files

Browse files
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "[PAD]": 32000
3
+ }
config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/data/guojialong/output/llama_prepbn_bnfp32/output/alpca_out_test1/checkpoint-25000",
3
+ "architectures": [
4
+ "LlamaForCausalLM"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModelForCausalLM": "modeling_llama.LlamaForCausalLM",
8
+ "AutoModelForSequenceClassification": "modeling_llama.LlamaForSequenceClassification"
9
+ },
10
+ "bos_token_id": 1,
11
+ "eos_token_id": 2,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 1024,
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 4096,
16
+ "max_position_embeddings": 2048,
17
+ "model_type": "llama",
18
+ "num_attention_heads": 8,
19
+ "num_hidden_layers": 18,
20
+ "pad_token_id": 0,
21
+ "rms_norm_eps": 1e-06,
22
+ "tie_word_embeddings": false,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.29.1",
25
+ "use_cache": false,
26
+ "vocab_size": 32001
27
+ }
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.29.1",
7
+ "use_cache": false
8
+ }
modeling_llama.py ADDED
@@ -0,0 +1,1229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ import numpy as np
23
+ from typing import List, Optional, Tuple, Union
24
+ import scipy as sp
25
+
26
+ import torch
27
+ import torch.utils.checkpoint
28
+ from torch import nn
29
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
+ import torch.nn.functional as F
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.modeling_outputs import BaseModelOutputWithPast, MoEModelOutputWithPastAndCrossAttentions, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
36
+ from transformers.models.llama.configuration_llama import LlamaConfig
37
+ from functools import partial
38
+ from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ _CONFIG_FOR_DOC = "LlamaConfig"
44
+
45
+
46
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
47
+ def _make_causal_mask(
48
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
49
+ ):
50
+ """
51
+ Make causal mask used for bi-directional self-attention.
52
+ """
53
+ bsz, tgt_len = input_ids_shape
54
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
55
+ mask_cond = torch.arange(mask.size(-1), device=device)
56
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
57
+ mask = mask.to(dtype)
58
+
59
+ if past_key_values_length > 0:
60
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
61
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
62
+
63
+
64
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
65
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
66
+ """
67
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
68
+ """
69
+ bsz, src_len = mask.size()
70
+ tgt_len = tgt_len if tgt_len is not None else src_len
71
+
72
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
73
+
74
+ inverted_mask = 1.0 - expanded_mask
75
+
76
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
77
+
78
+
79
+ class LlamaRMSNorm(nn.Module):
80
+ def __init__(self, hidden_size, eps=1e-6):
81
+ """
82
+ LlamaRMSNorm is equivalent to T5LayerNorm
83
+ """
84
+ super().__init__()
85
+ self.weight = nn.Parameter(torch.ones(hidden_size))
86
+ self.variance_epsilon = eps
87
+
88
+ def forward(self, hidden_states, pad_mask=None):
89
+ input_dtype = hidden_states.dtype
90
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
91
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
92
+
93
+ return (self.weight * hidden_states).to(input_dtype)
94
+
95
+
96
+ class MaskSyncBatchNorm(nn.Module):
97
+ """
98
+ An implementation of masked batch normalization, used for testing the numerical
99
+ stability.
100
+ """
101
+
102
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, \
103
+ affine=True, track_running_stats=True, sync_bn=True, process_group=None):
104
+ super().__init__()
105
+
106
+ self.num_features = num_features
107
+ self.eps = eps
108
+ self.momentum = momentum
109
+ self.affine = affine
110
+ self.track_running_stats = track_running_stats
111
+ if self.affine:
112
+ self.weight = nn.Parameter(torch.Tensor(num_features))
113
+ self.bias = nn.Parameter(torch.Tensor(num_features))
114
+ else:
115
+ self.register_parameter('weight', None)
116
+ self.register_parameter('bias', None)
117
+ if self.track_running_stats:
118
+ self.register_buffer('running_mean', torch.zeros(num_features))
119
+ self.register_buffer('running_var', torch.ones(num_features))
120
+ self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
121
+ else:
122
+ self.register_parameter('running_mean', None)
123
+ self.register_parameter('running_var', None)
124
+ self.register_parameter('num_batches_tracked', None)
125
+ self.sync_bn = sync_bn
126
+ # gpu_size is set through DistributedDataParallel initialization. This is to ensure that SyncBatchNorm is used
127
+ # under supported condition (single GPU per process)
128
+ self.process_group = process_group
129
+ self.ddp_gpu_size = 4
130
+ # self.lp = LayerScaling1D()
131
+ self.reset_parameters()
132
+
133
+ def _specify_ddp_gpu_num(self, gpu_size):
134
+ if gpu_size > 1:
135
+ raise ValueError('SyncBatchNorm is only supported for DDP with single GPU per process')
136
+ self.ddp_gpu_size = gpu_size
137
+
138
+ def reset_running_stats(self):
139
+ if self.track_running_stats:
140
+ self.running_mean.zero_()
141
+ self.running_var.fill_(1)
142
+ self.num_batches_tracked.zero_()
143
+
144
+ def reset_parameters(self):
145
+ self.reset_running_stats()
146
+ if self.affine:
147
+ # nn.init.ones_(self.weight)
148
+ nn.init.zeros_(self.weight)
149
+ nn.init.zeros_(self.bias)
150
+
151
+ def extra_repr(self):
152
+ return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
153
+ 'track_running_stats={track_running_stats}'.format(**self.__dict__)
154
+
155
+ def forward(self, input, pad_mask=None, is_encoder=False, update_run=True):
156
+ """
157
+ input: B x C x T
158
+ pad_mask: B x T (padding is False)
159
+ """
160
+ input_dtype = input.dtype
161
+ input = input.to(torch.float32)
162
+ shaped_input = (len(input.shape) == 2)
163
+ if shaped_input:
164
+ input = input.unsqueeze(0)
165
+
166
+ B, T, C = input.shape
167
+ # construct the mask_input, size to be (BxL) x C: L is the real length here
168
+ if pad_mask is None or not self.training:
169
+ mask_input = input.contiguous().view(-1, C)
170
+ else:
171
+ bn_mask = (pad_mask == 1)
172
+ mask_input = input[bn_mask, :]
173
+ mask_input = mask_input.contiguous().view(-1, C)
174
+
175
+ if self.momentum is None:
176
+ exponential_average_factor = 0.0
177
+ else:
178
+ exponential_average_factor = self.momentum
179
+
180
+ if self.training and self.track_running_stats:
181
+ self.num_batches_tracked += 1
182
+ if self.momentum is None: # use cumulative moving average
183
+ exponential_average_factor = 1.0 / self.num_batches_tracked.item()
184
+ else: # use exponential moving average
185
+ exponential_average_factor = self.momentum
186
+
187
+ if not update_run:
188
+ exponential_average_factor = 0.0
189
+
190
+ mean_copy = self.running_mean.clone()
191
+ var_copy = self.running_var.clone()
192
+ flag = False
193
+ flag_mean = mask_input.mean(0)
194
+ if torch.isnan(flag_mean).any() and self.training:
195
+ exponential_average_factor = 0.0
196
+ flag = True
197
+ # mask_input = torch.nan_to_num(mask_input)
198
+
199
+ need_sync = self.training and self.sync_bn # self.training #or not self.track_running_stats
200
+ if need_sync:
201
+ process_group = torch.distributed.group.WORLD
202
+ if self.process_group:
203
+ process_group = self.process_group
204
+ world_size = torch.distributed.get_world_size(process_group)
205
+ need_sync = world_size > 1
206
+
207
+ if torch.isnan(self.running_var).any() and self.training:
208
+ exponential_average_factor = 1.0
209
+ self.running_mean = torch.nan_to_num(self.running_mean)
210
+ self.running_var = torch.nan_to_num(self.running_var)
211
+
212
+ if not need_sync:
213
+ z = F.batch_norm(
214
+ mask_input.to(torch.float32), self.running_mean.to(torch.float32), self.running_var.to(torch.float32), self.weight.to(torch.float32), self.bias.to(torch.float32),
215
+ self.training or not self.track_running_stats,
216
+ exponential_average_factor, self.eps)
217
+ # z = F.batch_norm(
218
+ # mask_input, self.running_mean, self.running_var, self.weight, self.bias,
219
+ # True,
220
+ # exponential_average_factor, self.eps)
221
+ else:
222
+ # if not self.ddp_gpu_size:
223
+ # raise AttributeError('SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel')
224
+ z = sync_batch_norm.apply(
225
+ mask_input.to(torch.float32), self.weight.to(torch.float32), self.bias.to(torch.float32), self.running_mean.to(torch.float32), self.running_var.to(torch.float32),
226
+ self.eps, exponential_average_factor, process_group, world_size)
227
+
228
+ if flag:
229
+ self.running_mean = mean_copy
230
+ self.running_var = var_copy
231
+
232
+ if pad_mask is None or not self.training:
233
+ output = z
234
+ else:
235
+ output = input.clone()
236
+ output[bn_mask, :] = z
237
+ # output = z
238
+ output = output.view(B, T, C)
239
+ # Reshape it.
240
+ if shaped_input:
241
+ output = output.squeeze(0)
242
+ return output.to(input_dtype)
243
+
244
+
245
+ # class RepBN(nn.Module):
246
+ # def __init__(self, channels, eps=1e-5):
247
+ # super(RepBN, self).__init__()
248
+ # self.alpha = nn.Parameter(torch.zeros(1))
249
+ # self.bn = MaskSyncBatchNorm(channels, eps=eps)
250
+ #
251
+ # def forward(self, x, pad_mask=None, is_encoder=False, update_run=True):
252
+ # x = self.bn(x, pad_mask, is_encoder, update_run) + self.alpha * x
253
+ # return x
254
+
255
+
256
+ class RepBN(nn.Module):
257
+ def __init__(self, channels, eps=1e-5):
258
+ super(RepBN, self).__init__()
259
+ self.alpha = nn.Parameter(torch.zeros(1))
260
+ self.bn = nn.BatchNorm1d(channels, eps=eps, momentum=0.1)
261
+ self.reset_parameters()
262
+
263
+ def reset_parameters(self):
264
+ # nn.init.ones_(self.weight)
265
+ nn.init.zeros_(self.bn.weight)
266
+ nn.init.zeros_(self.bn.bias)
267
+
268
+ def forward(self, x, pad_mask=None):
269
+ B, T, C = x.shape
270
+ # construct the mask_input, size to be (BxL) x C: L is the real length here
271
+ if pad_mask is None or not self.training:
272
+ mask_input = x.contiguous().view(-1, C)
273
+ else:
274
+ bn_mask = (pad_mask == 1)
275
+ mask_input = x[bn_mask, :]
276
+ mask_input = mask_input.contiguous().view(-1, C)
277
+
278
+ o_bn = self.bn(mask_input)
279
+
280
+ if pad_mask is None or not self.training:
281
+ output = o_bn.view(B, T, C)
282
+ else:
283
+ output = x.clone()
284
+ output[bn_mask, :] = o_bn
285
+
286
+ x = output + self.alpha * x
287
+ return x
288
+
289
+
290
+ class LinearNorm(nn.Module):
291
+ def __init__(self, dim, norm1, norm2, eps=1e-5, warm=10000, step=18000, r0=1.0):
292
+ super(LinearNorm, self).__init__()
293
+ self.register_buffer('num_step', torch.tensor(0))
294
+ self.register_buffer('warm', torch.tensor(warm))
295
+ self.register_buffer('iter', torch.tensor(step))
296
+ self.register_buffer('total_step', torch.tensor(step))
297
+ self.r0 = r0
298
+ self.norm1 = norm1(dim, eps)
299
+ self.norm2 = norm2(dim, eps)
300
+
301
+ def forward(self, x, pad_mask=None):
302
+ if self.training:
303
+ if self.warm > 0:
304
+ if self.num_step % 16 == 0:
305
+ self.warm.copy_(self.warm - 1)
306
+ x = self.norm1(x)
307
+ else:
308
+ lamda = self.r0 * self.iter / self.total_step
309
+ if self.iter > 0:
310
+ if self.num_step % 16 == 0:
311
+ self.iter.copy_(self.iter - 1)
312
+ x1 = self.norm1(x)
313
+ x2 = self.norm2(x, pad_mask)
314
+ x = lamda * x1 + (1 - lamda) * x2
315
+ self.num_step.copy_((self.num_step + 1) % 16)
316
+ else:
317
+ x = self.norm2(x, pad_mask)
318
+ return x
319
+
320
+
321
+ linearnorm = partial(LinearNorm, norm1=LlamaRMSNorm, norm2=RepBN)
322
+ # linearnorm = LlamaRMSNorm
323
+ # linearnorm = RepBN
324
+
325
+
326
+ class LlamaRotaryEmbedding(torch.nn.Module):
327
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
328
+ super().__init__()
329
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
330
+ self.register_buffer("inv_freq", inv_freq)
331
+
332
+ # Build here to make `torch.jit.trace` work.
333
+ self.max_seq_len_cached = max_position_embeddings
334
+ # t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
335
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
336
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
337
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
338
+ emb = torch.cat((freqs, freqs), dim=-1)
339
+ dtype = torch.get_default_dtype()
340
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
341
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
342
+
343
+ def forward(self, x, seq_len=None):
344
+ # x: [bs, num_attention_heads, seq_len, head_size]
345
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
346
+ if seq_len > self.max_seq_len_cached:
347
+ self.max_seq_len_cached = seq_len
348
+ # t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
349
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32)
350
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
351
+
352
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
353
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
354
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
355
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
356
+ return (
357
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
358
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
359
+ )
360
+
361
+
362
+ def rotate_half(x):
363
+ """Rotates half the hidden dims of the input."""
364
+ x1 = x[..., : x.shape[-1] // 2]
365
+ x2 = x[..., x.shape[-1] // 2 :]
366
+ return torch.cat((-x2, x1), dim=-1)
367
+
368
+
369
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
370
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
371
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
372
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
373
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
374
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
375
+ q_embed = (q * cos) + (rotate_half(q) * sin)
376
+ k_embed = (k * cos) + (rotate_half(k) * sin)
377
+ return q_embed, k_embed
378
+
379
+
380
+ class LlamaMLP(nn.Module):
381
+ def __init__(
382
+ self,
383
+ hidden_size: int,
384
+ intermediate_size: int,
385
+ hidden_act: str,
386
+ ):
387
+ super().__init__()
388
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True)
389
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
390
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True)
391
+ self.act_fn = ACT2FN[hidden_act]
392
+
393
+ def forward(self, x):
394
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
395
+
396
+
397
+ class LlamaAttention(nn.Module):
398
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
399
+
400
+ def __init__(self, config: LlamaConfig):
401
+ super().__init__()
402
+ self.config = config
403
+ self.hidden_size = config.hidden_size
404
+ self.num_heads = config.num_attention_heads
405
+ self.head_dim = self.hidden_size // self.num_heads
406
+ self.max_position_embeddings = config.max_position_embeddings
407
+
408
+ if (self.head_dim * self.num_heads) != self.hidden_size:
409
+ raise ValueError(
410
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
411
+ f" and `num_heads`: {self.num_heads})."
412
+ )
413
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
414
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
415
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
416
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
417
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
418
+
419
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
420
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
421
+
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ attention_mask: Optional[torch.Tensor] = None,
426
+ position_ids: Optional[torch.LongTensor] = None,
427
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
428
+ output_attentions: bool = False,
429
+ use_cache: bool = False,
430
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
431
+ bsz, q_len, _ = hidden_states.size()
432
+
433
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
434
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
435
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
436
+
437
+ kv_seq_len = key_states.shape[-2]
438
+ if past_key_value is not None:
439
+ kv_seq_len += past_key_value[0].shape[-2]
440
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
441
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
442
+ # [bsz, nh, t, hd]
443
+
444
+ if past_key_value is not None:
445
+ # reuse k, v, self_attention
446
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
447
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
448
+
449
+ past_key_value = (key_states, value_states) if use_cache else None
450
+
451
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
452
+
453
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
454
+ raise ValueError(
455
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
456
+ f" {attn_weights.size()}"
457
+ )
458
+
459
+ if attention_mask is not None:
460
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
461
+ raise ValueError(
462
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
463
+ )
464
+ attn_weights = attn_weights + attention_mask
465
+ attn_weights = torch.max(
466
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
467
+ )
468
+
469
+ # upcast attention to fp32
470
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
471
+
472
+ attn_output = torch.matmul(attn_weights, value_states)
473
+
474
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
475
+ raise ValueError(
476
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
477
+ f" {attn_output.size()}"
478
+ )
479
+
480
+ attn_output = attn_output.transpose(1, 2)
481
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
482
+
483
+ attn_output = self.o_proj(attn_output)
484
+
485
+ if not output_attentions:
486
+ attn_weights = None
487
+
488
+ return attn_output, attn_weights, past_key_value
489
+
490
+
491
+ class LlamaDecoderLayer(nn.Module):
492
+ def __init__(self, config: LlamaConfig):
493
+ super().__init__()
494
+ self.hidden_size = config.hidden_size
495
+ self.intermediate_size=config.intermediate_size
496
+ self.self_attn = LlamaAttention(config=config)
497
+ self.mlp = LlamaMLP(
498
+ hidden_size=self.hidden_size,
499
+ intermediate_size=config.intermediate_size,
500
+ hidden_act=config.hidden_act,
501
+ )
502
+ # self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
503
+ # self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
504
+ # self.input_layernorm = linearnorm(config.hidden_size, eps=config.rms_norm_eps)
505
+ # self.post_attention_layernorm = linearnorm(config.hidden_size, eps=config.rms_norm_eps)
506
+ self.input_layernorm = nn.Identity()
507
+ self.post_attention_layernorm = nn.Identity()
508
+
509
+ def merge_bn(self):
510
+ # Attention
511
+ miu = self.input_layernorm.norm2.bn.running_mean
512
+ sigma2 = self.input_layernorm.norm2.bn.running_var
513
+ gamma = self.input_layernorm.norm2.bn.weight
514
+ beta = self.input_layernorm.norm2.bn.bias
515
+ eps = self.input_layernorm.norm2.bn.eps
516
+ alpha = self.input_layernorm.norm2.alpha
517
+
518
+ w_q = self.self_attn.q_proj.weight.data.transpose(0, 1)
519
+ w_k = self.self_attn.k_proj.weight.data.transpose(0, 1)
520
+ w_v = self.self_attn.v_proj.weight.data.transpose(0, 1)
521
+
522
+ self.self_attn.q_proj = nn.Linear(self.self_attn.hidden_size, self.self_attn.num_heads * self.self_attn.head_dim, bias=True).to(w_q.device)
523
+ self.self_attn.k_proj = nn.Linear(self.self_attn.hidden_size, self.self_attn.num_heads * self.self_attn.head_dim, bias=True).to(w_q.device)
524
+ self.self_attn.v_proj = nn.Linear(self.self_attn.hidden_size, self.self_attn.num_heads * self.self_attn.head_dim, bias=True).to(w_q.device)
525
+
526
+ a = gamma / torch.sqrt(sigma2 + eps) + alpha
527
+ b = beta - gamma * miu / torch.sqrt(sigma2 + eps)
528
+ a = torch.diag(a)
529
+
530
+ w_q_n = (a @ w_q).transpose(0, 1)
531
+ b_q_n = (b.unsqueeze(0) @ w_q).squeeze(0)
532
+ self.self_attn.q_proj.weight.data.copy_(w_q_n)
533
+ self.self_attn.q_proj.bias.data.copy_(b_q_n)
534
+ w_k_n = (a @ w_k).transpose(0, 1)
535
+ b_k_n = (b.unsqueeze(0) @ w_k).squeeze(0)
536
+ self.self_attn.k_proj.weight.data.copy_(w_k_n)
537
+ self.self_attn.k_proj.bias.data.copy_(b_k_n)
538
+ w_v_n = (a @ w_v).transpose(0, 1)
539
+ b_v_n = (b.unsqueeze(0) @ w_v).squeeze(0)
540
+ self.self_attn.v_proj.weight.data.copy_(w_v_n)
541
+ self.self_attn.v_proj.bias.data.copy_(b_v_n)
542
+ self.input_layernorm = nn.Identity()
543
+
544
+ # mlp
545
+ miu = self.post_attention_layernorm.norm2.bn.running_mean
546
+ sigma2 = self.post_attention_layernorm.norm2.bn.running_var
547
+ gamma = self.post_attention_layernorm.norm2.bn.weight
548
+ beta = self.post_attention_layernorm.norm2.bn.bias
549
+ eps = self.post_attention_layernorm.norm2.bn.eps
550
+ alpha = self.post_attention_layernorm.norm2.alpha
551
+
552
+ w_g = self.mlp.gate_proj.weight.data.transpose(0, 1)
553
+ w_u = self.mlp.up_proj.weight.data.transpose(0, 1)
554
+
555
+ self.mlp.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
556
+ self.mlp.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
557
+
558
+ a = gamma / torch.sqrt(sigma2 + eps) + alpha
559
+ b = beta - gamma * miu / torch.sqrt(sigma2 + eps)
560
+ a = torch.diag(a)
561
+
562
+ w_g_n = (a @ w_g).transpose(0, 1)
563
+ b_g_n = (b.unsqueeze(0) @ w_g).squeeze(0)
564
+ self.mlp.gate_proj.weight.data.copy_(w_g_n)
565
+ self.mlp.gate_proj.bias.data.copy_(b_g_n)
566
+ w_u_n = (a @ w_u).transpose(0, 1)
567
+ b_u_n = (b.unsqueeze(0) @ w_u).squeeze(0)
568
+ self.mlp.up_proj.weight.data.copy_(w_u_n)
569
+ self.mlp.up_proj.bias.data.copy_(b_u_n)
570
+ self.post_attention_layernorm = nn.Identity()
571
+ return
572
+
573
+ def forward(
574
+ self,
575
+ hidden_states: torch.Tensor,
576
+ attention_mask: Optional[torch.Tensor] = None,
577
+ pad_mask: Optional[torch.Tensor] = None,
578
+ position_ids: Optional[torch.LongTensor] = None,
579
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
580
+ output_attentions: Optional[bool] = False,
581
+ use_cache: Optional[bool] = False,
582
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
583
+ """
584
+ Args:
585
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
586
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
587
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
588
+ output_attentions (`bool`, *optional*):
589
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
590
+ returned tensors for more detail.
591
+ use_cache (`bool`, *optional*):
592
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
593
+ (see `past_key_values`).
594
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
595
+ """
596
+
597
+ residual = hidden_states
598
+
599
+ # hidden_states = self.input_layernorm(hidden_states, pad_mask)
600
+ hidden_states = self.input_layernorm(hidden_states)
601
+
602
+ # Self Attention
603
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
604
+ hidden_states=hidden_states,
605
+ attention_mask=attention_mask,
606
+ position_ids=position_ids,
607
+ past_key_value=past_key_value,
608
+ output_attentions=output_attentions,
609
+ use_cache=use_cache,
610
+ )
611
+ hidden_states = residual + hidden_states
612
+
613
+ # Fully Connected
614
+ residual = hidden_states
615
+ # hidden_states = self.post_attention_layernorm(hidden_states, pad_mask)
616
+ hidden_states = self.post_attention_layernorm(hidden_states)
617
+ hidden_states = self.mlp(hidden_states)
618
+ hidden_states = residual + hidden_states
619
+
620
+ outputs = (hidden_states,)
621
+
622
+ if output_attentions:
623
+ outputs += (self_attn_weights,)
624
+
625
+ if use_cache:
626
+ outputs += (present_key_value,)
627
+ # breakpoint()
628
+ return outputs
629
+
630
+
631
+ LLAMA_START_DOCSTRING = r"""
632
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
633
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
634
+ etc.)
635
+
636
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
637
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
638
+ and behavior.
639
+
640
+ Parameters:
641
+ config ([`LlamaConfig`]):
642
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
643
+ load the weights associated with the model, only the configuration. Check out the
644
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
645
+ """
646
+
647
+
648
+ @add_start_docstrings(
649
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
650
+ LLAMA_START_DOCSTRING,
651
+ )
652
+ class LlamaPreTrainedModel(PreTrainedModel):
653
+ config_class = LlamaConfig
654
+ base_model_prefix = "model"
655
+ supports_gradient_checkpointing = True
656
+ _no_split_modules = ["LlamaDecoderLayer"]
657
+ _skip_keys_device_placement = "past_key_values"
658
+
659
+ def _init_weights(self, module):
660
+ std = self.config.initializer_range
661
+ if isinstance(module, nn.Linear):
662
+ module.weight.data.normal_(mean=0.0, std=std)
663
+ if module.bias is not None:
664
+ module.bias.data.zero_()
665
+ elif isinstance(module, nn.Embedding):
666
+ module.weight.data.normal_(mean=0.0, std=std)
667
+ if module.padding_idx is not None:
668
+ module.weight.data[module.padding_idx].zero_()
669
+
670
+ def _set_gradient_checkpointing(self, module, value=False):
671
+ if isinstance(module, LlamaModel):
672
+ module.gradient_checkpointing = value
673
+
674
+
675
+ LLAMA_INPUTS_DOCSTRING = r"""
676
+ Args:
677
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
678
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
679
+ it.
680
+
681
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
682
+ [`PreTrainedTokenizer.__call__`] for details.
683
+
684
+ [What are input IDs?](../glossary#input-ids)
685
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
686
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
687
+
688
+ - 1 for tokens that are **not masked**,
689
+ - 0 for tokens that are **masked**.
690
+
691
+ [What are attention masks?](../glossary#attention-mask)
692
+
693
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
694
+ [`PreTrainedTokenizer.__call__`] for details.
695
+
696
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
697
+ `past_key_values`).
698
+
699
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
700
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
701
+ information on the default strategy.
702
+
703
+ - 1 indicates the head is **not masked**,
704
+ - 0 indicates the head is **masked**.
705
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
706
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
707
+ config.n_positions - 1]`.
708
+
709
+ [What are position IDs?](../glossary#position-ids)
710
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
711
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
712
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
713
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
714
+
715
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
716
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
717
+
718
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
719
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
720
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
721
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
722
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
723
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
724
+ model's internal embedding lookup matrix.
725
+ use_cache (`bool`, *optional*):
726
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
727
+ `past_key_values`).
728
+ output_attentions (`bool`, *optional*):
729
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
730
+ tensors for more detail.
731
+ output_hidden_states (`bool`, *optional*):
732
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
733
+ more detail.
734
+ return_dict (`bool`, *optional*):
735
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
736
+ """
737
+
738
+
739
+ @add_start_docstrings(
740
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
741
+ LLAMA_START_DOCSTRING,
742
+ )
743
+ class LlamaModel(LlamaPreTrainedModel):
744
+ """
745
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
746
+
747
+ Args:
748
+ config: LlamaConfig
749
+ """
750
+
751
+ def __init__(self, config: LlamaConfig):
752
+ super().__init__(config)
753
+ self.padding_idx = config.pad_token_id
754
+ self.vocab_size = config.vocab_size
755
+
756
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
757
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
758
+ # self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
759
+ self.norm = linearnorm(config.hidden_size, eps=config.rms_norm_eps)
760
+
761
+ self.gradient_checkpointing = False
762
+ # Initialize weights and apply final processing
763
+ self.post_init()
764
+
765
+ def get_input_embeddings(self):
766
+ return self.embed_tokens
767
+
768
+ def set_input_embeddings(self, value):
769
+ self.embed_tokens = value
770
+
771
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
772
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
773
+ # create causal mask
774
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
775
+ combined_attention_mask = None
776
+ if input_shape[-1] > 1:
777
+ combined_attention_mask = _make_causal_mask(
778
+ input_shape,
779
+ inputs_embeds.dtype,
780
+ device=inputs_embeds.device,
781
+ past_key_values_length=past_key_values_length,
782
+ )
783
+ if attention_mask is not None:
784
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
785
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
786
+ inputs_embeds.device
787
+ )
788
+ combined_attention_mask = (
789
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
790
+ )
791
+
792
+ return combined_attention_mask
793
+
794
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
795
+ def forward(
796
+ self,
797
+ input_ids: torch.LongTensor = None,
798
+ attention_mask: Optional[torch.Tensor] = None,
799
+ position_ids: Optional[torch.LongTensor] = None,
800
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
801
+ inputs_embeds: Optional[torch.FloatTensor] = None,
802
+ use_cache: Optional[bool] = None,
803
+ output_attentions: Optional[bool] = None,
804
+ output_hidden_states: Optional[bool] = None,
805
+ return_dict: Optional[bool] = None,
806
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
807
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
808
+ output_hidden_states = (
809
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
810
+ )
811
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
812
+
813
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
814
+
815
+ # retrieve input_ids and inputs_embeds
816
+ if input_ids is not None and inputs_embeds is not None:
817
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
818
+ elif input_ids is not None:
819
+ batch_size, seq_length = input_ids.shape
820
+ elif inputs_embeds is not None:
821
+ batch_size, seq_length, _ = inputs_embeds.shape
822
+ else:
823
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
824
+
825
+ seq_length_with_past = seq_length
826
+ past_key_values_length = 0
827
+
828
+ if past_key_values is not None:
829
+ past_key_values_length = past_key_values[0][0].shape[2]
830
+ seq_length_with_past = seq_length_with_past + past_key_values_length
831
+
832
+ if position_ids is None:
833
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
834
+ position_ids = torch.arange(
835
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
836
+ )
837
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
838
+ else:
839
+ position_ids = position_ids.view(-1, seq_length).long()
840
+
841
+ if inputs_embeds is None:
842
+ inputs_embeds = self.embed_tokens(input_ids)
843
+ # embed positions
844
+ if attention_mask is None:
845
+ attention_mask = torch.ones(
846
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
847
+ )
848
+ pad_mask = attention_mask
849
+ attention_mask = self._prepare_decoder_attention_mask(
850
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
851
+ )
852
+
853
+ hidden_states = inputs_embeds
854
+
855
+ if self.gradient_checkpointing and self.training:
856
+ if use_cache:
857
+ logger.warning_once(
858
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
859
+ )
860
+ use_cache = False
861
+
862
+ # decoder layers
863
+ all_hidden_states = () if output_hidden_states else None
864
+ all_self_attns = () if output_attentions else None
865
+ next_decoder_cache = () if use_cache else None
866
+
867
+ for idx, decoder_layer in enumerate(self.layers):
868
+ if output_hidden_states:
869
+ all_hidden_states += (hidden_states,)
870
+
871
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
872
+
873
+ if self.gradient_checkpointing and self.training:
874
+
875
+ def create_custom_forward(module):
876
+ def custom_forward(*inputs):
877
+ # None for past_key_value
878
+ return module(*inputs, output_attentions, None)
879
+
880
+ return custom_forward
881
+
882
+ layer_outputs = torch.utils.checkpoint.checkpoint(
883
+ create_custom_forward(decoder_layer),
884
+ hidden_states,
885
+ attention_mask,
886
+ pad_mask,
887
+ position_ids,
888
+ None,
889
+ )
890
+ else:
891
+ layer_outputs = decoder_layer(
892
+ hidden_states,
893
+ attention_mask=attention_mask,
894
+ pad_mask=pad_mask,
895
+ position_ids=position_ids,
896
+ past_key_value=past_key_value,
897
+ output_attentions=output_attentions,
898
+ use_cache=use_cache,
899
+ )
900
+
901
+ hidden_states = layer_outputs[0]
902
+
903
+ if use_cache:
904
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
905
+
906
+ if output_attentions:
907
+ all_self_attns += (layer_outputs[1],)
908
+
909
+ hidden_states = self.norm(hidden_states, pad_mask)
910
+
911
+ # add hidden states from the last decoder layer
912
+ if output_hidden_states:
913
+ all_hidden_states += (hidden_states,)
914
+
915
+ next_cache = next_decoder_cache if use_cache else None
916
+ if not return_dict:
917
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
918
+ return BaseModelOutputWithPast(
919
+ last_hidden_state=hidden_states,
920
+ past_key_values=next_cache,
921
+ hidden_states=all_hidden_states,
922
+ attentions=all_self_attns,
923
+ )
924
+
925
+
926
+ class LlamaForCausalLM(LlamaPreTrainedModel):
927
+ _tied_weights_keys = ["lm_head.weight"]
928
+
929
+ def __init__(self, config):
930
+ super().__init__(config)
931
+ self.model = LlamaModel(config)
932
+
933
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
934
+
935
+ # Initialize weights and apply final processing
936
+ self.post_init()
937
+
938
+ def compress(self):
939
+ for name, module in self.model.named_modules():
940
+ if isinstance(module, LlamaAttention):
941
+ module.compress()
942
+
943
+ def get_input_embeddings(self):
944
+ return self.model.embed_tokens
945
+
946
+ def set_input_embeddings(self, value):
947
+ self.model.embed_tokens = value
948
+
949
+ def get_output_embeddings(self):
950
+ return self.lm_head
951
+
952
+ def set_output_embeddings(self, new_embeddings):
953
+ self.lm_head = new_embeddings
954
+
955
+ def set_decoder(self, decoder):
956
+ self.model = decoder
957
+
958
+ def get_decoder(self):
959
+ return self.model
960
+
961
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
962
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
963
+ def forward(
964
+ self,
965
+ input_ids: torch.LongTensor = None,
966
+ attention_mask: Optional[torch.Tensor] = None,
967
+ position_ids: Optional[torch.LongTensor] = None,
968
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
969
+ inputs_embeds: Optional[torch.FloatTensor] = None,
970
+ labels: Optional[torch.LongTensor] = None,
971
+ use_cache: Optional[bool] = None,
972
+ output_attentions: Optional[bool] = None,
973
+ output_hidden_states: Optional[bool] = None,
974
+ return_dict: Optional[bool] = None,
975
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
976
+ r"""
977
+ Args:
978
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
979
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
980
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
981
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
982
+
983
+ Returns:
984
+
985
+ Example:
986
+
987
+ ```python
988
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
989
+
990
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
991
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
992
+
993
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
994
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
995
+
996
+ >>> # Generate
997
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
998
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
999
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1000
+ ```"""
1001
+
1002
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1003
+ output_hidden_states = (
1004
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1005
+ )
1006
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1007
+
1008
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1009
+ outputs = self.model(
1010
+ input_ids=input_ids,
1011
+ attention_mask=attention_mask,
1012
+ position_ids=position_ids,
1013
+ past_key_values=past_key_values,
1014
+ inputs_embeds=inputs_embeds,
1015
+ use_cache=use_cache,
1016
+ output_attentions=output_attentions,
1017
+ output_hidden_states=output_hidden_states,
1018
+ return_dict=return_dict,
1019
+ )
1020
+
1021
+ hidden_states = outputs[0]
1022
+ logits = self.lm_head(hidden_states)
1023
+
1024
+ loss = None
1025
+ if labels is not None:
1026
+ # Shift so that tokens < n predict n
1027
+ shift_logits = logits[..., :-1, :].contiguous()
1028
+ shift_labels = labels[..., 1:].contiguous()
1029
+ # Flatten the tokens
1030
+ loss_fct = CrossEntropyLoss()
1031
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1032
+ shift_labels = shift_labels.view(-1)
1033
+ # Enable model parallelism
1034
+ shift_labels = shift_labels.to(shift_logits.device)
1035
+ loss = loss_fct(shift_logits, shift_labels)
1036
+
1037
+ if not return_dict:
1038
+ output = (logits,) + outputs[1:]
1039
+ return (loss,) + output if loss is not None else output
1040
+
1041
+ return CausalLMOutputWithPast(
1042
+ loss=loss,
1043
+ logits=logits,
1044
+ past_key_values=outputs.past_key_values,
1045
+ hidden_states=outputs.hidden_states,
1046
+ attentions=outputs.attentions,
1047
+ )
1048
+
1049
+ def prepare_inputs_for_generation(
1050
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1051
+ ):
1052
+ if past_key_values:
1053
+ input_ids = input_ids[:, -1:]
1054
+
1055
+ position_ids = kwargs.get("position_ids", None)
1056
+ if attention_mask is not None and position_ids is None:
1057
+ # create position_ids on the fly for batch generation
1058
+ position_ids = attention_mask.long().cumsum(-1) - 1
1059
+ position_ids.masked_fill_(attention_mask == 0, 1)
1060
+ if past_key_values:
1061
+ position_ids = position_ids[:, -1].unsqueeze(-1)
1062
+
1063
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1064
+ if inputs_embeds is not None and past_key_values is None:
1065
+ model_inputs = {"inputs_embeds": inputs_embeds}
1066
+ else:
1067
+ model_inputs = {"input_ids": input_ids}
1068
+
1069
+ model_inputs.update(
1070
+ {
1071
+ "position_ids": position_ids,
1072
+ "past_key_values": past_key_values,
1073
+ "use_cache": kwargs.get("use_cache"),
1074
+ "attention_mask": attention_mask,
1075
+ }
1076
+ )
1077
+ return model_inputs
1078
+
1079
+ @staticmethod
1080
+ def _reorder_cache(past_key_values, beam_idx):
1081
+ reordered_past = ()
1082
+ for layer_past in past_key_values:
1083
+ reordered_past += (
1084
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1085
+ )
1086
+ return reordered_past
1087
+
1088
+
1089
+ @add_start_docstrings(
1090
+ """
1091
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
1092
+
1093
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1094
+ (e.g. GPT-2) do.
1095
+
1096
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1097
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1098
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1099
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1100
+ each row of the batch).
1101
+ """,
1102
+ LLAMA_START_DOCSTRING,
1103
+ )
1104
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
1105
+ def __init__(self, config):
1106
+ super().__init__(config)
1107
+ self.num_labels = config.num_labels
1108
+ self.model = LlamaModel(config)
1109
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1110
+
1111
+ # Initialize weights and apply final processing
1112
+ self.post_init()
1113
+
1114
+ def get_input_embeddings(self):
1115
+ return self.model.embed_tokens
1116
+
1117
+ def set_input_embeddings(self, value):
1118
+ self.model.embed_tokens = value
1119
+
1120
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1121
+ def forward(
1122
+ self,
1123
+ input_ids: torch.LongTensor = None,
1124
+ attention_mask: Optional[torch.Tensor] = None,
1125
+ position_ids: Optional[torch.LongTensor] = None,
1126
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1127
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1128
+ labels: Optional[torch.LongTensor] = None,
1129
+ use_cache: Optional[bool] = None,
1130
+ output_attentions: Optional[bool] = None,
1131
+ output_hidden_states: Optional[bool] = None,
1132
+ return_dict: Optional[bool] = None,
1133
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1134
+ r"""
1135
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1136
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1137
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1138
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1139
+ """
1140
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1141
+
1142
+ transformer_outputs = self.model(
1143
+ input_ids,
1144
+ attention_mask=attention_mask,
1145
+ position_ids=position_ids,
1146
+ past_key_values=past_key_values,
1147
+ inputs_embeds=inputs_embeds,
1148
+ use_cache=use_cache,
1149
+ output_attentions=output_attentions,
1150
+ output_hidden_states=output_hidden_states,
1151
+ return_dict=return_dict,
1152
+ )
1153
+ hidden_states = transformer_outputs[0]
1154
+ logits = self.score(hidden_states)
1155
+
1156
+ if input_ids is not None:
1157
+ batch_size = input_ids.shape[0]
1158
+ else:
1159
+ batch_size = inputs_embeds.shape[0]
1160
+
1161
+ if self.config.pad_token_id is None and batch_size != 1:
1162
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1163
+ if self.config.pad_token_id is None:
1164
+ sequence_lengths = -1
1165
+ else:
1166
+ if input_ids is not None:
1167
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
1168
+ else:
1169
+ sequence_lengths = -1
1170
+
1171
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1172
+
1173
+ loss = None
1174
+ #added for avoiding OOM
1175
+ if labels is not None:
1176
+ # Shift so that tokens < n predict n
1177
+ shift_logits = logits[..., :-1, :].contiguous()
1178
+ shift_labels = labels[..., 1:].contiguous()
1179
+ # Flatten the tokens
1180
+ if self.training:
1181
+ loss_fct = CrossEntropyLoss()
1182
+ else:
1183
+ loss_fct = CrossEntropyLoss(reduction="none") # to calculate ppl
1184
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1185
+ shift_labels = shift_labels.view(-1)
1186
+ # Enable model parallelism
1187
+ shift_labels = shift_labels.to(shift_logits.device)
1188
+ loss = loss_fct(shift_logits, shift_labels)
1189
+ if not self.training:
1190
+ loss_reshaped = loss.view(labels.size(0), -1) # shape (bs, seq_len-1)
1191
+ logits = loss_reshaped.mean(dim=-1) # return example wise (chunked sequence) CE loss to calculate ppl and avoid GPU OOM
1192
+ loss = loss.mean()
1193
+ return CausalLMOutputWithPast(
1194
+ loss=loss,
1195
+ logits=logits,
1196
+ )
1197
+ # if labels is not None:
1198
+ # labels = labels.to(logits.device)
1199
+ # if self.config.problem_type is None:
1200
+ # if self.num_labels == 1:
1201
+ # self.config.problem_type = "regression"
1202
+ # elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1203
+ # self.config.problem_type = "single_label_classification"
1204
+ # else:
1205
+ # self.config.problem_type = "multi_label_classification"
1206
+ #
1207
+ # if self.config.problem_type == "regression":
1208
+ # loss_fct = MSELoss()
1209
+ # if self.num_labels == 1:
1210
+ # loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1211
+ # else:
1212
+ # loss = loss_fct(pooled_logits, labels)
1213
+ # elif self.config.problem_type == "single_label_classification":
1214
+ # loss_fct = CrossEntropyLoss()
1215
+ # loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1216
+ # elif self.config.problem_type == "multi_label_classification":
1217
+ # loss_fct = BCEWithLogitsLoss()
1218
+ # loss = loss_fct(pooled_logits, labels)
1219
+ if not return_dict:
1220
+ output = (pooled_logits,) + transformer_outputs[1:]
1221
+ return ((loss,) + output) if loss is not None else output
1222
+
1223
+ return SequenceClassifierOutputWithPast(
1224
+ loss=loss,
1225
+ logits=pooled_logits,
1226
+ past_key_values=transformer_outputs.past_key_values,
1227
+ hidden_states=transformer_outputs.hidden_states,
1228
+ attentions=transformer_outputs.attentions,
1229
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:628c42a30011666dab2ad5c4f8530420aac31fb52ab084cb57efd423eeb4d0bc
3
+ size 1471029801
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "[PAD]",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "</s>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "model_max_length": 2048,
22
+ "pad_token": null,
23
+ "padding_side": "right",
24
+ "sp_model_kwargs": {},
25
+ "tokenizer_class": "LlamaTokenizer",
26
+ "unk_token": {
27
+ "__type": "AddedToken",
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ },
34
+ "use_fast": true
35
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c1bba2234b662d6033876ac76b9ea5e11bc611ad8237abeefbaddbe96ef7b32
3
+ size 3963