Spaces:
Running
on
Zero
Running
on
Zero
from typing import List, Dict | |
from abc import ABC, abstractmethod | |
from torch.nn.functional import conv1d | |
import torch | |
import logging | |
from multi_token.modalities.base_modality import Modality | |
from multi_token.model_utils import MultiTaskType | |
from torchviz import make_dot | |
class LMMMetaModel: | |
def __init__(self, config): | |
super(LMMMetaModel, self).__init__(config) | |
def _load_projector_weights(self, weights: Dict): | |
weights = { | |
(k[23:] if k.startswith("base_model.model.model.") else k): v | |
for k, v in weights.items() | |
} | |
logging.info(f"Loading pretrained weights: {list(weights.keys())}") | |
load_result = self.load_state_dict(weights, strict=False) | |
assert ( | |
len(load_result.unexpected_keys) == 0 | |
), "Unexpected weights, is this the right model?" | |
def initialize_pretrained_modules(self, modalities: List[Modality], weights: Dict): | |
for m in modalities: | |
# projector = m.build_projector(self.config.hidden_size) | |
# setattr(self, m.name + "_lmm_projector", projector) | |
projector = m.build_projector(self.config.hidden_size) | |
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK: | |
for task_name in m.tasks["task_heads"].keys(): | |
task_model = projector[task_name] | |
setattr(self, m.name + "_" + task_name, task_model) | |
else: | |
setattr(self, m.name + "_lmm_projector", projector) | |
self._load_projector_weights(weights) | |
def initialize_modules(self, modalities: List[Modality], weights: Dict): | |
names = [m.name for m in modalities] | |
self.config.modalities = names | |
for m in modalities: | |
# projector = m.build_projector(self.config.hidden_size) | |
# setattr(self, m.name + "_lmm_projector", projector) | |
projector = m.build_projector(self.config.hidden_size) | |
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK: | |
for task_name in m.tasks["task_heads"].keys(): | |
task_model = projector[task_name] | |
setattr(self, m.name + "_" + task_name, task_model) | |
else: | |
setattr(self, m.name + "_lmm_projector", projector) | |
self._load_projector_weights(weights) | |
class LMMMetaForCausalLM(ABC): | |
def get_model(self) -> "LMMMetaForCausalLM": | |
pass | |
def prepare_inputs_labels_for_multimodal( | |
self, input_ids, attention_mask, past_key_values, labels, **kwargs | |
): | |
model = self.get_model() | |
batch_size, seq_len = input_ids.shape | |
# batch_size x seq_len x embedding_hidden_size | |
inputs_embeds = torch.zeros( | |
(batch_size, seq_len, self.config.hidden_size), | |
dtype=self.dtype, | |
device=self.device, | |
) | |
# modality x batch_size x instance_idx x modality_token_width x embedding_hidden_size | |
projected_tensors = [] | |
# assuming that if caching is enabled, we'll never have past_key_values AND need to encode the instruction modality values | |
task_vals = {} | |
#print("here past_key_values", past_key_values) | |
#past_key_values == None | |
if past_key_values is None: | |
for m in self.modalities: | |
m_vals = m.forward(kwargs.get(m.name)) | |
mp_vals = [] | |
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK: | |
proj = {} | |
for task_name in m.tasks["task_heads"].keys(): | |
proj[task_name] = getattr(model, m.name + "_" + task_name) | |
else: | |
proj = getattr(model, m.name + "_lmm_projector") | |
# project each batch into language model token space | |
for m_val in m_vals: | |
if m.use_multi_task == MultiTaskType.SIMPLE_MULTI_TASK: | |
for task_name in m.tasks["task_heads"].keys(): | |
if task_name == "lmm_projector": | |
mp_vals.append(proj[task_name](m_val)) | |
# make_dot(mp_vals[-1], params=dict(list(model.named_parameters()))).render(task_name, format="png") | |
else: | |
if task_name not in task_vals: | |
task_vals[task_name] = [proj[task_name](m_val)] | |
else: | |
task_vals[task_name].append(proj[task_name](m_val)) | |
# make_dot(task_vals[task_name], params=dict(list(model.named_parameters()))).render(task_name, format="png") | |
elif m.use_multi_task == MultiTaskType.PROJECTED_MULTI_TASK: | |
task_outputs = proj(m_val) | |
mp_vals.append(task_outputs.pop("projectors")) | |
for task_name in task_outputs.keys(): | |
if not task_name in task_vals: | |
task_vals[task_name] = [task_outputs[task_name]] | |
else: | |
task_vals[task_name].append(task_outputs[task_name]) | |
else: | |
mp_vals.append(proj(m_val)) | |
assert all( | |
mp_val.shape[1:] == (m.token_width, self.config.hidden_size) | |
for mp_val in mp_vals | |
), ( | |
"Modality tensors have incorrect shape, check your projector implementation " | |
+ str([mp_val.shape[1:] for mp_val in mp_vals]) | |
+ " vs expected " | |
+ str((m.token_width, self.config.hidden_size)) | |
) | |
projected_tensors.append(mp_vals) | |
indices = None | |
for i, input_ids_sample in enumerate(input_ids): | |
is_text_mask = input_ids_sample >= 0 | |
# fill in all the LLM-based text embeddings | |
inputs_embeds[i, is_text_mask] = model.embed_tokens( | |
input_ids_sample[is_text_mask] | |
) | |
# skip if all tokens are text tokens | |
if is_text_mask.sum() == seq_len: | |
continue | |
assert ( | |
past_key_values is None | |
), "We shouldn't have cached keys if this is the first instruction pass" | |
#past_key_values = None | |
for mi, m in enumerate(self.modalities): | |
# locate the group of tokens for this modality | |
m_mask = (input_ids_sample == m.token_idx).float() | |
m_kernel = torch.tensor( | |
[-1] * m.token_width, dtype=m_mask.dtype, device=m_mask.device | |
) | |
m_conv = conv1d( | |
m_mask.unsqueeze(0).unsqueeze(0), | |
m_kernel.unsqueeze(0).unsqueeze(0), | |
) | |
# where do we see `token_width`-tokens in a row? | |
indices = (m_conv[0, 0] == -m.token_width).nonzero(as_tuple=True)[0] | |
# fill these embeddings with the projected modality tensor | |
last_covered_idx = -1 | |
k = 0 | |
for possible_token_idx in indices: | |
if possible_token_idx <= last_covered_idx: | |
# make sure we don't overwrite an instance we've already covered | |
# handles bug caused by back-to-back tokens | |
continue | |
batch_modality_tensor = projected_tensors[mi][i][k] | |
inputs_embeds[ | |
i, possible_token_idx : possible_token_idx + m.token_width | |
] = batch_modality_tensor | |
last_covered_idx = possible_token_idx + m.token_width - 1 | |
k += 1 | |
return None, attention_mask, past_key_values, inputs_embeds, labels, task_vals | |