|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
|
|
from peft.utils.integrations import gather_params_ctx |
|
|
|
from .config import PromptTuningInit |
|
|
|
|
|
class PromptEmbedding(torch.nn.Module): |
|
""" |
|
The model to encode virtual tokens into prompt embeddings. |
|
|
|
Args: |
|
config ([`PromptTuningConfig`]): The configuration of the prompt embedding. |
|
word_embeddings (`torch.nn.Module`): The word embeddings of the base transformer model. |
|
|
|
**Attributes**: |
|
- **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt embedding. |
|
|
|
Example: |
|
|
|
```py |
|
>>> from peft import PromptEmbedding, PromptTuningConfig |
|
|
|
>>> config = PromptTuningConfig( |
|
... peft_type="PROMPT_TUNING", |
|
... task_type="SEQ_2_SEQ_LM", |
|
... num_virtual_tokens=20, |
|
... token_dim=768, |
|
... num_transformer_submodules=1, |
|
... num_attention_heads=12, |
|
... num_layers=12, |
|
... prompt_tuning_init="TEXT", |
|
... prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral", |
|
... tokenizer_name_or_path="t5-base", |
|
... ) |
|
|
|
>>> # t5_model.shared is the word embeddings of the base model |
|
>>> prompt_embedding = PromptEmbedding(config, t5_model.shared) |
|
``` |
|
|
|
Input Shape: (`batch_size`, `total_virtual_tokens`) |
|
|
|
Output Shape: (`batch_size`, `total_virtual_tokens`, `token_dim`) |
|
""" |
|
|
|
def __init__(self, config, word_embeddings): |
|
super().__init__() |
|
|
|
total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules |
|
self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim) |
|
if config.prompt_tuning_init == PromptTuningInit.TEXT and not config.inference_mode: |
|
from transformers import AutoTokenizer |
|
|
|
tokenizer_kwargs = config.tokenizer_kwargs or {} |
|
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path, **tokenizer_kwargs) |
|
init_text = config.prompt_tuning_init_text |
|
init_token_ids = tokenizer(init_text)["input_ids"] |
|
|
|
num_text_tokens = len(init_token_ids) |
|
if num_text_tokens > total_virtual_tokens: |
|
init_token_ids = init_token_ids[:total_virtual_tokens] |
|
elif num_text_tokens < total_virtual_tokens: |
|
num_reps = math.ceil(total_virtual_tokens / num_text_tokens) |
|
init_token_ids = init_token_ids * num_reps |
|
init_token_ids = init_token_ids[:total_virtual_tokens] |
|
init_token_ids = torch.LongTensor(init_token_ids).to(word_embeddings.weight.device) |
|
with gather_params_ctx(word_embeddings.parameters()): |
|
word_embedding_weights = word_embeddings(init_token_ids).detach().clone() |
|
word_embedding_weights = word_embedding_weights.to(torch.float32) |
|
self.embedding.weight = torch.nn.Parameter(word_embedding_weights) |
|
|
|
def forward(self, indices): |
|
|
|
prompt_embeddings = self.embedding(indices) |
|
return prompt_embeddings |
|
|