File size: 10,601 Bytes
427d150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
"""

Customized dataset. Extended from vanilla PANet script by Wang et al.

"""

import os
import random
import torch
import numpy as np

from dataloaders.common import ReloadPairedDataset, ValidationDataset
from dataloaders.ManualAnnoDatasetv2 import ManualAnnoDataset

def attrib_basic(_sample, class_id):
    """

    Add basic attribute

    Args:

        _sample: data sample

        class_id: class label asscociated with the data

            (sometimes indicting from which subset the data are drawn)

    """
    return {'class_id': class_id}

def getMaskOnly(label, class_id, class_ids):
    """

    Generate FG/BG mask from the segmentation mask



    Args:

        label:

            semantic mask

        scribble:

            scribble mask

        class_id:

            semantic class of interest

        class_ids:

            all class id in this episode

    """
    # Dense Mask
    fg_mask = torch.where(label == class_id,
                          torch.ones_like(label), torch.zeros_like(label))
    bg_mask = torch.where(label != class_id,
                          torch.ones_like(label), torch.zeros_like(label))
    for class_id in class_ids:
        bg_mask[label == class_id] = 0

    return {'fg_mask': fg_mask,
            'bg_mask': bg_mask}

def getMasks(*args, **kwargs):
    raise NotImplementedError

def fewshot_pairing(paired_sample, n_ways, n_shots, cnt_query, coco=False, mask_only = True):
    """

    Postprocess paired sample for fewshot settings

    For now only 1-way is tested but we leave multi-way possible (inherited from original PANet)



    Args:

        paired_sample:

            data sample from a PairedDataset

        n_ways:

            n-way few-shot learning

        n_shots:

            n-shot few-shot learning

        cnt_query:

            number of query images for each class in the support set

        coco:

            MS COCO dataset. This is from the original PANet dataset but lets keep it for further extension

        mask_only:

            only give masks and no scribbles/ instances. Suitable for medical images (for now)

    """
    if not mask_only:
        raise NotImplementedError
    ###### Compose the support and query image list ######
    cumsum_idx = np.cumsum([0,] + [n_shots + x for x in cnt_query]) # seperation for supports and queries

    # support class ids
    class_ids = [paired_sample[cumsum_idx[i]]['basic_class_id'] for i in range(n_ways)] # class ids for each image (support and query)

    # support images
    support_images = [[paired_sample[cumsum_idx[i] + j]['image'] for j in range(n_shots)]
                      for i in range(n_ways)] # fetch support images for each class

    # support image labels
    if coco:
        support_labels = [[paired_sample[cumsum_idx[i] + j]['label'][class_ids[i]]
                           for j in range(n_shots)] for i in range(n_ways)]
    else:
        support_labels = [[paired_sample[cumsum_idx[i] + j]['label'] for j in range(n_shots)]
                          for i in range(n_ways)]

    if not mask_only:
        support_scribbles = [[paired_sample[cumsum_idx[i] + j]['scribble'] for j in range(n_shots)]
                             for i in range(n_ways)]
        support_insts = [[paired_sample[cumsum_idx[i] + j]['inst'] for j in range(n_shots)]
                         for i in range(n_ways)]
    else:
        support_insts = []

    # query images, masks and class indices
    query_images = [paired_sample[cumsum_idx[i+1] - j - 1]['image'] for i in range(n_ways)
                    for j in range(cnt_query[i])]
    if coco:
        query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'][class_ids[i]]
                        for i in range(n_ways) for j in range(cnt_query[i])]
    else:
        query_labels = [paired_sample[cumsum_idx[i+1] - j - 1]['label'] for i in range(n_ways)
                        for j in range(cnt_query[i])]
    query_cls_idx = [sorted([0,] + [class_ids.index(x) + 1
                                    for x in set(np.unique(query_label)) & set(class_ids)])
                     for query_label in query_labels]

    ###### Generate support image masks ######
    if not mask_only:
        support_mask = [[getMasks(support_labels[way][shot], support_scribbles[way][shot],
                                 class_ids[way], class_ids)
                         for shot in range(n_shots)] for way in range(n_ways)]
    else:
        support_mask = [[getMaskOnly(support_labels[way][shot],
                                 class_ids[way], class_ids)
                         for shot in range(n_shots)] for way in range(n_ways)]

    ###### Generate query label (class indices in one episode, i.e. the ground truth)######
    query_labels_tmp = [torch.zeros_like(x) for x in query_labels]
    for i, query_label_tmp in enumerate(query_labels_tmp):
        query_label_tmp[query_labels[i] == 255] = 255
        for j in range(n_ways):
            query_label_tmp[query_labels[i] == class_ids[j]] = j + 1

    ###### Generate query mask for each semantic class (including BG) ######
    # BG class
    query_masks = [[torch.where(query_label == 0,
                                torch.ones_like(query_label),
                                torch.zeros_like(query_label))[None, ...],]
                   for query_label in query_labels]
    # Other classes in query image
    for i, query_label in enumerate(query_labels):
        for idx in query_cls_idx[i][1:]:
            mask = torch.where(query_label == class_ids[idx - 1],
                               torch.ones_like(query_label),
                               torch.zeros_like(query_label))[None, ...]
            query_masks[i].append(mask)


    return {'class_ids': class_ids,
            'support_images': support_images,
            'support_mask': support_mask,
            'support_inst': support_insts, # leave these interfaces
            'support_scribbles': support_scribbles, 

            'query_images': query_images,
            'query_labels': query_labels_tmp,
            'query_masks': query_masks,
            'query_cls_idx': query_cls_idx,
           }


def med_fewshot(dataset_name, base_dir, idx_split, mode, scan_per_load,

        transforms, act_labels, n_ways, n_shots, max_iters_per_load, min_fg = '', n_queries=1, fix_parent_len = None, exclude_list = [], **kwargs):
    """

    Dataset wrapper

    Args:

        dataset_name:

            indicates what dataset to use

        base_dir:

            dataset directory

        mode: 

            which mode to use

            choose from ('train', 'val', 'trainval', 'trainaug')

        idx_split:

            index of split

        scan_per_load:

            number of scans to load into memory as the dataset is large

            use that together with reload_buffer

        transforms:

            transformations to be performed on images/masks

        act_labels:

            active labels involved in training process. Should be a subset of all labels

        n_ways:

            n-way few-shot learning, should be no more than # of object class labels

        n_shots:

            n-shot few-shot learning

        max_iters_per_load:

            number of pairs per load (epoch size)

        n_queries:

            number of query images

        fix_parent_len:

            fixed length of the parent dataset

    """
    med_set = ManualAnnoDataset


    mydataset = med_set(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode,\
         scan_per_load = scan_per_load, transforms=transforms, min_fg = min_fg, fix_length = fix_parent_len,\
         exclude_list = exclude_list, **kwargs)

    mydataset.add_attrib('basic', attrib_basic, {})

    # Create sub-datasets and add class_id attribute. Here the class file is internally loaded and reloaded inside
    subsets = mydataset.subsets([{'basic': {'class_id': ii}}
        for ii, _ in enumerate(mydataset.label_name)])

    # Choose the classes of queries
    cnt_query = np.bincount(random.choices(population=range(n_ways), k=n_queries), minlength=n_ways)
    # Number of queries for each way
    # Set the number of images for each class
    n_elements = [n_shots + x for x in cnt_query] # <n_shot> supports + <cnt_quert>[i] queries 
    # Create paired dataset. We do not include background.
    paired_data = ReloadPairedDataset([subsets[ii] for ii in act_labels], n_elements=n_elements, curr_max_iters=max_iters_per_load, 
                                pair_based_transforms=[
                                    (fewshot_pairing, {'n_ways': n_ways, 'n_shots': n_shots,
                                        'cnt_query': cnt_query, 'mask_only': True})])
    return paired_data, mydataset

def update_loader_dset(loader, parent_set):
    """

    Update data loader and the parent dataset behind

    Args:

        loader: actual dataloader

        parent_set: parent dataset which actually stores the data

    """
    parent_set.reload_buffer()
    loader.dataset.update_index()
    print(f'###### Loader and dataset have been updated ######' )

def med_fewshot_val(dataset_name, base_dir, idx_split, scan_per_load, act_labels, npart, fix_length = None, nsup = 1, transforms=None, mode='val', **kwargs):
    """

    validation set for med images

    Args:

        dataset_name:

            indicates what dataset to use

        base_dir:

            SABS dataset directory

        mode: (original split)

            which split to use

            choose from ('train', 'val', 'trainval', 'trainaug')

        idx_split:

            index of split

        scan_per_batch:

            number of scans to load into memory as the dataset is large

            use that together with reload_buffer

        act_labels:

            actual labels involved in training process. Should be a subset of all labels

        npart: number of chunks for splitting a 3d volume

        nsup:  number of support scans, equivalent to nshot

    """
    mydataset = ManualAnnoDataset(which_dataset = dataset_name, base_dir=base_dir, idx_split = idx_split, mode = mode, scan_per_load = scan_per_load, transforms=transforms, min_fg = 1, fix_length = fix_length, nsup = nsup, **kwargs)
    mydataset.add_attrib('basic', attrib_basic, {})

    valset = ValidationDataset(mydataset, test_classes = act_labels, npart = npart)

    return valset, mydataset