Spaces:
Runtime error
Runtime error
File size: 582 Bytes
f85e212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN
input = torch.randn((1, 3, 16, 128, 128)) # [B, C, H, W]
model = VQVAE(in_channels=3, out_channels=3, spatial_dims = 3, emb_channels=1, deep_supervision=True)
# output = model(input)
# print(output)
loss = model._step({'source':input}, 1, 'train', 1, 1)
print(loss)
# model = VQGAN(in_channels=3, out_channels=3, spatial_dims = 3, emb_channels=1, deep_supervision=True)
# # output = model(input)
# # print(output)
# loss = model._step({'source':input}, 1, 'train', 1, 1)
# print(loss)
|