from BeamDiffusionModel.models.CoSeD.sequence_predictor import SoftAttention | |
import torch | |
model = SoftAttention.load_from_checkpoint( | |
"./BeamDiffusionModel/models/CoSeD/checkpoints/epoch=19-step=140.ckpt") | |
# "/user/home/vcc.ramos/latent_training/sft/reference_training/9q3eu8vi/checkpoints/epoch=7-step=15.ckpt" | |
model.eval() | |
def get_softmax(previous_steps_embeddings, previous_images_embeddings, current_steps_embeddings, | |
current_images_embeddings): | |
previous_steps_tensor = torch.cat(previous_steps_embeddings, dim=0).to("cpu").unsqueeze(0) | |
previous_images_tensor = torch.cat(previous_images_embeddings, dim=0).to("cpu").unsqueeze(0) | |
current_steps_tensor = torch.cat(current_steps_embeddings).to("cpu").unsqueeze(0) | |
current_images_tensor = torch.cat(current_images_embeddings).to("cpu").unsqueeze(0) | |
with torch.no_grad(): | |
softmax, logit = model(current_steps_tensor, | |
current_images_tensor, | |
previous_steps_tensor, | |
previous_images_tensor | |
) | |
if len(softmax.shape) <= 1: | |
return softmax | |
# sum and normalize the softmax values | |
return torch.sum(softmax, dim=-1) / softmax.shape[-1] | |