Yanisadel commited on
Commit
5d37fe8
·
verified ·
1 Parent(s): 57e9187

Delete chatNT.py

Browse files
Files changed (1) hide show
  1. chatNT.py +0 -1898
chatNT.py DELETED
@@ -1,1898 +0,0 @@
1
- # This file stores ChatNT and all associated layers and configs
2
-
3
- from dataclasses import asdict, dataclass, field
4
- from typing import Dict, List, Optional, Tuple
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F # noqa: N812
10
- from transformers import PretrainedConfig, PreTrainedModel
11
-
12
-
13
- @dataclass
14
- class RotaryEmbeddingConfig:
15
- """
16
- Rotary Positional Embedding configuration
17
- max_seq_len: The number of positions to encode and cache.
18
- dim: Dimension of RoPE.
19
- theta: Rotation angle.
20
- """
21
-
22
- max_seq_len: int
23
- dim: int
24
- theta: float
25
-
26
-
27
- @dataclass
28
- class PerceiverResamplerConfig:
29
- """
30
- Parameters to initialize an PerceiverResampler model. Based on the ESM architecture.
31
-
32
- Args:
33
- emb_layer_norm_before: Whether to use layer norm before the first attention
34
- layer.
35
- attention_heads: Number of attention heads.
36
- key_size: The dimension of the query, key, and values within each attention
37
- head, if not specified, it is set to attention_heads//embed_dim.
38
- It can be useful to set a custom key size if we want to impose the size of
39
- the query, key and value tensor ( for example, tensors shaped with
40
- power of 2 are more efficiently handled on TPUs ).
41
- Note: Parametrizing the model with a custom key size has been done in :
42
- Brown, Tom, et al. "Language models are few-shot learners."
43
- Advances in neural information processing systems 33 (2020): 1877-1901.
44
- embed_dim: Embedding dimension.
45
- ffn_embed_dim: Feed forward embedding dimension.
46
- num_layers: Number of attention blocks.
47
- ffn_activation_name: Activation function to be used in FFN block. Supported
48
- names are "gelu", "relu", "swish".
49
- use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed
50
- Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg
51
- to True and use swish as ffn_activation_name.
52
- Same principle for a gated-relu. To keep the same number of parameters in
53
- the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU.
54
- See https://arxiv.org/pdf/2002.05202.pdf for more details.
55
- resampled_length: length of the resampled output of the module
56
- use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
57
- gradients in the forward pass to reduce the computation in the backward).
58
- """
59
-
60
- # architecture
61
- emb_layer_norm_before: bool = False
62
- attention_heads: int = 20
63
- key_size: Optional[int] = None
64
- embed_dim: int = 1280
65
- ffn_embed_dim: int = 5120
66
- num_layers: int = 24
67
- add_bias_kv: bool = False
68
- add_bias_ffn: bool = True
69
- ffn_activation_name: str = "gelu-no-approx"
70
- use_glu_in_ffn: bool = False
71
- resampled_length: int = 64
72
-
73
- # performance
74
- use_gradient_checkpointing: bool = False
75
-
76
- def __post_init__(self) -> None:
77
- """
78
- Checks that the given values are compatible.
79
- """
80
-
81
- if self.key_size is None:
82
- if not self.embed_dim % self.attention_heads == 0:
83
- raise ValueError(
84
- f"When no key size is provided, the embedding dimension should be "
85
- f"divisible by the number of heads, however provided embedding "
86
- f"dimension is {self.embed_dim} and the number of heads is "
87
- f"{self.attention_heads}."
88
- )
89
- self.key_size = self.embed_dim // self.attention_heads
90
-
91
-
92
- @dataclass
93
- class GptConfig:
94
- """
95
- Parameters to initialize a Gpt model.
96
-
97
- NOTE: the pad token is not defined
98
-
99
- Args:
100
- vocab_size: Token vocabulary.
101
- eos_token_id: used to stop sentence generation
102
- embed_dim: Embedding dimension.
103
- ffn_embed_dim: Feed forward embedding dimension.
104
- num_heads: Number of attention heads.
105
- num_kv_heads: Number of key and value heads to support Grouped-Query and
106
- Multi-Query Attention. If None, the number of key and value heads is
107
- equal to the number of attention heads.
108
- num_layers: Number of Decoder layer_stack
109
- rope_config: The configuration for the rotary positional embeddings
110
- add_bias_ffn: Add bias in feed forward network block.
111
- ffn_activation_name: Activation function to be used in FFN block. Supported
112
- names are "gelu", "gelu-no-approx", "relu", "swish".
113
- use_glu_in_ffn: whether to use Gated Linear Unit (GLU) in Feed
114
- Forward Network (FFN) block.
115
- example: To do a swiGLU (gated-swish) put this arg
116
- to True and use swish as ffn_activation_name.
117
- Same principle for a gated-relu.
118
- add_bias_lm_head: whether to use bias in the final LM layer
119
- norm_type: The type of norm used ( pre normalization scheme ) used. can be
120
- one of ["layer_norm", "RMS_norm"]
121
- parallel_attention_ff: Whether to do the attention and the MLP in parallel,
122
- and then sum up the results as it is done in Gpt-NeoX :
123
- Black, Sid, et al. "Gpt-neox-20b: An open-source autoregressive
124
- language model." arXiv preprint arXiv:2204.06745 (2022).
125
- It is said to improve the training time of 15% when compiling with JAX
126
- use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
127
- gradients in the forward pass to reduce the computation in the backward).
128
- add_bias_attn: Add bias to the attention mechanism (key, query, value, and
129
- output projections).
130
- """
131
-
132
- # vocabulary
133
- vocab_size: int
134
- eos_token_id: int
135
-
136
- # architecture
137
- embed_dim: int = 16
138
- ffn_embed_dim: int = 64
139
- num_heads: int = 2
140
- num_kv_heads: Optional[int] = None
141
- num_layers: int = 2
142
- rope_config: RotaryEmbeddingConfig = field(
143
- default_factory=lambda: RotaryEmbeddingConfig(
144
- max_seq_len=512, dim=8, theta=10000.0
145
- )
146
- )
147
- add_bias_ffn: bool = False
148
- ffn_activation_name: str = "swish"
149
- use_glu_in_ffn: bool = True
150
- add_bias_lm_head: bool = False
151
- norm_type: str = "RMS_norm"
152
- rms_norm_eps: float = 1e-6
153
- parallel_attention_ff: bool = True
154
-
155
- # inference / backward behavior
156
- use_gradient_checkpointing: bool = False
157
-
158
- # architecture params with default values
159
- add_bias_attn: bool = False
160
-
161
- def __post_init__(self) -> None:
162
- """
163
- Checks that the given values are compatible.
164
- """
165
- if not self.embed_dim % self.num_heads == 0:
166
- raise ValueError(
167
- f"The embedding dimension should be "
168
- f"divisible by the number of heads, however provided embedding "
169
- f"dimension is {self.embed_dim} and the number of heads is "
170
- f"{self.num_heads}."
171
- )
172
-
173
- if not self.embed_dim // self.num_heads > 1:
174
- raise ValueError(
175
- "embed_dim / num_heads must be higher than 2 to apply rotary embeddings"
176
- )
177
-
178
- if not self.embed_dim // self.num_heads >= self.rope_config.dim:
179
- raise ValueError(
180
- "embed_dim // num_heads must be higher than rope_config.dim "
181
- "to apply rotary embeddings"
182
- )
183
-
184
- def to_dict(self): # type: ignore
185
- output = asdict(self)
186
- output["rope_config"] = asdict(self.rope_config)
187
- return output
188
-
189
-
190
- @dataclass
191
- class ESMTransformerConfig:
192
- """
193
- Parameters to initialize an ESM model. While the ESM architecture is an encoder-only
194
- model, different choices have been made for each version and this configuration aims
195
- to cover most of them.
196
-
197
- Args:
198
- alphabet_size: Token vocabulary.
199
- pad_token_id: ID of pad token.
200
- mask_token_id: ID of mask token.
201
- max_positions: Maximum sequence length.
202
- embed_scale: Correction ratio applied to the embeddings to make up for the
203
- norm difference between the input during training and inference.
204
- emb_layer_norm_before: Whether to use layer norm before the first attention
205
- layer.
206
- attention_heads: Number of attention heads.
207
- key_size: The dimension of the query, key, and values within each attention
208
- head, if not specified, it is set to attention_heads//embed_dim.
209
- It can be useful to set a custom key size if we want to impose the size of
210
- the query, key and value tensor ( for example, tensors shaped with
211
- power of 2 are more efficiently handled on TPUs ).
212
- Note: Parametrizing the model with a custom key size has been done in :
213
- Brown, Tom, et al. "Language models are few-shot learners."
214
- Advances in neural information processing systems 33 (2020): 1877-1901.
215
- embed_dim: Embedding dimension.
216
- ffn_embed_dim: Feed forward embedding dimension.
217
- num_layers: Number of attention blocks.
218
- positional_embedding: Type of positional embedding to use before the first
219
- attention layer. Options: "learned", "learned_standard" "sinusoidal" or
220
- None.
221
- NOTE: "learned" is the positional embedding of ESM, and "learned_standard"
222
- is a more standard one, used for example in DNAbert.
223
- lm_head: type of language model head. Options: "simple", "roberta" or None.
224
- add_bias_kv: Add bias in attention layer.
225
- add_bias_ffn: Add bias in feed forward network block.
226
- use_rotary_embedding: Whether to use rotary embeddings (for ESM2). Requires:
227
- positional_embeddings = None.
228
- rescaling_factor: Scaling factor to use for rotary embeddings.
229
- ffn_activation_name: Activation function to be used in FFN block. Supported
230
- names are "gelu", "relu", "swish".
231
- use_glu_in_ffn: Whether to use Gated Linear Unit (GLU) in Feed
232
- Forward Network (FFN) block. To do a swiGLU (gated-swish) put this arg
233
- to True and use swish as ffn_activation_name.
234
- Same principle for a gated-relu. To keep the same number of parameters in
235
- the FFN block, one should multiply by 2/3 the ffn_embed_dim when using GLU.
236
- See https://arxiv.org/pdf/2002.05202.pdf for more details.
237
- mask_before_attention: Use mask before attention layers (for EMS1b and ESM2).
238
- layer_norm_eps: the eps factor in the different layer norms of the model (refer
239
- to layer norm implementation)
240
- token_dropout: Token dropout.
241
- masking_ratio: Masking ratio (used if token dropout is enabled).
242
- masking_prob: Masking probability (used if token dropout is enabled).
243
- use_gradient_checkpointing: Whether to use gradient checkpointing (checkpoint
244
- gradients in the forward pass to reduce the computation in the backward).
245
- """
246
-
247
- alphabet_size: int
248
- pad_token_id: int
249
- mask_token_id: int
250
-
251
- max_positions: int = 1024
252
- embed_scale: float = 1.0
253
-
254
- # architecture
255
- emb_layer_norm_before: bool = False
256
- attention_heads: int = 20
257
- key_size: Optional[int] = None
258
- embed_dim: int = 1280
259
- ffn_embed_dim: int = 5120
260
- num_layers: int = 24
261
- positional_embedding: Optional[str] = "learned"
262
- lm_head: Optional[str] = "simple"
263
- add_bias_kv: bool = False
264
- add_bias_ffn: bool = True
265
- use_rotary_embedding: bool = False
266
- rescaling_factor: Optional[float] = None
267
- ffn_activation_name: str = "gelu-no-approx"
268
- use_glu_in_ffn: bool = False
269
- mask_before_attention: bool = False
270
- layer_norm_eps: float = 1e-5
271
- pre_layer_norm: bool = True
272
- bias_word_embedding: bool = False
273
-
274
- # dropout
275
- token_dropout: bool = False
276
- masking_ratio: float = 0.1
277
- masking_prob: float = 0.8
278
-
279
- # logging
280
- use_gradient_checkpointing: bool = False
281
-
282
- # return
283
- embeddings_layers_to_save: List[int] = field(default_factory=list)
284
- attention_maps_to_save: List[Tuple[int, int]] = field(default_factory=list)
285
-
286
- def __post_init__(self) -> None:
287
- """
288
- Checks that the given values are compatible.
289
- """
290
-
291
- if self.key_size is None:
292
- if not self.embed_dim % self.attention_heads == 0:
293
- raise ValueError(
294
- f"When no key size is provided, the embedding dimension should be "
295
- f"divisible by the number of heads, however provided embedding "
296
- f"dimension is {self.embed_dim} and the number of heads is "
297
- f"{self.attention_heads}."
298
- )
299
- self.key_size = self.embed_dim // self.attention_heads
300
- if self.positional_embedding is not None:
301
- if type(self.positional_embedding) != str:
302
- raise TypeError
303
-
304
- if self.positional_embedding not in [
305
- "learned",
306
- "sinusoidal",
307
- "learned_standard",
308
- "alibi_dnabert_2",
309
- ]:
310
- raise ValueError(
311
- "The positional_embedding argument should either be None,"
312
- "`learned`, `sinusoidal`, 'learned_standard' or 'alibi_dnabert_2'."
313
- )
314
- if self.lm_head is not None:
315
- if type(self.lm_head) != str:
316
- raise TypeError
317
-
318
- if self.lm_head not in ["simple", "roberta"]:
319
- raise ValueError(
320
- "The lm_head argument should either be None,"
321
- "`simple` or `roberta`."
322
- )
323
-
324
- if self.use_rotary_embedding and self.positional_embedding is not None:
325
- raise ValueError(
326
- "When using rotary embedding, positional_embedding must be set to none"
327
- )
328
-
329
- if self.add_bias_kv and self.use_rotary_embedding:
330
- raise ValueError(
331
- "Biases on key and values are not compatible with Rotary embeddings."
332
- )
333
-
334
- if self.positional_embedding == "alibi_dnabert_2":
335
- assert not self.add_bias_kv
336
-
337
-
338
- @dataclass
339
- class ChatNTConfig(PretrainedConfig):
340
- model_type = "ChatNT"
341
-
342
- def __init__(self, **kwargs): # type: ignore
343
- self.gpt_config: GptConfig = kwargs.get("gpt_config", GptConfig(32000, 3))
344
- self.esm_config: ESMTransformerConfig = kwargs.get(
345
- "esm_config", ESMTransformerConfig(4000, 1, 4)
346
- )
347
- self.perceiver_resampler_config: PerceiverResamplerConfig = kwargs.get(
348
- "perceiver_resampler_config", PerceiverResamplerConfig()
349
- )
350
- self.seq_token_id: int = kwargs.get("seq_token_id", 32000)
351
- self.bio_pad_token_id: int = kwargs.get("bio_pad_token_id", 1)
352
- self.english_pad_token_id: int = kwargs.get("english_pad_token_id", 2)
353
- super().__init__(**kwargs)
354
-
355
- def to_dict(self): # type: ignore
356
- output = super().to_dict()
357
-
358
- def serialize(obj): # type: ignore
359
- return obj.to_dict() if hasattr(obj, "to_dict") else vars(obj)
360
-
361
- output["gpt_config"] = serialize(self.gpt_config) # type: ignore
362
- output["esm_config"] = serialize(self.esm_config) # type: ignore
363
- output["perceiver_resampler_config"] = serialize( # type: ignore
364
- self.perceiver_resampler_config
365
- )
366
- return output
367
-
368
-
369
- class TorchBioBrainDecoder(nn.Module):
370
- def __init__(
371
- self,
372
- gpt_config: GptConfig,
373
- seq_token_id: int,
374
- ):
375
- """
376
- Initializes the BioBrain decoder, using a GPT model for text generation with
377
- bio embeddings.
378
-
379
- Args:
380
- gpt_config: Configuration for the GPT model
381
- seq_token_id: Index of the SEQ token
382
- """
383
- super(TorchBioBrainDecoder, self).__init__()
384
- self.gpt_config = gpt_config
385
- self.seq_token_id = seq_token_id
386
-
387
- # Initialize the GPT model (assumed you have it already in PyTorch)
388
- self.gpt_model = TorchGptDecoder(self.gpt_config)
389
-
390
- def forward(
391
- self, english_token_ids: torch.Tensor, projected_bio_embeddings: torch.Tensor
392
- ) -> torch.Tensor:
393
- """
394
- Forward pass through the model.
395
-
396
- Args:
397
- english_token_ids: Tensor of English token IDs with shape
398
- (batch_size, num_english_tokens).
399
- projected_bio_embeddings: Optional tensor of bio embeddings with shape
400
- (batch_size, num_bio_sequences, ?, embed_dim).
401
-
402
- Returns:
403
- torch.Tensor: The logits from the GPT model,
404
- shaped (batch_size, num_english_tokens, vocab_size).
405
- """
406
-
407
- # Compute English token embeddings
408
- tokens_embeddings = self.gpt_model.token_embed(english_token_ids)
409
-
410
- if projected_bio_embeddings is not None:
411
- (
412
- batch_size,
413
- num_bio_sequences,
414
- _,
415
- bio_embed_dim,
416
- ) = projected_bio_embeddings.shape
417
-
418
- # Insert the bio embeddings at the SEQ token positions
419
- processed_tokens_ids = english_token_ids.clone()
420
- for bio_seq_num in range(num_bio_sequences):
421
- tokens_embeddings, processed_tokens_ids = self.insert_embeddings(
422
- processed_tokens_ids,
423
- tokens_embeddings,
424
- projected_bio_embeddings[:, bio_seq_num, :, :],
425
- bio_seq_num=bio_seq_num,
426
- )
427
-
428
- # Regular GPT pass through
429
- embeddings = self.gpt_model.apply_transformer_layers(tokens_embeddings)
430
- embeddings = self.gpt_model.final_norm(embeddings)
431
-
432
- # Compute logits
433
- logits = self.gpt_model.lm_head(embeddings)
434
-
435
- if projected_bio_embeddings is not None:
436
- # Clean logits sequentially
437
- processed_tokens_ids = english_token_ids.clone()
438
- resampled_length = projected_bio_embeddings.shape[-2]
439
- for _ in range(num_bio_sequences):
440
- logits, processed_tokens_ids = self.cleanup_logits(
441
- tokens=processed_tokens_ids,
442
- logits=logits,
443
- resampled_length=resampled_length,
444
- )
445
-
446
- return logits
447
-
448
- def insert_embeddings(
449
- self,
450
- tokens: torch.Tensor,
451
- input_embeddings: torch.Tensor,
452
- resampled_embeddings: torch.Tensor,
453
- bio_seq_num: int,
454
- ) -> Tuple[torch.Tensor, torch.Tensor]:
455
- """
456
- Inserts resampled embeddings in input_embeddings, starting at the SEQ token
457
-
458
- Args:
459
- tokens (torch.Tensor): Shape (batch_size, num_tokens)
460
- input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
461
- resampled_embeddings (torch.Tensor):
462
- Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
463
-
464
- Returns:
465
- Tuple[torch.Tensor, torch.Tensor]:
466
- - input_embeddings with resampled_embeddings inserted at the SEQ token
467
- - tokens with the SEQ token set to -1
468
- """
469
-
470
- def _insert(
471
- tokens_1d: torch.Tensor,
472
- input_embeddings_1d: torch.Tensor,
473
- resampled_embeddings_1d: torch.Tensor,
474
- ) -> Tuple[torch.Tensor, torch.Tensor]:
475
- """
476
- Args:
477
- tokens (torch.Tensor): Shape (num_tokens,)
478
- input_embeddings (torch.Tensor): Shape (num_tokens, embed_dim,)
479
- resampled_embeddings (torch.Tensor):
480
- Shape (bio_sequence_length, embed_dim,)
481
- """
482
- indices = torch.where(tokens_1d == self.seq_token_id)[0]
483
- if indices.numel() > 0:
484
- idx = indices[0].item()
485
- insertion_pos = idx + resampled_embeddings_1d.shape[-2] * bio_seq_num
486
- x = torch.cat(
487
- [
488
- input_embeddings_1d[:insertion_pos, :],
489
- resampled_embeddings_1d,
490
- input_embeddings_1d[insertion_pos:, :],
491
- ],
492
- dim=0,
493
- )[: tokens_1d.shape[0] + 1, :]
494
- x = torch.roll(torch.roll(x, shifts=-idx, dims=0), shifts=idx, dims=0)[
495
- :-1, :
496
- ]
497
- tokens_1d[idx] = -1
498
- return x, tokens_1d
499
- else:
500
- return (
501
- input_embeddings,
502
- tokens_1d,
503
- ) # Return unchanged if seq_token_id is not found
504
-
505
- tokens_acc = []
506
- embeddings_acc = []
507
-
508
- for i in range(tokens.shape[0]):
509
- embeddings_out, tokens_out = _insert(
510
- tokens[i].clone(),
511
- input_embeddings[i].clone(),
512
- resampled_embeddings[i].clone(),
513
- )
514
- tokens_acc.append(tokens_out)
515
- embeddings_acc.append(embeddings_out)
516
- tokens_acc = torch.stack(tokens_acc)
517
- embeddings_acc = torch.stack(embeddings_acc)
518
-
519
- return embeddings_acc, tokens_acc
520
-
521
- def cleanup_logits(
522
- self, tokens: torch.Tensor, logits: torch.Tensor, resampled_length: int
523
- ) -> Tuple[torch.Tensor, torch.Tensor]:
524
- """
525
- Removes the logits corresponding to the unused embeddings.
526
-
527
- Args:
528
- tokens: Input english tokens.
529
- logits: Input logits.
530
-
531
- Returns:
532
- Cleaned logits, last values will be equal to 0.
533
- """
534
-
535
- def _clean(
536
- token: torch.Tensor, logit: torch.Tensor
537
- ) -> Tuple[torch.Tensor, torch.Tensor]:
538
- indices = torch.where(token == self.seq_token_id)[0]
539
- if indices.numel() > 0:
540
- idx = indices[0].item()
541
-
542
- mask_idx = (
543
- torch.arange(logit.shape[0] - resampled_length, device=logit.device)
544
- > idx
545
- )
546
- mask_idx = mask_idx.unsqueeze(1)
547
-
548
- # Remove values corresponding to bio tokens
549
- logit = (
550
- logit[:-resampled_length] * (~mask_idx)
551
- + logit[resampled_length:] * mask_idx
552
- )
553
-
554
- # Append zeros at the end
555
- logit = torch.cat(
556
- (
557
- logit,
558
- torch.zeros(
559
- (resampled_length, logit.shape[1]),
560
- dtype=logit.dtype,
561
- device=logit.device,
562
- ),
563
- )
564
- )
565
-
566
- # Update token
567
- token[idx] = -1
568
-
569
- return logit, token
570
-
571
- else:
572
- return logit, token
573
-
574
- tokens_acc = []
575
- logits_acc = []
576
-
577
- for i in range(tokens.shape[0]):
578
- logits_out, tokens_out = _clean(tokens[i].clone(), logits[i].clone())
579
- tokens_acc.append(tokens_out)
580
- logits_acc.append(logits_out)
581
- tokens_acc = torch.stack(tokens_acc)
582
- logits_acc = torch.stack(logits_acc)
583
-
584
- return logits_acc, tokens_acc
585
-
586
-
587
- class TorchMultiOmicsModel(PreTrainedModel):
588
- config_class = ChatNTConfig
589
-
590
- def __init__(self, config: ChatNTConfig) -> None:
591
- if isinstance(config, dict):
592
- # If config is a dictionary instead of ChatNTConfig (which can happen
593
- # depending how the config was saved), we convert it to the config
594
- config["gpt_config"]["rope_config"] = RotaryEmbeddingConfig(
595
- **config["gpt_config"]["rope_config"]
596
- )
597
- config["gpt_config"] = GptConfig(**config["gpt_config"])
598
- config["esm_config"] = ESMTransformerConfig(**config["esm_config"])
599
- config["perceiver_resampler_config"] = PerceiverResamplerConfig(
600
- **config["perceiver_resampler_config"]
601
- )
602
- config = ChatNTConfig(**config) # type: ignore
603
-
604
- else:
605
- if isinstance(config.gpt_config, dict):
606
- config.gpt_config["rope_config"] = RotaryEmbeddingConfig(
607
- **config.gpt_config["rope_config"]
608
- )
609
- config.gpt_config = GptConfig(**config.gpt_config)
610
-
611
- if isinstance(config.esm_config, dict):
612
- config.esm_config = ESMTransformerConfig(**config.esm_config)
613
-
614
- if isinstance(config.perceiver_resampler_config, dict):
615
- config.perceiver_resampler_config = PerceiverResamplerConfig(
616
- **config.perceiver_resampler_config
617
- )
618
-
619
- super().__init__(config=config)
620
- self.gpt_config = config.gpt_config
621
- self.esm_config = config.esm_config
622
- self.perceiver_resampler_config = config.perceiver_resampler_config
623
- self.seq_token_id = config.seq_token_id
624
- self.bio_pad_token_id = config.bio_pad_token_id
625
- self.english_pad_token_id = config.english_pad_token_id
626
-
627
- # Correct seq_token_id
628
- self.seq_token_id -= 1
629
-
630
- self.biobrain_encoder = TorchBioBrainEncoder(esm_config=self.esm_config)
631
- self.biobrain_decoder = TorchBioBrainDecoder(
632
- gpt_config=self.gpt_config, seq_token_id=self.seq_token_id
633
- )
634
- self.projection_model = TorchMultiModalPerceiverResamplerProjection(
635
- perceiver_resampler_config=self.perceiver_resampler_config,
636
- input_embed_dim=self.esm_config.embed_dim,
637
- embed_dim=self.gpt_config.embed_dim,
638
- english_vocab_size=self.gpt_config.vocab_size,
639
- bio_pad_token_id=self.bio_pad_token_id,
640
- english_pad_token_id=self.english_pad_token_id,
641
- )
642
-
643
- def forward(
644
- self,
645
- multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
646
- projection_english_tokens_ids: torch.Tensor,
647
- projected_bio_embeddings: torch.Tensor = None,
648
- ) -> dict[str, torch.Tensor]:
649
- """
650
-
651
- Args:
652
- multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
653
- english_tokens_ids: Represents the prompt tokens (english tokens)
654
- Shape (batch_size, num_english_tokens)
655
-
656
- bio_tokens_ids: Represents the bio sequences tokens
657
- Shape (batch_size, num_bio_sequences, num_bio_tokens)
658
-
659
- projection_english_tokens_ids (torch.Tensor):
660
- Shape (batch_size, num_english_tokens)
661
-
662
- projected_bio_embeddings (projected_bio_embeddings, optional):
663
- Shape (batch_size, num_bio_sequencse, ?, embed_dim).
664
- Defaults to None.
665
-
666
- Returns:
667
- dict[str, torch.Tensor] containing:
668
- - logits:
669
- Shape (batch_size, num_tokens, vocab_size)
670
-
671
- - projected_bio_embeddings:
672
- Shape (batch_size, num_bio_sequences, ?, embed_dim)
673
- """
674
- english_token_ids, bio_token_ids = multi_omics_tokens_ids
675
- english_token_ids = english_token_ids.clone()
676
- bio_token_ids = bio_token_ids.clone()
677
- projection_english_tokens_ids = projection_english_tokens_ids.clone()
678
- if projected_bio_embeddings is not None:
679
- projected_bio_embeddings = projected_bio_embeddings.clone()
680
-
681
- # Replace config.vocab_size value in english tokens
682
- # We do this because the default vocab size (32000) doesn't match with the
683
- # number of tokens because of seq_token_id(=32000) that was added
684
- # Therefore, we will put seq_token_id to 31999
685
- # (I will also put token n°31999 to 0, which is for unknown token)
686
- # This is a workaround to avoid having to change the vocab size in the config
687
- vocab_size = self.gpt_config.vocab_size
688
- # Replace vocab
689
- english_token_ids[english_token_ids == vocab_size - 1] = 0
690
- projection_english_tokens_ids[
691
- projection_english_tokens_ids == vocab_size - 1
692
- ] = 0
693
- english_token_ids[english_token_ids == vocab_size] = vocab_size - 1
694
- projection_english_tokens_ids[projection_english_tokens_ids == vocab_size] = (
695
- vocab_size - 1
696
- )
697
-
698
- if bio_token_ids is None:
699
- projected_bio_embeddings = None
700
- else:
701
- num_bio_sequences = bio_token_ids.shape[1]
702
-
703
- if projected_bio_embeddings is None:
704
- # Compute bio sequences embeddings
705
- bio_embeddings_list = [
706
- self.biobrain_encoder(bio_token_ids=bio_token_ids[:, bio_seq_num])
707
- for bio_seq_num in range(num_bio_sequences)
708
- ]
709
-
710
- # Project these embeddings
711
- projected_bio_embeddings = [
712
- self.projection_model(
713
- bio_token_ids=bio_token_ids[:, bio_seq_num],
714
- bio_embeddings=bio_embeddings,
715
- english_token_ids=projection_english_tokens_ids,
716
- )
717
- for bio_seq_num, bio_embeddings in enumerate(bio_embeddings_list)
718
- ]
719
- projected_bio_embeddings = torch.stack(projected_bio_embeddings, dim=1)
720
-
721
- # decode
722
- logits = self.biobrain_decoder(
723
- english_token_ids=english_token_ids,
724
- projected_bio_embeddings=projected_bio_embeddings,
725
- )
726
-
727
- outs = {"logits": logits, "projected_bio_embeddings": projected_bio_embeddings}
728
-
729
- return outs
730
-
731
-
732
- class TorchRotaryEmbedding(torch.nn.Module):
733
- def __init__(self, config: RotaryEmbeddingConfig):
734
- super().__init__()
735
-
736
- self.max_seq_len = config.max_seq_len
737
- self.dim = config.dim
738
- self.theta = config.theta
739
- self.sincos_cache = None
740
-
741
- def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
742
- """
743
- Create the sines and cosines for the RoPE.
744
-
745
- Returns:
746
- Sinusoidal positions of shape (self.max_seq_len, self.dim).
747
- """
748
- # Create the inverse frequency based on theta and dim
749
- inv_freq = 1.0 / (
750
- self.theta
751
- ** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
752
- )
753
-
754
- # Compute sinusoidal input using the broadcasting
755
- sinusoid_inp = torch.einsum(
756
- "i,j->ij", torch.arange(self.max_seq_len, device=device).float(), inv_freq
757
- )
758
-
759
- # Apply sin and cos to the sinusoidal input
760
- sin, cos = sinusoid_inp.sin(), sinusoid_inp.cos()
761
-
762
- # Allocate a tensor for the final sin-cos values
763
- sincos = torch.zeros(
764
- (self.max_seq_len, self.dim), dtype=torch.float32, device=device
765
- )
766
-
767
- # Fill the sincos tensor with sin and cos values
768
- sentinel = self.dim // 2 + self.dim % 2
769
- sincos[:, :sentinel] = sin
770
- sincos[:, sentinel:] = cos
771
-
772
- return sincos
773
-
774
- def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
775
- """
776
- Prepare a tensor to apply the RoPE mechanism.
777
-
778
- Args:
779
- x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
780
- typically this is the key or query tensor.
781
-
782
- Returns:
783
- The even indices in the last dimension have their sign flipped.
784
- Tensor of shape (batch_size, seq_len, num_heads, head_dim).
785
- """
786
- # Split the tensor into two halves (odd and even indexed dimensions)
787
- rotate_half = torch.stack((-x[..., 1::2], x[..., ::2]), dim=-1)
788
-
789
- # Reshape the tensor to the original shape
790
- rotate_half = rotate_half.view(rotate_half.shape[:-2] + (-1,))
791
- return rotate_half
792
-
793
- def _apply_rotary_pos_emb(
794
- self, x: torch.Tensor, sincos: torch.Tensor
795
- ) -> torch.Tensor:
796
- """
797
- Applies rotary embeddings to x.
798
-
799
- Args:
800
- x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
801
- typically this is the key or query tensor.
802
- sincos: Tuple of sine and cosine tensors for position encoding.
803
-
804
- Returns:
805
- RoPE embeddings tensor.
806
- """
807
- sin_pos, cos_pos = sincos
808
-
809
- # Reshape the sin and cos tensors for broadcasting
810
- sin_pos = torch.repeat_interleave(sin_pos.unsqueeze(2), repeats=2, dim=-1)
811
- cos_pos = torch.repeat_interleave(cos_pos.unsqueeze(2), repeats=2, dim=-1)
812
-
813
- # Apply the rotary embedding mechanism
814
- return (x * cos_pos) + (self._rotate_every_two(x) * sin_pos)
815
-
816
- def __call__(
817
- self, k: torch.Tensor, q: torch.Tensor, positions: Optional[torch.Tensor] = None
818
- ) -> tuple[torch.Tensor, torch.Tensor]:
819
- """
820
- Applies rotary embeddings to k and q.
821
-
822
- Args:
823
- k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
824
- q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
825
- positions: optional positions offset useful when caching,
826
-
827
- Returns:
828
- RoPE embeddings for the keys and values.
829
- """
830
- if self.sincos_cache is None:
831
- device = k.device
832
- self.sincos_cache = self._create_sinusoidal_positions(device=device)
833
-
834
- batch_size, seq_len, num_heads, head_dim = k.shape
835
-
836
- # Generate position ids
837
- position_ids = (
838
- torch.arange(seq_len, device=k.device).unsqueeze(0).expand(batch_size, -1)
839
- )
840
-
841
- if positions is not None:
842
- position_ids += positions
843
-
844
- # Retrieve sincos values using the position_ids
845
- sincos = self.sincos_cache[position_ids] # type: ignore
846
-
847
- # Split sincos into sin_pos and cos_pos
848
- sincos = torch.chunk(sincos, 2, dim=-1)
849
-
850
- # Apply rotary position embedding to key (k) and query (q)
851
- k_rot = self._apply_rotary_pos_emb(k[..., : self.dim], sincos)
852
- k_pass = k[..., self.dim :]
853
-
854
- q_rot = self._apply_rotary_pos_emb(q[..., : self.dim], sincos)
855
- q_pass = q[..., self.dim :]
856
-
857
- # Concatenate the rotated and non-rotated parts
858
- keys = torch.cat([k_rot, k_pass], dim=-1)
859
- values = torch.cat([q_rot, q_pass], dim=-1)
860
-
861
- return keys, values
862
-
863
-
864
- class TorchGptGroupedQueryAttention(nn.Module):
865
- def __init__(
866
- self,
867
- embed_dim: int,
868
- num_heads: int,
869
- rope_config: RotaryEmbeddingConfig,
870
- num_kv_heads: int = None, # type: ignore
871
- head_dim: int = None, # type: ignore
872
- add_bias_attn: bool = False, # type: ignore
873
- ) -> None:
874
- super().__init__()
875
- self.num_heads = num_heads
876
- self.num_kv_heads = num_kv_heads or num_heads
877
- self.embed_dim = embed_dim
878
- self.head_dim = head_dim or (embed_dim // num_heads)
879
- self.add_bias_attn = add_bias_attn
880
- self.rope = TorchRotaryEmbedding(rope_config)
881
-
882
- self.query_linear = nn.Linear(
883
- embed_dim, self.num_heads * self.head_dim, bias=add_bias_attn
884
- )
885
- self.key_linear = nn.Linear(
886
- embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn
887
- )
888
- self.value_linear = nn.Linear(
889
- embed_dim, self.num_kv_heads * self.head_dim, bias=add_bias_attn
890
- )
891
- self.out_linear = nn.Linear(
892
- self.num_heads * self.head_dim, embed_dim, bias=add_bias_attn
893
- )
894
-
895
- def forward(
896
- self,
897
- query_inputs: torch.Tensor,
898
- key_inputs: torch.Tensor,
899
- value_inputs: torch.Tensor,
900
- attention_mask: torch.Tensor = None,
901
- ) -> torch.Tensor:
902
- batch_size, seq_len, _ = query_inputs.shape
903
-
904
- queries = self.query_linear(query_inputs).view( # noqa
905
- batch_size, seq_len, self.num_heads, self.head_dim
906
- )
907
- keys = self.key_linear(key_inputs).view( # noqa
908
- batch_size, seq_len, self.num_kv_heads, self.head_dim
909
- )
910
- values = self.value_linear(value_inputs).view( # noqa
911
- batch_size, seq_len, self.num_kv_heads, self.head_dim
912
- )
913
-
914
- keys, queries = self.rope(keys, queries)
915
-
916
- n_rep = self.num_heads // self.num_kv_heads
917
- keys = keys.repeat_interleave(n_rep, dim=2)
918
- values = values.repeat_interleave(n_rep, dim=2)
919
-
920
- attention_logits = torch.einsum("bthd,bThd->bhtT", queries, keys) / (
921
- self.head_dim**0.5
922
- )
923
-
924
- if attention_mask is not None:
925
- attention_logits = attention_logits.masked_fill(
926
- attention_mask == 0, float("-inf")
927
- )
928
-
929
- attention_weights = nn.functional.softmax(attention_logits, dim=-1)
930
-
931
- values = torch.einsum("bhtT,bThd->bthd", attention_weights, values)
932
- values = values.contiguous().view(batch_size, seq_len, -1)
933
-
934
- return self.out_linear(values)
935
-
936
-
937
- class TorchGptDecoder(nn.Module):
938
- def __init__(self, config: GptConfig, name: Optional[str] = None):
939
- super().__init__()
940
- self.config = config
941
-
942
- self.token_embed = nn.Embedding(config.vocab_size, config.embed_dim)
943
-
944
- if config.norm_type == "layer_norm":
945
- self.final_norm = nn.LayerNorm(config.embed_dim)
946
- elif config.norm_type == "RMS_norm":
947
- self.final_norm = TorchRMSNorm(config.embed_dim, eps=config.rms_norm_eps)
948
- else:
949
- raise ValueError(f"unrecognized norm_type in config {config.norm_type}")
950
-
951
- self.layers = nn.ModuleList(
952
- [
953
- TorchGptDecoderLayer(
954
- embed_dim=config.embed_dim,
955
- ffn_embed_dim=config.ffn_embed_dim,
956
- num_heads=config.num_heads,
957
- rope_config=config.rope_config,
958
- norm_type=config.norm_type,
959
- parallel_attention_ff=config.parallel_attention_ff,
960
- add_bias_ffn=config.add_bias_ffn,
961
- ffn_activation_name=config.ffn_activation_name,
962
- use_glu_in_ffn=config.use_glu_in_ffn,
963
- num_kv_heads=config.num_kv_heads, # type: ignore
964
- add_bias_attn=config.add_bias_attn,
965
- rms_norm_eps=config.rms_norm_eps,
966
- )
967
- for _ in range(config.num_layers)
968
- ]
969
- )
970
-
971
- self.lm_head = TorchSimpleLMHead(
972
- embed_dim=config.embed_dim,
973
- alphabet_size=config.vocab_size,
974
- add_bias_lm_head=config.add_bias_lm_head,
975
- )
976
-
977
- def apply_transformer_layers(
978
- self, embeddings: torch.Tensor, attention_mask: torch.Tensor = None
979
- ) -> torch.Tensor:
980
- if attention_mask is None:
981
- attention_mask = build_causal_attention_mask(
982
- 1, embeddings.shape[1], device=embeddings.device
983
- )
984
- for layer in self.layers:
985
- embeddings = layer(embeddings, attention_mask)
986
-
987
- return embeddings
988
-
989
- def forward(
990
- self, token_ids: torch.Tensor, attention_mask: torch.Tensor = None
991
- ) -> dict[str, torch.Tensor]:
992
- if attention_mask is None:
993
- attention_mask = build_causal_attention_mask(
994
- 1, token_ids.shape[1], device=token_ids.device
995
- )
996
-
997
- tokens_embeddings = self.token_embed(token_ids)
998
-
999
- after_transformer_embeddings = self.apply_transformer_layers(
1000
- tokens_embeddings, attention_mask=attention_mask
1001
- )
1002
-
1003
- embeddings = self.final_norm(after_transformer_embeddings)
1004
- logits = self.lm_head(embeddings)
1005
- return {"embeddings": embeddings, "logits": logits}
1006
-
1007
-
1008
- class TorchSimpleLMHead(nn.Module):
1009
- def __init__(
1010
- self, embed_dim: int, alphabet_size: int, add_bias_lm_head: bool = True
1011
- ) -> None:
1012
- super().__init__()
1013
- self.fc = nn.Linear(embed_dim, alphabet_size, bias=add_bias_lm_head)
1014
-
1015
- def forward(self, x: torch.Tensor) -> torch.Tensor:
1016
- return self.fc(x)
1017
-
1018
-
1019
- class TorchGptDecoderLayer(nn.Module):
1020
- def __init__(
1021
- self,
1022
- embed_dim: int,
1023
- ffn_embed_dim: int,
1024
- num_heads: int,
1025
- rope_config: RotaryEmbeddingConfig,
1026
- norm_type: str,
1027
- parallel_attention_ff: bool,
1028
- add_bias_ffn: bool,
1029
- ffn_activation_name: str,
1030
- use_glu_in_ffn: bool,
1031
- num_kv_heads: int,
1032
- add_bias_attn: bool,
1033
- rms_norm_eps: float = 1e-6,
1034
- ) -> None:
1035
- super().__init__()
1036
- self.num_heads = num_heads
1037
- self.parallel_attention_ff = parallel_attention_ff
1038
- self.use_glu_in_ffn = use_glu_in_ffn
1039
-
1040
- # Self-Attention layer
1041
- self.self_attn = TorchGptGroupedQueryAttention(
1042
- embed_dim=embed_dim,
1043
- num_heads=num_heads,
1044
- num_kv_heads=num_kv_heads,
1045
- rope_config=rope_config,
1046
- add_bias_attn=add_bias_attn,
1047
- )
1048
-
1049
- # Normalization layers
1050
- if norm_type == "layer_norm":
1051
- self.attn_norm = nn.LayerNorm(embed_dim)
1052
- if not self.parallel_attention_ff:
1053
- self.ffn_norm = nn.LayerNorm(embed_dim)
1054
- elif norm_type == "RMS_norm":
1055
- self.attn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps)
1056
- if not self.parallel_attention_ff:
1057
- self.ffn_norm = TorchRMSNorm(embed_dim, eps=rms_norm_eps)
1058
- else:
1059
- raise ValueError(f"unrecognized norm_type: {norm_type}")
1060
-
1061
- # Feedforward network
1062
- self.activation = get_activation_fn(ffn_activation_name)
1063
- ffn_hidden_dim = ffn_embed_dim * (2 if use_glu_in_ffn else 1)
1064
- self.fc1 = nn.Linear(embed_dim, ffn_hidden_dim, bias=add_bias_ffn)
1065
- self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_ffn)
1066
-
1067
- def forward(
1068
- self, embeddings: torch.Tensor, attention_mask: torch.Tensor
1069
- ) -> torch.Tensor:
1070
- residuals = embeddings
1071
-
1072
- if self.parallel_attention_ff:
1073
- # Parallel Attention + MLP
1074
- embeddings_normed = self.attn_norm(embeddings)
1075
-
1076
- attn_output, _ = self.self_attn(
1077
- embeddings_normed,
1078
- embeddings_normed,
1079
- embeddings_normed,
1080
- attn_mask=attention_mask,
1081
- )
1082
- ffn_output = self.mlp(embeddings_normed) # type: ignore
1083
-
1084
- return residuals + attn_output + ffn_output
1085
- else:
1086
- # Sequential Attention + MLP
1087
- normed_embeddings = self.attn_norm(embeddings)
1088
-
1089
- attn_output = embeddings + self.self_attn(
1090
- normed_embeddings,
1091
- normed_embeddings,
1092
- normed_embeddings,
1093
- attention_mask=attention_mask,
1094
- )
1095
-
1096
- normed_embeddings2 = self.ffn_norm(attn_output)
1097
- ffn_output = self.mlp(normed_embeddings2) # type: ignore
1098
- return attn_output + ffn_output # Residual connection
1099
-
1100
- def mlp(self, x: torch.Tensor) -> torch.Tensor:
1101
- """Applies the feedforward network (MLP) with optional GLU."""
1102
- ffn_output = self.fc1(x)
1103
-
1104
- if self.use_glu_in_ffn:
1105
- ffn_output1, ffn_output2 = ffn_output.chunk(2, dim=-1)
1106
- ffn_output = self.activation(ffn_output1) * ffn_output2
1107
- else:
1108
- ffn_output = self.activation(ffn_output)
1109
-
1110
- return self.fc2(ffn_output)
1111
-
1112
-
1113
- class TorchRMSNorm(nn.Module):
1114
- def __init__(self, dim: int, eps: float = 1e-6) -> None:
1115
- super().__init__()
1116
- self.eps = eps
1117
- self.scale = nn.Parameter(torch.ones(dim))
1118
-
1119
- def forward(self, x: torch.Tensor) -> torch.Tensor:
1120
- return (
1121
- x
1122
- * self.scale
1123
- / torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
1124
- )
1125
-
1126
-
1127
- def get_activation_fn(activation_name: str): # type: ignore
1128
- activations = {
1129
- "gelu": nn.functional.gelu,
1130
- "relu": nn.functional.relu,
1131
- "swish": nn.functional.silu,
1132
- "silu": nn.functional.silu,
1133
- }
1134
- return activations.get(activation_name, nn.functional.relu)
1135
-
1136
-
1137
- def build_causal_attention_mask(
1138
- batch_size: int, seq_len: int, device: torch.device
1139
- ) -> torch.Tensor:
1140
- """
1141
- Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1142
- to an attention layer.
1143
-
1144
- Args:
1145
- batch_size: Batch size.
1146
- seq_len: Length of the sequences.
1147
-
1148
- Returns:
1149
- Batch of causal masks.
1150
- """
1151
- mask = torch.ones((batch_size, 1, seq_len, seq_len), device=device)
1152
- causal_mask = torch.tril(mask)
1153
- return causal_mask
1154
-
1155
-
1156
- @dataclass
1157
- class RotaryEmbeddingConfigBis:
1158
- """
1159
- Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows
1160
- to adapt the rotary embeddings to larger lengths than what was used for training.
1161
- One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa
1162
- Args:
1163
- """
1164
-
1165
- rescaling_factor: Optional[float]
1166
-
1167
-
1168
- class RotaryEmbeddingBis(torch.nn.Module):
1169
- """
1170
- Rotary position embeddings based on those in
1171
- [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer).
1172
- Query and keys are transformed by rotation
1173
- matrices which depend on their relative positions.
1174
- """
1175
-
1176
- def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfigBis):
1177
- super().__init__()
1178
-
1179
- # Extract argument from the config
1180
- self.rescaling_factor = rotary_embedding_config.rescaling_factor
1181
- self.upper_freq = 10000
1182
- self.dim = dim
1183
-
1184
- self._seq_len_cached = None
1185
- self._cos_cached = None
1186
- self._sin_cached = None
1187
-
1188
- def _apply_rotary_pos_emb(
1189
- self,
1190
- heads: torch.Tensor,
1191
- cos: torch.Tensor,
1192
- sin: torch.Tensor,
1193
- ) -> torch.Tensor:
1194
- """ """
1195
- x_first, x_second = (
1196
- heads[..., : heads.shape[-1] // 2],
1197
- heads[..., heads.shape[-1] // 2 :],
1198
- )
1199
-
1200
- first_part = x_first * cos - x_second * sin
1201
- second_part = x_second * cos + x_first * sin
1202
-
1203
- return torch.cat((first_part, second_part), dim=-1)
1204
-
1205
- def _compute_cos_sin_tables(
1206
- self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2
1207
- ) -> tuple[torch.Tensor, torch.Tensor]:
1208
- seq_len = x.shape[seq_dimension]
1209
- # Reset the tables if the sequence length has changed,
1210
- # or if we're on a new device (possibly due to tracing for instance)
1211
- self._seq_len_cached = seq_len
1212
- t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq)
1213
- # freqs = torch.outer(t, inv_freq)
1214
- freqs = torch.einsum("i, j -> ij", t, inv_freq)
1215
-
1216
- self._cos_cached = torch.cos(freqs)[None, :, None, :]
1217
- self._sin_cached = torch.sin(freqs)[None, :, None, :]
1218
- # emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
1219
-
1220
- # self._cos_cached = emb.cos()[None, None, :, :]
1221
- # self._sin_cached = emb.sin()[None, None, :, :]
1222
-
1223
- return self._cos_cached, self._sin_cached
1224
-
1225
- def forward(
1226
- self, q: torch.Tensor, k: torch.Tensor
1227
- ) -> Tuple[torch.Tensor, torch.Tensor]:
1228
- if self.rescaling_factor is None:
1229
- inv_freq = 1.0 / (
1230
- self.upper_freq
1231
- ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
1232
- )
1233
- else:
1234
- updated_base = self.upper_freq * (
1235
- self.rescaling_factor ** (self.dim / (self.dim - 2))
1236
- )
1237
- inv_freq = 1.0 / (
1238
- updated_base
1239
- ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim)
1240
- )
1241
-
1242
- self._cos_cached, self._sin_cached = self._compute_cos_sin_tables(
1243
- q,
1244
- inv_freq,
1245
- seq_dimension=-3,
1246
- )
1247
-
1248
- return (
1249
- self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
1250
- self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
1251
- )
1252
-
1253
-
1254
- class MultiHeadAttention(nn.Module):
1255
- def __init__(
1256
- self,
1257
- num_heads: int,
1258
- key_size: int,
1259
- rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None,
1260
- add_bias_kv: bool = False,
1261
- value_size: Optional[int] = None,
1262
- model_size: Optional[int] = None,
1263
- name: Optional[str] = None,
1264
- ):
1265
- super().__init__()
1266
- if not model_size:
1267
- model_size = key_size * num_heads
1268
- if not value_size:
1269
- value_size = key_size
1270
- self.model_size = model_size
1271
- self.key_size = key_size
1272
- self.value_size = value_size
1273
- self.add_bias_kv = add_bias_kv
1274
- self.name = name
1275
- self.num_heads = num_heads
1276
- self._rotary_embedding_config = rotary_embedding_config
1277
-
1278
- self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size)
1279
- self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size)
1280
- self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size)
1281
- self.output = nn.Linear(self.num_heads * self.value_size, self.model_size)
1282
- if self._rotary_embedding_config:
1283
- self._rotary_embedding = RotaryEmbeddingBis(
1284
- self.key_size, self._rotary_embedding_config
1285
- )
1286
-
1287
- def apply_rotary_embeddings(
1288
- self,
1289
- query: torch.Tensor,
1290
- key: torch.Tensor,
1291
- ) -> tuple[torch.Tensor, torch.Tensor]:
1292
- """ """
1293
- query, key = self._rotary_embedding(query, key)
1294
- return query, key
1295
-
1296
- def forward(
1297
- self,
1298
- query: torch.Tensor,
1299
- key: torch.Tensor,
1300
- value: torch.Tensor,
1301
- attention_mask: Optional[torch.Tensor] = None,
1302
- attention_weight_bias: Optional[torch.Tensor] = None,
1303
- ) -> dict[str, torch.Tensor]:
1304
- """
1305
- Returns:
1306
- dictionary containing attention weights
1307
- and outputs.
1308
- """
1309
- key_heads = self.w_k(key).reshape(
1310
- (*key.shape[:-1], self.num_heads, self.key_size)
1311
- )
1312
- query_heads = self.w_q(query).reshape(
1313
- (*query.shape[:-1], self.num_heads, self.key_size)
1314
- )
1315
- value_heads = self.w_v(value).reshape(
1316
- (*value.shape[:-1], self.num_heads, self.value_size)
1317
- )
1318
- if self._rotary_embedding_config:
1319
- query_heads, key_heads = self.apply_rotary_embeddings(
1320
- query_heads, key_heads
1321
- )
1322
- attention_weights = torch.einsum(
1323
- "...thd, ...Thd -> ...htT", query_heads, key_heads
1324
- )
1325
- sqrt_key_size = np.sqrt(self.key_size)
1326
- attention_weights = attention_weights / sqrt_key_size
1327
- if attention_mask is not None:
1328
- attention_weights = torch.where(attention_mask, attention_weights, -1e30)
1329
- if attention_weight_bias is not None:
1330
- attention_weights = F.softmax(
1331
- attention_weights + attention_weight_bias, dim=-1
1332
- )
1333
- else:
1334
- attention_weights = F.softmax(attention_weights, dim=-1)
1335
- value_out = torch.einsum(
1336
- "...htT, ...Thd->...thd", attention_weights, value_heads
1337
- )
1338
- value_out = value_out.reshape((*value_out.shape[:-2], -1))
1339
- embeddings = self.output(value_out)
1340
-
1341
- return {"attention_weights": attention_weights, "embeddings": embeddings}
1342
-
1343
-
1344
- class SelfAttentionBlock(nn.Module):
1345
- def __init__(
1346
- self,
1347
- num_heads: int,
1348
- embed_dim: int,
1349
- ffn_embed_dim: int,
1350
- key_size: Optional[int] = None,
1351
- add_bias_kv: bool = False,
1352
- add_bias_fnn: bool = True,
1353
- ffn_activation_name: str = "gelu-no-approx",
1354
- use_glu_in_ffn: bool = False,
1355
- layer_norm_eps: float = 1e-5, # this is the default haiku value
1356
- pre_layer_norm: bool = True,
1357
- name: Optional[str] = None,
1358
- rotary_embedding_config: Optional[RotaryEmbeddingConfigBis] = None,
1359
- ):
1360
- super().__init__()
1361
- if key_size is None:
1362
- if embed_dim % num_heads != 0:
1363
- raise ValueError(
1364
- f"The embedding dimension should be divisible by the number of "
1365
- f"heads, however provided embedding dimension is {embed_dim} and "
1366
- f"the number of heads is {num_heads}."
1367
- )
1368
- else:
1369
- key_size = embed_dim // num_heads
1370
-
1371
- # Get ffn activation function
1372
- self._pre_layer_norm = pre_layer_norm
1373
- self._use_glu_in_fnn = use_glu_in_ffn
1374
- # Define layers
1375
- if use_glu_in_ffn:
1376
- # user should multiply ffn_embed_dim by 2/3 when using GLU
1377
- # to keep total number of parameters equal
1378
- # see https://arxiv.org/pdf/2002.05202.pdf. for more details
1379
- # we multiply by 2 here as the output will be split in 2 for GLU
1380
- self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn)
1381
- else:
1382
- self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn)
1383
-
1384
- self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn)
1385
-
1386
- self.layer_norm_self_attention = nn.LayerNorm(
1387
- embed_dim,
1388
- )
1389
- self.layer_norm_mlp = nn.LayerNorm(embed_dim)
1390
- if ffn_activation_name == "swish":
1391
- self._ffn_activation_fn = nn.SiLU()
1392
- elif ffn_activation_name == "gelu-no-approx":
1393
- self._ffn_activation_fn = nn.GELU(approximate="tanh")
1394
- else:
1395
- self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name)
1396
-
1397
- self.mha = MultiHeadAttention(
1398
- num_heads=num_heads,
1399
- key_size=key_size,
1400
- add_bias_kv=add_bias_kv,
1401
- model_size=embed_dim,
1402
- name="self_attention",
1403
- rotary_embedding_config=rotary_embedding_config,
1404
- )
1405
-
1406
- def mlp(self, embed: torch.Tensor) -> torch.Tensor:
1407
-
1408
- if self._pre_layer_norm:
1409
- x = self.layer_norm_mlp(embed)
1410
- else:
1411
- x = embed
1412
-
1413
- if self._use_glu_in_fnn:
1414
- x = self.fc1(x)
1415
- x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1)
1416
- x = self._ffn_activation_fn(x1) * x2
1417
- else:
1418
- x = self._ffn_activation_fn(self.fc1(x))
1419
- x = self.fc2(x)
1420
-
1421
- if not self._pre_layer_norm:
1422
- x = self.layer_norm_mlp(x + embed)
1423
- return x
1424
-
1425
- def forward(
1426
- self,
1427
- x: torch.Tensor,
1428
- attention_mask: Optional[torch.Tensor] = None,
1429
- attention_weight_bias: Optional[torch.Tensor] = None,
1430
- ) -> dict[str, torch.Tensor]:
1431
-
1432
- res = x
1433
- if self._pre_layer_norm:
1434
- x = self.layer_norm_self_attention(x)
1435
-
1436
- output: dict[str, torch.Tensor] = self.mha(
1437
- x,
1438
- x,
1439
- x,
1440
- attention_mask=attention_mask,
1441
- attention_weight_bias=attention_weight_bias,
1442
- )
1443
-
1444
- if not self._pre_layer_norm:
1445
- output["embeddings"] = self.layer_norm_self_attention(
1446
- output["embeddings"] + res
1447
- )
1448
-
1449
- x = output["embeddings"]
1450
- else:
1451
- x = output["embeddings"]
1452
- x = res + x
1453
-
1454
- # MLP
1455
- if not self._pre_layer_norm:
1456
- x = self.mlp(x)
1457
- else:
1458
- x = x + self.mlp(x)
1459
-
1460
- output["embeddings"] = x
1461
- return output
1462
-
1463
-
1464
- class RobertaLMHead(nn.Module):
1465
- """
1466
- Roberta Language Model head. Transforms final attention layer output into a
1467
- distribution over tokens at each position.
1468
- """
1469
-
1470
- def __init__(self, embed_dim: int, alphabet_size: int):
1471
- """
1472
- Args:
1473
- embed_dim: Embedding dimension.
1474
- alphabet_size: Number of tokens in the alphabet.
1475
- """
1476
- super().__init__()
1477
- self.embed_dim = embed_dim
1478
- self.alphabet_size = alphabet_size
1479
-
1480
- # Define layers
1481
- self._first_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True)
1482
- self._fc1 = nn.Linear(embed_dim, embed_dim)
1483
- self._second_layer_norm = nn.LayerNorm(embed_dim, elementwise_affine=True)
1484
- self._final_fc = nn.Linear(embed_dim, alphabet_size)
1485
-
1486
- def forward(self, x: torch.Tensor) -> dict:
1487
- x = self._first_layer_norm(x)
1488
- embeddings = x
1489
- x = self._fc1(x)
1490
- x = nn.functional.gelu(x)
1491
- x = self._second_layer_norm(x)
1492
- logits = self._final_fc(x)
1493
- return {"embeddings": embeddings, "logits": logits}
1494
-
1495
-
1496
- class TorchESMTransformer(nn.Module):
1497
- def __init__(
1498
- self,
1499
- esm_config: ESMTransformerConfig,
1500
- ):
1501
- super(TorchESMTransformer, self).__init__()
1502
- self.esm_config = esm_config
1503
-
1504
- # Other cases are not implemented
1505
- assert esm_config.positional_embedding is None
1506
- assert esm_config.lm_head == "roberta"
1507
- assert esm_config.use_rotary_embedding is True
1508
- assert esm_config.token_dropout is False
1509
- assert esm_config.emb_layer_norm_before is False
1510
- assert esm_config.mask_before_attention is False
1511
- assert esm_config.bias_word_embedding is False
1512
- assert esm_config.use_gradient_checkpointing is False
1513
-
1514
- self.embed_layer = nn.Embedding(esm_config.alphabet_size, esm_config.embed_dim)
1515
-
1516
- self.lm_head = RobertaLMHead(
1517
- embed_dim=esm_config.embed_dim,
1518
- alphabet_size=esm_config.alphabet_size,
1519
- )
1520
-
1521
- self.rotary_embedding_config = RotaryEmbeddingConfigBis(
1522
- rescaling_factor=esm_config.rescaling_factor
1523
- )
1524
-
1525
- self.attention_blocks = nn.ModuleList(
1526
- [
1527
- SelfAttentionBlock( # type: ignore
1528
- num_heads=esm_config.attention_heads,
1529
- embed_dim=esm_config.embed_dim,
1530
- key_size=esm_config.key_size,
1531
- ffn_embed_dim=esm_config.ffn_embed_dim,
1532
- add_bias_kv=esm_config.add_bias_kv,
1533
- add_bias_fnn=esm_config.add_bias_ffn,
1534
- ffn_activation_name=esm_config.ffn_activation_name,
1535
- use_glu_in_ffn=esm_config.use_glu_in_ffn,
1536
- rotary_embedding_config=self.rotary_embedding_config,
1537
- layer_norm_eps=esm_config.layer_norm_eps,
1538
- pre_layer_norm=esm_config.pre_layer_norm,
1539
- )
1540
- for _ in range(esm_config.num_layers)
1541
- ]
1542
- )
1543
-
1544
- def forward(
1545
- self, tokens: torch.Tensor, attention_mask: torch.Tensor = None
1546
- ) -> torch.Tensor:
1547
- """
1548
- Computes the embeddings based on the input tokens.
1549
-
1550
- Args:
1551
- tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
1552
- attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
1553
- If no mask is provided, a mask by default which equals 1 over all non
1554
- pad tokens and 0 over pad tokens is computed.
1555
-
1556
- Returns:
1557
- Dictionary containing the final embeddings and logits.
1558
- """
1559
- x = self.embed_layer(tokens)
1560
-
1561
- # RoBERTa's mask scaling factor
1562
- x = self.esm_config.embed_scale * x
1563
-
1564
- if attention_mask is None:
1565
- attention_mask = build_padding_attention_mask(
1566
- tokens=tokens, pad_token_id=self.esm_config.pad_token_id
1567
- )
1568
-
1569
- for layer in self.attention_blocks:
1570
- x = layer(x, attention_mask)["embeddings"]
1571
-
1572
- assert self.esm_config.lm_head == "roberta"
1573
- x = self.lm_head(x)["embeddings"]
1574
-
1575
- return x
1576
-
1577
-
1578
- def build_padding_attention_mask(
1579
- tokens: torch.Tensor, pad_token_id: int
1580
- ) -> torch.Tensor:
1581
- """
1582
- Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
1583
-
1584
- Args:
1585
- tokens: Batch of sequences of shape (batch_size, seq_len).
1586
- pad_token_id: Int corresponding to the <pad> token to mask.
1587
-
1588
- Returns:
1589
- Batch of attention masks, masking out <pad> tokens.
1590
- """
1591
- padding_mask = tokens != pad_token_id
1592
- padding_mask = padding_mask.unsqueeze(1)
1593
- padding_mask = torch.einsum("bhT, bht -> bhtT", padding_mask, padding_mask)
1594
- return padding_mask
1595
-
1596
-
1597
- class TorchBioBrainEncoder(nn.Module):
1598
- def __init__(
1599
- self,
1600
- esm_config: ESMTransformerConfig,
1601
- ):
1602
- super(TorchBioBrainEncoder, self).__init__()
1603
- self.esm_config = esm_config
1604
- self.esm_model = TorchESMTransformer(self.esm_config)
1605
-
1606
- def forward(
1607
- self,
1608
- bio_token_ids: torch.Tensor,
1609
- ) -> torch.Tensor:
1610
- """
1611
- Args:
1612
- bio_token_ids (torch.Tensor):
1613
- Shape (batch_size, num_bio_tokens)
1614
-
1615
- Returns:
1616
- torch.Tensor:
1617
- Shape (batch_size, num_bio_tokens, embed_dim)
1618
- """
1619
- bio_embeddings = self.esm_model(tokens=bio_token_ids)
1620
-
1621
- return bio_embeddings
1622
-
1623
-
1624
- class TorchMultiModalPerceiverResamplerBlock(nn.Module):
1625
- def __init__(
1626
- self,
1627
- num_heads: int,
1628
- embed_dim: int,
1629
- ffn_embed_dim: int,
1630
- key_size: Optional[int] = None,
1631
- add_bias_kv: bool = False,
1632
- add_bias_ffn: bool = True,
1633
- ffn_activation_name: str = "gelu",
1634
- use_glu_in_ffn: bool = False,
1635
- ):
1636
- super().__init__()
1637
-
1638
- if key_size is None:
1639
- if embed_dim % num_heads != 0:
1640
- raise ValueError(
1641
- f"Embedding dimension {embed_dim} should be divisible by "
1642
- f"num_heads {num_heads}."
1643
- )
1644
- key_size = embed_dim // num_heads
1645
-
1646
- self.num_heads = num_heads
1647
- self.embed_dim = embed_dim
1648
- self.ffn_embed_dim = ffn_embed_dim * 2 if use_glu_in_ffn else ffn_embed_dim
1649
- self.use_glu_in_ffn = use_glu_in_ffn
1650
-
1651
- self.cross_attention_1 = MultiHeadAttention(
1652
- num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv
1653
- )
1654
- self.cross_attention_2 = MultiHeadAttention(
1655
- num_heads=num_heads, key_size=key_size, add_bias_kv=add_bias_kv
1656
- )
1657
-
1658
- self.norm_cross_attention_1 = nn.LayerNorm(embed_dim)
1659
- self.norm_cross_attention_2 = nn.LayerNorm(embed_dim)
1660
- self.norm_mlp = nn.LayerNorm(embed_dim)
1661
-
1662
- self.fc1 = nn.Linear(embed_dim, self.ffn_embed_dim, bias=add_bias_ffn)
1663
- self.fc2 = nn.Linear(self.ffn_embed_dim, embed_dim, bias=add_bias_ffn)
1664
-
1665
- self.activation_fn = getattr(
1666
- nn.functional, ffn_activation_name, nn.functional.gelu
1667
- )
1668
-
1669
- def mlp(self, x: torch.Tensor) -> torch.Tensor:
1670
- x = self.norm_mlp(x)
1671
- if self.use_glu_in_ffn:
1672
- x1, x2 = torch.chunk(self.fc1(x), 2, dim=-1)
1673
- x = self.activation_fn(x1) * x2
1674
- else:
1675
- x = self.activation_fn(self.fc1(x))
1676
- return self.fc2(x)
1677
-
1678
- def forward(
1679
- self,
1680
- x: torch.Tensor,
1681
- cross_attention_embeddings_1: torch.Tensor,
1682
- cross_attention_embeddings_2: torch.Tensor,
1683
- attention_mask_1: Optional[torch.Tensor] = None,
1684
- attention_mask_2: Optional[torch.Tensor] = None,
1685
- ) -> Dict[str, torch.Tensor]:
1686
- res = x
1687
- x = self.norm_cross_attention_1(x)
1688
-
1689
- attn_output = self.cross_attention_1(
1690
- query=x,
1691
- key=cross_attention_embeddings_1,
1692
- value=cross_attention_embeddings_1,
1693
- attention_mask=attention_mask_1,
1694
- )["embeddings"]
1695
- x = res + attn_output
1696
-
1697
- res = x
1698
- x = self.norm_cross_attention_2(x)
1699
- attn_output = self.cross_attention_2(
1700
- query=x,
1701
- key=cross_attention_embeddings_2,
1702
- value=cross_attention_embeddings_2,
1703
- attention_mask=attention_mask_2,
1704
- )["embeddings"]
1705
- x = res + attn_output
1706
-
1707
- x = x + self.mlp(x)
1708
-
1709
- return {"embeddings": x}
1710
-
1711
-
1712
- class TorchMultiModalPerceiverResampler(nn.Module):
1713
- """
1714
- Perceiver Resampler model, made of successive PerceiverResamplerBlocks.
1715
- """
1716
-
1717
- def __init__(
1718
- self,
1719
- config: PerceiverResamplerConfig,
1720
- name: Optional[str] = None,
1721
- ):
1722
- """
1723
- Initialize a Perceiver Resampler model.
1724
-
1725
- Args:
1726
- config: Dataclass containing model hyperparameters.
1727
- name: Name for module (custom will break weight loading).
1728
- """
1729
- super().__init__()
1730
- self.config = config
1731
- self.name = name
1732
- self.layers = nn.ModuleList(
1733
- [
1734
- TorchMultiModalPerceiverResamplerBlock(
1735
- num_heads=self.config.attention_heads,
1736
- embed_dim=self.config.embed_dim,
1737
- key_size=self.config.key_size,
1738
- ffn_embed_dim=self.config.ffn_embed_dim,
1739
- add_bias_kv=self.config.add_bias_kv,
1740
- add_bias_ffn=self.config.add_bias_ffn,
1741
- ffn_activation_name=self.config.ffn_activation_name,
1742
- use_glu_in_ffn=self.config.use_glu_in_ffn,
1743
- )
1744
- for _ in range(self.config.num_layers)
1745
- ]
1746
- )
1747
-
1748
- self.latent_queries = torch.nn.Parameter(
1749
- torch.randn(self.config.resampled_length, self.config.embed_dim)
1750
- * (
1751
- 1.0
1752
- / torch.sqrt(torch.tensor(self.config.embed_dim, dtype=torch.float32))
1753
- )
1754
- )
1755
-
1756
- def apply_attention_blocks(
1757
- self,
1758
- x: torch.Tensor,
1759
- xf_1: torch.Tensor,
1760
- xf_2: torch.Tensor,
1761
- outs: Dict[str, torch.Tensor],
1762
- attention_mask_1: Optional[torch.Tensor] = None,
1763
- attention_mask_2: Optional[torch.Tensor] = None,
1764
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
1765
- """
1766
- Create the blocks of attention layers and applies them.
1767
- """
1768
- for layer in self.layers:
1769
- concat_input_1 = torch.cat([xf_1, x], dim=1)
1770
- concat_input_2 = torch.cat([xf_2, x], dim=1)
1771
-
1772
- output = layer(
1773
- x=x,
1774
- cross_attention_embeddings_1=concat_input_1,
1775
- cross_attention_embeddings_2=concat_input_2,
1776
- attention_mask_1=attention_mask_1,
1777
- attention_mask_2=attention_mask_2,
1778
- )
1779
- x = output["embeddings"]
1780
-
1781
- return x, outs
1782
-
1783
- def forward(
1784
- self,
1785
- input_embeddings_1: torch.Tensor,
1786
- input_embeddings_2: torch.Tensor,
1787
- attention_mask_1: Optional[torch.Tensor] = None,
1788
- attention_mask_2: Optional[torch.Tensor] = None,
1789
- ) -> Dict[str, torch.Tensor]:
1790
- """
1791
- Computes the embeddings based on the input tokens.
1792
- """
1793
- assert (
1794
- input_embeddings_1.shape[-1] == self.config.embed_dim
1795
- ), "The input embedding dim should match the model embed dim"
1796
- assert (
1797
- input_embeddings_2.shape[-1] == self.config.embed_dim
1798
- ), "The input embedding dim should match the model embed dim"
1799
-
1800
- batch_size = input_embeddings_1.shape[0]
1801
-
1802
- latent_queries = self.latent_queries.unsqueeze(0).repeat(batch_size, 1, 1)
1803
-
1804
- outs: Dict[str, torch.Tensor] = {}
1805
- x = latent_queries
1806
-
1807
- x, outs = self.apply_attention_blocks(
1808
- x=x,
1809
- xf_1=input_embeddings_1,
1810
- xf_2=input_embeddings_2,
1811
- outs=outs,
1812
- attention_mask_1=attention_mask_1,
1813
- attention_mask_2=attention_mask_2,
1814
- )
1815
-
1816
- outs["embeddings"] = x
1817
-
1818
- return outs
1819
-
1820
-
1821
- class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1822
- def __init__(
1823
- self,
1824
- perceiver_resampler_config: PerceiverResamplerConfig,
1825
- input_embed_dim: int,
1826
- embed_dim: int,
1827
- bio_pad_token_id: int,
1828
- english_pad_token_id: int,
1829
- english_vocab_size: int,
1830
- ):
1831
- super().__init__()
1832
- self.config = perceiver_resampler_config
1833
- self.input_embed_dim = input_embed_dim
1834
- self.embed_dim = embed_dim
1835
- self.bio_pad_token_id = bio_pad_token_id
1836
- self.english_pad_token_id = english_pad_token_id
1837
- self.english_vocab_size = english_vocab_size
1838
-
1839
- self.bio_projection = nn.Linear(input_embed_dim, embed_dim)
1840
- self.token_embedding = nn.Embedding(english_vocab_size, embed_dim)
1841
- self.perceiver_resampler = TorchMultiModalPerceiverResampler(config=self.config)
1842
-
1843
- def forward(
1844
- self,
1845
- bio_token_ids: torch.Tensor,
1846
- bio_embeddings: torch.Tensor,
1847
- english_token_ids: torch.Tensor,
1848
- ) -> torch.Tensor:
1849
- """
1850
- Args:
1851
- bio_token_ids (torch.Tensor):
1852
- Shape (batch_size, num_bio_tokens)
1853
-
1854
- bio_embeddings (torch.Tensor):
1855
- Shape (batch_size, num_bio_tokens, embed_dim)
1856
-
1857
- english_token_ids (torch.Tensor):
1858
- Shape (batch_size, num_english_tokens)
1859
- """
1860
- projected_bio_embeddings = self.bio_projection(bio_embeddings)
1861
- english_embeddings = self.token_embedding(english_token_ids)
1862
-
1863
- bio_attention_mask = build_perceiver_padding_attention_mask(
1864
- bio_token_ids, self.config.resampled_length, self.bio_pad_token_id
1865
- )
1866
- english_attention_mask = build_perceiver_padding_attention_mask(
1867
- english_token_ids, self.config.resampled_length, self.english_pad_token_id
1868
- )
1869
-
1870
- projected_embeddings = self.perceiver_resampler(
1871
- input_embeddings_1=projected_bio_embeddings,
1872
- attention_mask_1=bio_attention_mask,
1873
- input_embeddings_2=english_embeddings,
1874
- attention_mask_2=english_attention_mask,
1875
- )["embeddings"]
1876
-
1877
- return projected_embeddings
1878
-
1879
-
1880
- def build_perceiver_padding_attention_mask(
1881
- tokens: torch.Tensor, resampled_length: int, pad_token_id: int
1882
- ) -> torch.Tensor:
1883
- batch_size, seq_len = tokens.shape
1884
- padding_mask = tokens != pad_token_id # (batch_size, seq_len)
1885
-
1886
- padding_mask = torch.cat(
1887
- [
1888
- padding_mask,
1889
- torch.ones(
1890
- (batch_size, resampled_length), dtype=torch.bool, device=tokens.device
1891
- ),
1892
- ],
1893
- dim=1,
1894
- ) # (batch_size, seq_len + resampled_length)
1895
-
1896
- padding_mask = padding_mask[:, None, None, :]
1897
- padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
1898
- return padding_mask