File size: 582 Bytes
f85e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch 
from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN


input = torch.randn((1, 3, 16, 128, 128)) # [B, C, H, W]


model = VQVAE(in_channels=3, out_channels=3, spatial_dims = 3, emb_channels=1, deep_supervision=True)
# output = model(input)
# print(output)
loss = model._step({'source':input}, 1, 'train', 1, 1)
print(loss)


# model = VQGAN(in_channels=3, out_channels=3, spatial_dims = 3, emb_channels=1, deep_supervision=True)
# # output = model(input)
# # print(output)
# loss = model._step({'source':input}, 1, 'train', 1, 1)
# print(loss)