BeamDiffusion / models /CoSeD /cross_attention.py
Gui28F's picture
uploaded all project files
173ea2b verified
from BeamDiffusionModel.models.CoSeD.sequence_predictor import SoftAttention
import torch
model = SoftAttention.load_from_checkpoint(
"./BeamDiffusionModel/models/CoSeD/checkpoints/epoch=19-step=140.ckpt")
# "/user/home/vcc.ramos/latent_training/sft/reference_training/9q3eu8vi/checkpoints/epoch=7-step=15.ckpt"
model.eval()
def get_softmax(previous_steps_embeddings, previous_images_embeddings, current_steps_embeddings,
current_images_embeddings):
previous_steps_tensor = torch.cat(previous_steps_embeddings, dim=0).to("cpu").unsqueeze(0)
previous_images_tensor = torch.cat(previous_images_embeddings, dim=0).to("cpu").unsqueeze(0)
current_steps_tensor = torch.cat(current_steps_embeddings).to("cpu").unsqueeze(0)
current_images_tensor = torch.cat(current_images_embeddings).to("cpu").unsqueeze(0)
with torch.no_grad():
softmax, logit = model(current_steps_tensor,
current_images_tensor,
previous_steps_tensor,
previous_images_tensor
)
if len(softmax.shape) <= 1:
return softmax
# sum and normalize the softmax values
return torch.sum(softmax, dim=-1) / softmax.shape[-1]