Spaces:
Sleeping
Sleeping
File size: 841 Bytes
c42fe7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
from torch import nn
class CategorizedModule(nn.Module):
@property
def category(self):
raise NotImplementedError()
def check_category(self, category):
if category is None:
raise RuntimeError('Category is not specified in this checkpoint.\n'
'If this is a checkpoint in the old format, please consider '
'migrating it to the new format via the following command:\n'
'python scripts/migrate.py ckpt <INPUT_CKPT> <OUTPUT_CKPT>')
elif category != self.category:
raise RuntimeError('Category mismatches!\n'
f'This checkpoint is of the category \'{category}\', '
f'but a checkpoint of category \'{self.category}\' is required.')
|