File size: 841 Bytes
c42fe7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from torch import nn


class CategorizedModule(nn.Module):
    @property
    def category(self):
        raise NotImplementedError()

    def check_category(self, category):
        if category is None:
            raise RuntimeError('Category is not specified in this checkpoint.\n'
                               'If this is a checkpoint in the old format, please consider '
                               'migrating it to the new format via the following command:\n'
                               'python scripts/migrate.py ckpt <INPUT_CKPT> <OUTPUT_CKPT>')
        elif category != self.category:
            raise RuntimeError('Category mismatches!\n'
                               f'This checkpoint is of the category \'{category}\', '
                               f'but a checkpoint of category \'{self.category}\' is required.')