File size: 23,618 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple

try:
    import lap
except ImportError:
    lap = None
import numpy as np
import torch
from addict import Dict
from mmengine.structures import InstanceData

from mmdet.registry import MODELS
from mmdet.structures import DetDataSample
from mmdet.structures.bbox import (bbox_cxcyah_to_xyxy, bbox_overlaps,
                                   bbox_xyxy_to_cxcyah)
from .sort_tracker import SORTTracker


@MODELS.register_module()
class OCSORTTracker(SORTTracker):
    """Tracker for OC-SORT.

    Args:
        motion (dict): Configuration of motion. Defaults to None.
        obj_score_thrs (float): Detection score threshold for matching objects.
            Defaults to 0.3.
        init_track_thr (float): Detection score threshold for initializing a
            new tracklet. Defaults to 0.7.
        weight_iou_with_det_scores (bool): Whether using detection scores to
            weight IOU which is used for matching. Defaults to True.
        match_iou_thr (float): IOU distance threshold for matching between two
            frames. Defaults to 0.3.
        num_tentatives (int, optional): Number of continuous frames to confirm
            a track. Defaults to 3.
        vel_consist_weight (float): Weight of the velocity consistency term in
            association (OCM term in the paper).
        vel_delta_t (int): The difference of time step for calculating of the
            velocity direction of tracklets.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(self,
                 motion: Optional[dict] = None,
                 obj_score_thr: float = 0.3,
                 init_track_thr: float = 0.7,
                 weight_iou_with_det_scores: bool = True,
                 match_iou_thr: float = 0.3,
                 num_tentatives: int = 3,
                 vel_consist_weight: float = 0.2,
                 vel_delta_t: int = 3,
                 **kwargs):
        if lap is None:
            raise RuntimeError('lap is not installed,\
                 please install it by: pip install lap')
        super().__init__(motion=motion, **kwargs)
        self.obj_score_thr = obj_score_thr
        self.init_track_thr = init_track_thr

        self.weight_iou_with_det_scores = weight_iou_with_det_scores
        self.match_iou_thr = match_iou_thr
        self.vel_consist_weight = vel_consist_weight
        self.vel_delta_t = vel_delta_t

        self.num_tentatives = num_tentatives

    @property
    def unconfirmed_ids(self):
        """Unconfirmed ids in the tracker."""
        ids = [id for id, track in self.tracks.items() if track.tentative]
        return ids

    def init_track(self, id: int, obj: Tuple[torch.Tensor]):
        """Initialize a track."""
        super().init_track(id, obj)
        if self.tracks[id].frame_ids[-1] == 0:
            self.tracks[id].tentative = False
        else:
            self.tracks[id].tentative = True
        bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1])  # size = (1, 4)
        assert bbox.ndim == 2 and bbox.shape[0] == 1
        bbox = bbox.squeeze(0).cpu().numpy()
        self.tracks[id].mean, self.tracks[id].covariance = self.kf.initiate(
            bbox)
        # track.obs maintains the history associated detections to this track
        self.tracks[id].obs = []
        bbox_id = self.memo_items.index('bboxes')
        self.tracks[id].obs.append(obj[bbox_id])
        # a placefolder to save mean/covariance before losing tracking it
        # parameters to save: mean, covariance, measurement
        self.tracks[id].tracked = True
        self.tracks[id].saved_attr = Dict()
        self.tracks[id].velocity = torch.tensor(
            (-1, -1)).to(obj[bbox_id].device)  # placeholder

    def update_track(self, id: int, obj: Tuple[torch.Tensor]):
        """Update a track."""
        super().update_track(id, obj)
        if self.tracks[id].tentative:
            if len(self.tracks[id]['bboxes']) >= self.num_tentatives:
                self.tracks[id].tentative = False
        bbox = bbox_xyxy_to_cxcyah(self.tracks[id].bboxes[-1])  # size = (1, 4)
        assert bbox.ndim == 2 and bbox.shape[0] == 1
        bbox = bbox.squeeze(0).cpu().numpy()
        self.tracks[id].mean, self.tracks[id].covariance = self.kf.update(
            self.tracks[id].mean, self.tracks[id].covariance, bbox)
        self.tracks[id].tracked = True
        bbox_id = self.memo_items.index('bboxes')
        self.tracks[id].obs.append(obj[bbox_id])

        bbox1 = self.k_step_observation(self.tracks[id])
        bbox2 = obj[bbox_id]
        self.tracks[id].velocity = self.vel_direction(bbox1, bbox2).to(
            obj[bbox_id].device)

    def vel_direction(self, bbox1: torch.Tensor, bbox2: torch.Tensor):
        """Estimate the direction vector between two boxes."""
        if bbox1.sum() < 0 or bbox2.sum() < 0:
            return torch.tensor((-1, -1))
        cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0
        cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0
        speed = torch.tensor([cy2 - cy1, cx2 - cx1])
        norm = torch.sqrt((speed[0])**2 + (speed[1])**2) + 1e-6
        return speed / norm

    def vel_direction_batch(self, bboxes1: torch.Tensor,
                            bboxes2: torch.Tensor):
        """Estimate the direction vector given two batches of boxes."""
        cx1, cy1 = (bboxes1[:, 0] + bboxes1[:, 2]) / 2.0, (bboxes1[:, 1] +
                                                           bboxes1[:, 3]) / 2.0
        cx2, cy2 = (bboxes2[:, 0] + bboxes2[:, 2]) / 2.0, (bboxes2[:, 1] +
                                                           bboxes2[:, 3]) / 2.0
        speed_diff_y = cy2[None, :] - cy1[:, None]
        speed_diff_x = cx2[None, :] - cx1[:, None]
        speed = torch.cat((speed_diff_y[..., None], speed_diff_x[..., None]),
                          dim=-1)
        norm = torch.sqrt((speed[:, :, 0])**2 + (speed[:, :, 1])**2) + 1e-6
        speed[:, :, 0] /= norm
        speed[:, :, 1] /= norm
        return speed

    def k_step_observation(self, track: Dict):
        """return the observation k step away before."""
        obs_seqs = track.obs
        num_obs = len(obs_seqs)
        if num_obs == 0:
            return torch.tensor((-1, -1, -1, -1)).to(track.obs[0].device)
        elif num_obs > self.vel_delta_t:
            if obs_seqs[num_obs - 1 - self.vel_delta_t] is not None:
                return obs_seqs[num_obs - 1 - self.vel_delta_t]
            else:
                return self.last_obs(track)
        else:
            return self.last_obs(track)

    def ocm_assign_ids(self,
                       ids: List[int],
                       det_bboxes: torch.Tensor,
                       det_labels: torch.Tensor,
                       det_scores: torch.Tensor,
                       weight_iou_with_det_scores: Optional[bool] = False,
                       match_iou_thr: Optional[float] = 0.5):
        """Apply Observation-Centric Momentum (OCM) to assign ids.

        OCM adds movement direction consistency into the association cost
        matrix. This term requires no additional assumption but from the
        same linear motion assumption as the canonical Kalman Filter in SORT.

        Args:
            ids (list[int]): Tracking ids.
            det_bboxes (Tensor): of shape (N, 4)
            det_labels (Tensor): of shape (N,)
            det_scores (Tensor): of shape (N,)
            weight_iou_with_det_scores (bool, optional): Whether using
                detection scores to weight IOU which is used for matching.
                Defaults to False.
            match_iou_thr (float, optional): Matching threshold.
                Defaults to 0.5.

        Returns:
            tuple(int): The assigning ids.

        OC-SORT uses velocity consistency besides IoU for association
        """
        # get track_bboxes
        track_bboxes = np.zeros((0, 4))
        for id in ids:
            track_bboxes = np.concatenate(
                (track_bboxes, self.tracks[id].mean[:4][None]), axis=0)
        track_bboxes = torch.from_numpy(track_bboxes).to(det_bboxes)
        track_bboxes = bbox_cxcyah_to_xyxy(track_bboxes)

        # compute distance
        ious = bbox_overlaps(track_bboxes, det_bboxes)
        if weight_iou_with_det_scores:
            ious *= det_scores

        # support multi-class association
        track_labels = torch.tensor([
            self.tracks[id]['labels'][-1] for id in ids
        ]).to(det_bboxes.device)
        cate_match = det_labels[None, :] == track_labels[:, None]
        # to avoid det and track of different categories are matched
        cate_cost = (1 - cate_match.int()) * 1e6

        dists = (1 - ious + cate_cost).cpu().numpy()

        if len(ids) > 0 and len(det_bboxes) > 0:
            track_velocities = torch.stack(
                [self.tracks[id].velocity for id in ids]).to(det_bboxes.device)
            k_step_observations = torch.stack([
                self.k_step_observation(self.tracks[id]) for id in ids
            ]).to(det_bboxes.device)
            # valid1: if the track has previous observations to estimate speed
            # valid2: if the associated observation k steps ago is a detection
            valid1 = track_velocities.sum(dim=1) != -2
            valid2 = k_step_observations.sum(dim=1) != -4
            valid = valid1 & valid2

            vel_to_match = self.vel_direction_batch(k_step_observations,
                                                    det_bboxes)
            track_velocities = track_velocities[:, None, :].repeat(
                1, det_bboxes.shape[0], 1)

            angle_cos = (vel_to_match * track_velocities).sum(dim=-1)
            angle_cos = torch.clamp(angle_cos, min=-1, max=1)
            angle = torch.acos(angle_cos)  # [0, pi]
            norm_angle = (angle - np.pi / 2.) / np.pi  # [-0.5, 0.5]
            valid_matrix = valid[:, None].int().repeat(1, det_bboxes.shape[0])
            # set non-valid entries 0
            valid_norm_angle = norm_angle * valid_matrix

            dists += valid_norm_angle.cpu().numpy() * self.vel_consist_weight

        # bipartite match
        if dists.size > 0:
            cost, row, col = lap.lapjv(
                dists, extend_cost=True, cost_limit=1 - match_iou_thr)
        else:
            row = np.zeros(len(ids)).astype(np.int32) - 1
            col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
        return row, col

    def last_obs(self, track: Dict):
        """extract the last associated observation."""
        for bbox in track.obs[::-1]:
            if bbox is not None:
                return bbox

    def ocr_assign_ids(self,
                       track_obs: torch.Tensor,
                       last_track_labels: torch.Tensor,
                       det_bboxes: torch.Tensor,
                       det_labels: torch.Tensor,
                       det_scores: torch.Tensor,
                       weight_iou_with_det_scores: Optional[bool] = False,
                       match_iou_thr: Optional[float] = 0.5):
        """association for Observation-Centric Recovery.

        As try to recover tracks from being lost whose estimated velocity is
        out- to-date, we use IoU-only matching strategy.

        Args:
            track_obs (Tensor): the list of historical associated
                detections of tracks
            det_bboxes (Tensor): of shape (N, 5), unmatched detections
            det_labels (Tensor): of shape (N,)
            det_scores (Tensor): of shape (N,)
            weight_iou_with_det_scores (bool, optional): Whether using
                detection scores to weight IOU which is used for matching.
                Defaults to False.
            match_iou_thr (float, optional): Matching threshold.
                Defaults to 0.5.

        Returns:
            tuple(int): The assigning ids.
        """
        # compute distance
        ious = bbox_overlaps(track_obs, det_bboxes)
        if weight_iou_with_det_scores:
            ious *= det_scores

        # support multi-class association
        cate_match = det_labels[None, :] == last_track_labels[:, None]
        # to avoid det and track of different categories are matched
        cate_cost = (1 - cate_match.int()) * 1e6

        dists = (1 - ious + cate_cost).cpu().numpy()

        # bipartite match
        if dists.size > 0:
            cost, row, col = lap.lapjv(
                dists, extend_cost=True, cost_limit=1 - match_iou_thr)
        else:
            row = np.zeros(len(track_obs)).astype(np.int32) - 1
            col = np.zeros(len(det_bboxes)).astype(np.int32) - 1
        return row, col

    def online_smooth(self, track: Dict, obj: torch.Tensor):
        """Once a track is recovered from being lost, online smooth its
        parameters to fix the error accumulated during being lost.

        NOTE: you can use different virtual trajectory generation
        strategies, we adopt the naive linear interpolation as default
        """
        last_match_bbox = self.last_obs(track)
        new_match_bbox = obj
        unmatch_len = 0
        for bbox in track.obs[::-1]:
            if bbox is None:
                unmatch_len += 1
            else:
                break
        bbox_shift_per_step = (new_match_bbox - last_match_bbox) / (
            unmatch_len + 1)
        track.mean = track.saved_attr.mean
        track.covariance = track.saved_attr.covariance
        for i in range(unmatch_len):
            virtual_bbox = last_match_bbox + (i + 1) * bbox_shift_per_step
            virtual_bbox = bbox_xyxy_to_cxcyah(virtual_bbox[None, :])
            virtual_bbox = virtual_bbox.squeeze(0).cpu().numpy()
            track.mean, track.covariance = self.kf.update(
                track.mean, track.covariance, virtual_bbox)

    def track(self, data_sample: DetDataSample, **kwargs) -> InstanceData:
        """Tracking forward function.
        NOTE: this implementation is slightly different from the original
        OC-SORT implementation (https://github.com/noahcao/OC_SORT)that we
        do association between detections and tentative/non-tentative tracks
        independently while the original implementation combines them together.

        Args:
            data_sample (:obj:`DetDataSample`): The data sample.
                It includes information such as `pred_instances`.

        Returns:
            :obj:`InstanceData`: Tracking results of the input images.
            Each InstanceData usually contains ``bboxes``, ``labels``,
            ``scores`` and ``instances_id``.
        """
        metainfo = data_sample.metainfo
        bboxes = data_sample.pred_instances.bboxes
        labels = data_sample.pred_instances.labels
        scores = data_sample.pred_instances.scores
        frame_id = metainfo.get('frame_id', -1)
        if frame_id == 0:
            self.reset()
        if not hasattr(self, 'kf'):
            self.kf = self.motion

        if self.empty or bboxes.size(0) == 0:
            valid_inds = scores > self.init_track_thr
            scores = scores[valid_inds]
            bboxes = bboxes[valid_inds]
            labels = labels[valid_inds]
            num_new_tracks = bboxes.size(0)
            ids = torch.arange(self.num_tracks,
                               self.num_tracks + num_new_tracks).to(labels)
            self.num_tracks += num_new_tracks
        else:
            # 0. init
            ids = torch.full((bboxes.size(0), ),
                             -1,
                             dtype=labels.dtype,
                             device=labels.device)

            # get the detection bboxes for the first association
            det_inds = scores > self.obj_score_thr
            det_bboxes = bboxes[det_inds]
            det_labels = labels[det_inds]
            det_scores = scores[det_inds]
            det_ids = ids[det_inds]

            # 1. predict by Kalman Filter
            for id in self.confirmed_ids:
                # track is lost in previous frame
                if self.tracks[id].frame_ids[-1] != frame_id - 1:
                    self.tracks[id].mean[7] = 0
                if self.tracks[id].tracked:
                    self.tracks[id].saved_attr.mean = self.tracks[id].mean
                    self.tracks[id].saved_attr.covariance = self.tracks[
                        id].covariance
                (self.tracks[id].mean,
                 self.tracks[id].covariance) = self.kf.predict(
                     self.tracks[id].mean, self.tracks[id].covariance)

            # 2. match detections and tracks' predicted locations
            match_track_inds, raw_match_det_inds = self.ocm_assign_ids(
                self.confirmed_ids, det_bboxes, det_labels, det_scores,
                self.weight_iou_with_det_scores, self.match_iou_thr)
            # '-1' mean a detection box is not matched with tracklets in
            # previous frame
            valid = raw_match_det_inds > -1
            det_ids[valid] = torch.tensor(
                self.confirmed_ids)[raw_match_det_inds[valid]].to(labels)

            match_det_bboxes = det_bboxes[valid]
            match_det_labels = det_labels[valid]
            match_det_scores = det_scores[valid]
            match_det_ids = det_ids[valid]
            assert (match_det_ids > -1).all()

            # unmatched tracks and detections
            unmatch_det_bboxes = det_bboxes[~valid]
            unmatch_det_labels = det_labels[~valid]
            unmatch_det_scores = det_scores[~valid]
            unmatch_det_ids = det_ids[~valid]
            assert (unmatch_det_ids == -1).all()

            # 3. use unmatched detection bboxes from the first match to match
            # the unconfirmed tracks
            (tentative_match_track_inds,
             tentative_match_det_inds) = self.ocm_assign_ids(
                 self.unconfirmed_ids, unmatch_det_bboxes, unmatch_det_labels,
                 unmatch_det_scores, self.weight_iou_with_det_scores,
                 self.match_iou_thr)
            valid = tentative_match_det_inds > -1
            unmatch_det_ids[valid] = torch.tensor(self.unconfirmed_ids)[
                tentative_match_det_inds[valid]].to(labels)

            match_det_bboxes = torch.cat(
                (match_det_bboxes, unmatch_det_bboxes[valid]), dim=0)
            match_det_labels = torch.cat(
                (match_det_labels, unmatch_det_labels[valid]), dim=0)
            match_det_scores = torch.cat(
                (match_det_scores, unmatch_det_scores[valid]), dim=0)
            match_det_ids = torch.cat((match_det_ids, unmatch_det_ids[valid]),
                                      dim=0)
            assert (match_det_ids > -1).all()

            unmatch_det_bboxes = unmatch_det_bboxes[~valid]
            unmatch_det_labels = unmatch_det_labels[~valid]
            unmatch_det_scores = unmatch_det_scores[~valid]
            unmatch_det_ids = unmatch_det_ids[~valid]
            assert (unmatch_det_ids == -1).all()

            all_track_ids = [id for id, _ in self.tracks.items()]
            unmatched_track_inds = torch.tensor(
                [ind for ind in all_track_ids if ind not in match_det_ids])

            if len(unmatched_track_inds) > 0:
                # 4. still some tracks not associated yet, perform OCR
                last_observations = []
                for id in unmatched_track_inds:
                    last_box = self.last_obs(self.tracks[id.item()])
                    last_observations.append(last_box)
                last_observations = torch.stack(last_observations)
                last_track_labels = torch.tensor([
                    self.tracks[id.item()]['labels'][-1]
                    for id in unmatched_track_inds
                ]).to(det_bboxes.device)

                remain_det_ids = torch.full((unmatch_det_bboxes.size(0), ),
                                            -1,
                                            dtype=labels.dtype,
                                            device=labels.device)

                _, ocr_match_det_inds = self.ocr_assign_ids(
                    last_observations, last_track_labels, unmatch_det_bboxes,
                    unmatch_det_labels, unmatch_det_scores,
                    self.weight_iou_with_det_scores, self.match_iou_thr)

                valid = ocr_match_det_inds > -1
                remain_det_ids[valid] = unmatched_track_inds.clone()[
                    ocr_match_det_inds[valid]].to(labels)

                ocr_match_det_bboxes = unmatch_det_bboxes[valid]
                ocr_match_det_labels = unmatch_det_labels[valid]
                ocr_match_det_scores = unmatch_det_scores[valid]
                ocr_match_det_ids = remain_det_ids[valid]
                assert (ocr_match_det_ids > -1).all()

                ocr_unmatch_det_bboxes = unmatch_det_bboxes[~valid]
                ocr_unmatch_det_labels = unmatch_det_labels[~valid]
                ocr_unmatch_det_scores = unmatch_det_scores[~valid]
                ocr_unmatch_det_ids = remain_det_ids[~valid]
                assert (ocr_unmatch_det_ids == -1).all()

                unmatch_det_bboxes = ocr_unmatch_det_bboxes
                unmatch_det_labels = ocr_unmatch_det_labels
                unmatch_det_scores = ocr_unmatch_det_scores
                unmatch_det_ids = ocr_unmatch_det_ids
                match_det_bboxes = torch.cat(
                    (match_det_bboxes, ocr_match_det_bboxes), dim=0)
                match_det_labels = torch.cat(
                    (match_det_labels, ocr_match_det_labels), dim=0)
                match_det_scores = torch.cat(
                    (match_det_scores, ocr_match_det_scores), dim=0)
                match_det_ids = torch.cat((match_det_ids, ocr_match_det_ids),
                                          dim=0)

            # 5. summarize the track results
            for i in range(len(match_det_ids)):
                det_bbox = match_det_bboxes[i]
                track_id = match_det_ids[i].item()
                if not self.tracks[track_id].tracked:
                    # the track is lost before this step
                    self.online_smooth(self.tracks[track_id], det_bbox)

            for track_id in all_track_ids:
                if track_id not in match_det_ids:
                    self.tracks[track_id].tracked = False
                    self.tracks[track_id].obs.append(None)

            bboxes = torch.cat((match_det_bboxes, unmatch_det_bboxes), dim=0)
            labels = torch.cat((match_det_labels, unmatch_det_labels), dim=0)
            scores = torch.cat((match_det_scores, unmatch_det_scores), dim=0)
            ids = torch.cat((match_det_ids, unmatch_det_ids), dim=0)
            # 6. assign new ids
            new_track_inds = ids == -1

            ids[new_track_inds] = torch.arange(
                self.num_tracks,
                self.num_tracks + new_track_inds.sum()).to(labels)
            self.num_tracks += new_track_inds.sum()

        self.update(
            ids=ids,
            bboxes=bboxes,
            labels=labels,
            scores=scores,
            frame_ids=frame_id)

        # update pred_track_instances
        pred_track_instances = InstanceData()
        pred_track_instances.bboxes = bboxes
        pred_track_instances.labels = labels
        pred_track_instances.scores = scores
        pred_track_instances.instances_id = ids
        return pred_track_instances