File size: 890 Bytes
f85e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from pathlib import Path 
import torch 
from medical_diffusion.models.embedders.latent_embedders import VQVAE, VQGAN, VAE, VAEGAN
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

path_root = Path('runs/2022_12_01_210017_patho_vaegan')

# Load model 
model = VAEGAN.load_from_checkpoint(path_root/'last.ckpt')
# model = torch.load(path_root/'last.ckpt') 



# Save model-part 
# torch.save(model.vqvae, path_root/'last_vae.ckpt') # Not working 
# ------ Ugly workaround ----------
checkpointing = ModelCheckpoint()
trainer = Trainer(callbacks=[checkpointing])
trainer.strategy._lightning_module = model.vqvae 
trainer.model = model.vqvae 
trainer.save_checkpoint(path_root/'last_vae.ckpt')
# -----------------

model = VAE.load_from_checkpoint(path_root/'last_vae.ckpt')
# model = torch.load(path_root/'last_vae.ckpt')  # load_state_dict