File size: 745 Bytes
173ea2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from typing import List


class Singleton(type):
    _instances = {}

    def __call__(cls, *args, **kwargs):
        if cls not in cls._instances:
            cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
        return cls._instances[cls]


class Latents(metaclass=Singleton):
    def __init__(self) -> None:
        self.history: List[torch.FloatTensor] = []

    def is_empty(self) -> bool:
        return self.history is None

    def add_latents(self, latents: torch.FloatTensor):
        self.history.append(latents)

    def clear(self):
        self.history = []

    def dump_and_clear(self):
        history = self.history
        self.clear()
        return history