annabeth97c's picture
feat(src/sonicverse): Initial commit
7c34c28
from typing import List, Dict
from abc import ABC, abstractmethod
from torch.nn.functional import conv1d
import torch
import logging
from multi_token.modalities.base_modality import Modality
from multi_token.model_utils import MultiTaskType
from torchviz import make_dot
class LMMMetaModel:
def __init__(self, config):
super(LMMMetaModel, self).__init__(config)
def _load_projector_weights(self, weights: Dict):
weights = {
(k[23:] if k.startswith("base_model.model.model.") else k): v
for k, v in weights.items()
}
logging.info(f"Loading pretrained weights: {list(weights.keys())}")
load_result = self.load_state_dict(weights, strict=False)
assert (
len(load_result.unexpected_keys) == 0
), "Unexpected weights, is this the right model?"
def initialize_pretrained_modules(self, modalities: List[Modality], weights: Dict):
for m in modalities:
# projector = m.build_projector(self.config.hidden_size)
# setattr(self, m.name + "_lmm_projector", projector)
projector = m.build_projector(self.config.hidden_size)
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
for task_name in m.tasks["task_heads"].keys():
task_model = projector[task_name]
setattr(self, m.name + "_" + task_name, task_model)
else:
setattr(self, m.name + "_lmm_projector", projector)
self._load_projector_weights(weights)
def initialize_modules(self, modalities: List[Modality], weights: Dict):
names = [m.name for m in modalities]
self.config.modalities = names
for m in modalities:
# projector = m.build_projector(self.config.hidden_size)
# setattr(self, m.name + "_lmm_projector", projector)
projector = m.build_projector(self.config.hidden_size)
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
for task_name in m.tasks["task_heads"].keys():
task_model = projector[task_name]
setattr(self, m.name + "_" + task_name, task_model)
else:
setattr(self, m.name + "_lmm_projector", projector)
self._load_projector_weights(weights)
class LMMMetaForCausalLM(ABC):
@abstractmethod
def get_model(self) -> "LMMMetaForCausalLM":
pass
def prepare_inputs_labels_for_multimodal(
self, input_ids, attention_mask, past_key_values, labels, **kwargs
):
model = self.get_model()
batch_size, seq_len = input_ids.shape
# batch_size x seq_len x embedding_hidden_size
inputs_embeds = torch.zeros(
(batch_size, seq_len, self.config.hidden_size),
dtype=self.dtype,
device=self.device,
)
# modality x batch_size x instance_idx x modality_token_width x embedding_hidden_size
projected_tensors = []
# assuming that if caching is enabled, we'll never have past_key_values AND need to encode the instruction modality values
task_vals = {}
#print("here past_key_values", past_key_values)
#past_key_values == None
if past_key_values is None:
for m in self.modalities:
m_vals = m.forward(kwargs.get(m.name))
mp_vals = []
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
proj = {}
for task_name in m.tasks["task_heads"].keys():
proj[task_name] = getattr(model, m.name + "_" + task_name)
else:
proj = getattr(model, m.name + "_lmm_projector")
# project each batch into language model token space
for m_val in m_vals:
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK:
for task_name in m.tasks["task_heads"].keys():
if task_name == "lmm_projector":
mp_vals.append(proj[task_name](m_val))
# make_dot(mp_vals[-1], params=dict(list(model.named_parameters()))).render(task_name, format="png")
else:
if task_name not in task_vals:
task_vals[task_name] = [proj[task_name](m_val)]
else:
task_vals[task_name].append(proj[task_name](m_val))
# make_dot(task_vals[task_name], params=dict(list(model.named_parameters()))).render(task_name, format="png")
elif m.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK:
task_outputs = proj(m_val)
mp_vals.append(task_outputs.pop("projectors"))
for task_name in task_outputs.keys():
if not task_name in task_vals:
task_vals[task_name] = [task_outputs[task_name]]
else:
task_vals[task_name].append(task_outputs[task_name])
else:
mp_vals.append(proj(m_val))
assert all(
mp_val.shape[1:] == (m.token_width, self.config.hidden_size)
for mp_val in mp_vals
), (
"Modality tensors have incorrect shape, check your projector implementation "
+ str([mp_val.shape[1:] for mp_val in mp_vals])
+ " vs expected "
+ str((m.token_width, self.config.hidden_size))
)
projected_tensors.append(mp_vals)
indices = None
for i, input_ids_sample in enumerate(input_ids):
is_text_mask = input_ids_sample >= 0
# fill in all the LLM-based text embeddings
inputs_embeds[i, is_text_mask] = model.embed_tokens(
input_ids_sample[is_text_mask]
)
# skip if all tokens are text tokens
if is_text_mask.sum() == seq_len:
continue
assert (
past_key_values is None
), "We shouldn't have cached keys if this is the first instruction pass"
#past_key_values = None
for mi, m in enumerate(self.modalities):
# locate the group of tokens for this modality
m_mask = (input_ids_sample == m.token_idx).float()
m_kernel = torch.tensor(
[-1] * m.token_width, dtype=m_mask.dtype, device=m_mask.device
)
m_conv = conv1d(
m_mask.unsqueeze(0).unsqueeze(0),
m_kernel.unsqueeze(0).unsqueeze(0),
)
# where do we see `token_width`-tokens in a row?
indices = (m_conv[0, 0] == -m.token_width).nonzero(as_tuple=True)[0]
# fill these embeddings with the projected modality tensor
last_covered_idx = -1
k = 0
for possible_token_idx in indices:
if possible_token_idx <= last_covered_idx:
# make sure we don't overwrite an instance we've already covered
# handles bug caused by back-to-back tokens
continue
batch_modality_tensor = projected_tensors[mi][i][k]
inputs_embeds[
i, possible_token_idx : possible_token_idx + m.token_width
] = batch_modality_tensor
last_covered_idx = possible_token_idx + m.token_width - 1
k += 1
return None, attention_mask, past_key_values, inputs_embeds, labels, task_vals