File size: 3,386 Bytes
1865436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright (c) Facebook, Inc. and its affiliates.

import json
import os
import tempfile
import time
import unittest
import torch
from torch import nn

from detectron2.config import configurable, get_cfg
from detectron2.engine import DefaultTrainer, SimpleTrainer, hooks
from detectron2.modeling.meta_arch import META_ARCH_REGISTRY
from detectron2.utils.events import CommonMetricPrinter, JSONWriter


@META_ARCH_REGISTRY.register()
class _SimpleModel(nn.Module):
    @configurable
    def __init__(self, sleep_sec=0):
        super().__init__()
        self.mod = nn.Linear(10, 20)
        self.sleep_sec = sleep_sec

    @classmethod
    def from_config(cls, cfg):
        return {}

    def forward(self, x):
        if self.sleep_sec > 0:
            time.sleep(self.sleep_sec)
        return {"loss": x.sum() + sum([x.mean() for x in self.parameters()])}


class TestTrainer(unittest.TestCase):
    def _data_loader(self, device):
        device = torch.device(device)
        while True:
            yield torch.rand(3, 3).to(device)

    def test_simple_trainer(self, device="cpu"):
        model = _SimpleModel().to(device=device)
        trainer = SimpleTrainer(
            model, self._data_loader(device), torch.optim.SGD(model.parameters(), 0.1)
        )
        trainer.train(0, 10)

    @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
    def test_simple_trainer_cuda(self):
        self.test_simple_trainer(device="cuda")

    def test_writer_hooks(self):
        model = _SimpleModel(sleep_sec=0.1)
        trainer = SimpleTrainer(
            model, self._data_loader("cpu"), torch.optim.SGD(model.parameters(), 0.1)
        )

        max_iter = 50

        with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
            json_file = os.path.join(d, "metrics.json")
            writers = [CommonMetricPrinter(max_iter), JSONWriter(json_file)]

            trainer.register_hooks(
                [hooks.EvalHook(0, lambda: {"metric": 100}), hooks.PeriodicWriter(writers)]
            )
            with self.assertLogs(writers[0].logger) as logs:
                trainer.train(0, max_iter)

            with open(json_file, "r") as f:
                data = [json.loads(line.strip()) for line in f]
                self.assertEqual([x["iteration"] for x in data], [19, 39, 49, 50])
                # the eval metric is in the last line with iter 50
                self.assertIn("metric", data[-1], "Eval metric must be in last line of JSON!")

            # test logged messages from CommonMetricPrinter
            self.assertEqual(len(logs.output), 3)
            for log, iter in zip(logs.output, [19, 39, 49]):
                self.assertIn(f"iter: {iter}", log)

            self.assertIn("eta: 0:00:00", logs.output[-1], "Last ETA must be 0!")

    @unittest.skipIf(os.environ.get("CI"), "Require COCO data.")
    def test_default_trainer(self):
        cfg = get_cfg()
        cfg.MODEL.META_ARCHITECTURE = "_SimpleModel"
        cfg.DATASETS.TRAIN = ("coco_2017_val_100",)
        with tempfile.TemporaryDirectory(prefix="detectron2_test") as d:
            cfg.OUTPUT_DIR = d
            trainer = DefaultTrainer(cfg)

            # test property
            self.assertIs(trainer.model, trainer._trainer.model)
            trainer.model = _SimpleModel()
            self.assertIs(trainer.model, trainer._trainer.model)