File size: 2,610 Bytes
9665c2c
 
 
 
 
 
b0cf684
9665c2c
 
 
 
 
 
 
b0cf684
 
9665c2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright (c) Meta Platforms, Inc. and affiliates.

import argparse
from pathlib import Path
from typing import Optional, Tuple

from omegaconf import DictConfig, OmegaConf

from .. import logger
from ..data import KittiDataModule
from .run import evaluate

default_cfg_single = OmegaConf.create({})
# For the sequential evaluation, we need to center the map around the GT location,
# since random offsets would accumulate and leave only the GT location with
# a valid mask. This should not have much impact on the results.
default_cfg_sequential = OmegaConf.create(
    {
        "data": {
            "mask_radius": KittiDataModule.default_cfg["max_init_error"],
            "prior_range_rotation": KittiDataModule.default_cfg[
                "max_init_error_rotation"
            ]
            + 1,
            "max_init_error": 0,
            "max_init_error_rotation": 0,
        },
        "chunking": {
            "max_length": 100,  # about 10s?
        },
    }
)


def run(
    split: str,
    experiment: str,
    cfg: Optional[DictConfig] = None,
    sequential: bool = False,
    thresholds: Tuple[int] = (1, 3, 5),
    **kwargs,
):
    cfg = cfg or {}
    if isinstance(cfg, dict):
        cfg = OmegaConf.create(cfg)
    default = default_cfg_sequential if sequential else default_cfg_single
    cfg = OmegaConf.merge(default, cfg)
    dataset = KittiDataModule(cfg.get("data", {}))

    metrics = evaluate(
        experiment,
        cfg,
        dataset,
        split=split,
        sequential=sequential,
        viz_kwargs=dict(show_dir_error=True, show_masked_prob=False),
        **kwargs,
    )

    keys = ["directional_error", "yaw_max_error"]
    if sequential:
        keys += ["directional_seq_error", "yaw_seq_error"]
    for k in keys:
        rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
        logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
    return metrics


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--experiment", type=str, required=True)
    parser.add_argument(
        "--split", type=str, default="test", choices=["test", "val", "train"]
    )
    parser.add_argument("--sequential", action="store_true")
    parser.add_argument("--output_dir", type=Path)
    parser.add_argument("--num", type=int)
    parser.add_argument("dotlist", nargs="*")
    args = parser.parse_args()
    cfg = OmegaConf.from_cli(args.dotlist)
    run(
        args.split,
        args.experiment,
        cfg,
        args.sequential,
        output_dir=args.output_dir,
        num=args.num,
    )