Spaces:
Runtime error
Runtime error
File size: 890 Bytes
f85e212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from pathlib import Path
import torch
from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN, VAE, VAEGAN
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
path_root = Path('runs/2022_12_01_210017_patho_vaegan')
# Load model
model = VAEGAN.load_from_checkpoint(path_root/'last.ckpt')
# model = torch.load(path_root/'last.ckpt')
# Save model-part
# torch.save(model.vqvae, path_root/'last_vae.ckpt') # Not working
# ------ Ugly workaround ----------
checkpointing = ModelCheckpoint()
trainer = Trainer(callbacks=[checkpointing])
trainer.strategy._lightning_module = model.vqvae
trainer.model = model.vqvae
trainer.save_checkpoint(path_root/'last_vae.ckpt')
# -----------------
model = VAE.load_from_checkpoint(path_root/'last_vae.ckpt')
# model = torch.load(path_root/'last_vae.ckpt') # load_state_dict |