File size: 4,123 Bytes
f97a499
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import wandb
import torch
import numpy as np
from monai.visualize import blend_images


class WandBModel:
    """
    Enable WandB features to the model using multiple inheritance
    """
    def __init__(self, *args, **kwargs):
        # the following attributes should be initialized by class `BaseSegmentationModel`
        self.visual_pairs = None
        self.train_loss = None
        self.val_loss = None
        self.metric_meter = None
        self.name = None
        # the following attributes should be initialized by the child class
        self.val_table = None

    def volume2videos(self, time_dim=3, tag=''):
        """
        Convert 3D volumes to video in favor of WandB logging
        Args:
            time_dim: the spatial dimension to be converted as the time dimension, default is the axial axis (dim 3)
            tag: extra information for logging
        """
        videos = []
        for image_pair in self.visual_pairs:
            try:
                pair_name = getattr(self, image_pair['name'])
                image = getattr(self, image_pair['image'])
                mask = getattr(self, image_pair['mask'])
                vis_type = image_pair['type']
            except:
                continue
            for i in range(image.shape[0]):  # deallocate the batch dim
                image2save = image[i, ...]
                mask2save = mask[i, ...]
                item_name = pair_name[i]
                # detach the tensor, format [C, H, W, D]
                image_numpy = image2save.detach()
                mask_numpy = mask2save.detach()
                if mask_numpy.shape[0] > 1:
                    mask_numpy = torch.argmax(mask_numpy, dim=0, keepdim=True)
                # (C, H, W, D), torch.Tensor on device
                pair_blend = blend_images(image_numpy, mask_numpy, alpha=0.5) * 255
                # permute the axes to (time, channel, height, width)
                spatial_dim = list(range(1, len(pair_blend.shape[1:]) + 1))
                spatial_dim.remove(time_dim)
                pair_blend = pair_blend.permute([time_dim, 0] + spatial_dim).cpu().numpy().astype(np.uint8)
                # record in the wandb.Video class
                video = wandb.Video(pair_blend, fps=8, caption='{}_{}{}'.format(item_name, vis_type, tag))
                videos.append(video)
        return videos

    def log_scaler(self, key, value, step=None):
        """
        Log manually defined scaler data
        """
        wandb.log({key: np.round(value, decimals=4)}, step=step)

    def log_train_loss(self, step=None):
        """
        Log train loss
        """
        data_dict = self.train_loss.pop_data(True)
        for key, value in data_dict.items():
            wandb.log({'train/{}'.format(key): value}, step=step)

    def log_val_loss(self, step=None):
        """
        Log val loss
        """
        data_dict = self.val_loss.pop_data(True)
        for key, value in data_dict.items():
            wandb.log({'val/{}'.format(key): value}, step=step)

    def log_metrics(self, step=None):
        """
        Log validation metrics as wandb.Table
        """
        df = self.metric_meter.to_df()
        wandb.log({'val/metrics': wandb.Table(dataframe=df)}, step=step)

    def log_vis(self, key, step=None, time_dim=3, tag=''):
        """
        Log training intermediate visualizations
        """
        videos = self.volume2videos(time_dim, tag)
        wandb.log({key: videos}, step=step)

    def update_val_visualization(self, time_dim=3, tag=''):
        """
        Update the validation visualization to buffer, called every step of evaluation
        """
        videos = self.volume2videos(time_dim, tag)
        self.val_table.add_data(self.name, *videos)

    def log_val_visualization(self, step=None):
        """
        Log validation visualization
        """
        wandb.log({'val/visualization': self.val_table}, step=step)
        # re-initialize the table for next logging
        del self.val_table
        self.val_table = wandb.Table(columns=['ID'] + [pair['type'] for pair in self.visual_pairs])