File size: 2,579 Bytes
22e1b62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import collections
import random
from typing import Callable

from torchdata.datapipes.iter import IterDataPipe


def get_second_entry(sample):
    return sample[1]


class UnderSamplerIterDataPipe(IterDataPipe):
    """Dataset wrapper for under-sampling.

    Copied from: https://github.com/MaxHalford/pytorch-resample/blob/master/pytorch_resample/under.py # noqa
    Modified to work with multiple labels.

    MIT License

    Copyright (c) 2020 Max Halford

    This method is based on rejection sampling.

    Parameters:
        dataset
        desired_dist: The desired class distribution.
            The keys are the classes whilst the
            values are the desired class percentages.
            The values are normalised so that sum up
            to 1.
        label_getter: A function that takes a sample and returns its label.
        seed: Random seed for reproducibility.

    Attributes:
        actual_dist: The counts of the observed sample labels.
        rng: A random number generator instance.

    References:
        - https://www.wikiwand.com/en/Rejection_sampling

    """

    def __init__(
        self,
        dataset: IterDataPipe,
        desired_dist: dict,
        label_getter: Callable = get_second_entry,
        seed: int = None,
    ):

        self.dataset = dataset
        self.desired_dist = {
            c: p / sum(desired_dist.values()) for c, p in desired_dist.items()
        }
        self.label_getter = label_getter
        self.seed = seed

        self.actual_dist = collections.Counter()
        self.rng = random.Random(seed)
        self._pivot = None

    def __iter__(self):

        for dp in self.dataset:
            y = self.label_getter(dp)

            self.actual_dist[y] += 1

            # To ease notation
            f = self.desired_dist
            g = self.actual_dist

            # Check if the pivot needs to be changed
            if y != self._pivot:
                self._pivot = max(g.keys(), key=lambda y: f[y] / g[y])
            else:
                yield dp
                continue

            # Determine the sampling ratio if the observed label
            # is not the pivot
            M = f[self._pivot] / g[self._pivot]
            ratio = f[y] / (M * g[y])

            if ratio < 1 and self.rng.random() < ratio:
                yield dp

    @classmethod
    def expected_size(cls, n, desired_dist, actual_dist):
        M = max(
            desired_dist.get(k) / actual_dist.get(k)
            for k in set(desired_dist) | set(actual_dist)
        )
        return int(n / M)