File size: 1,312 Bytes
173ea2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from BeamDiffusionModel.models.CoSeD.sequence_predictor import SoftAttention
import torch

model = SoftAttention.load_from_checkpoint(
    "./BeamDiffusionModel/models/CoSeD/checkpoints/epoch=19-step=140.ckpt")
# "/user/home/vcc.ramos/latent_training/sft/reference_training/9q3eu8vi/checkpoints/epoch=7-step=15.ckpt"

model.eval()


def get_softmax(previous_steps_embeddings, previous_images_embeddings, current_steps_embeddings,

                current_images_embeddings):
    previous_steps_tensor = torch.cat(previous_steps_embeddings, dim=0).to("cpu").unsqueeze(0)
    previous_images_tensor = torch.cat(previous_images_embeddings, dim=0).to("cpu").unsqueeze(0)
    current_steps_tensor = torch.cat(current_steps_embeddings).to("cpu").unsqueeze(0)
    current_images_tensor = torch.cat(current_images_embeddings).to("cpu").unsqueeze(0)
    with torch.no_grad():
        softmax, logit = model(current_steps_tensor,
                               current_images_tensor,
                               previous_steps_tensor,
                               previous_images_tensor
                               )
        if len(softmax.shape) <= 1:
            return softmax
        # sum and normalize the softmax values
        return torch.sum(softmax, dim=-1) / softmax.shape[-1]