File size: 745 Bytes
173ea2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
from typing import List
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class Latents(metaclass=Singleton):
def __init__(self) -> None:
self.history: List[torch.FloatTensor] = []
def is_empty(self) -> bool:
return self.history is None
def add_latents(self, latents: torch.FloatTensor):
self.history.append(latents)
def clear(self):
self.history = []
def dump_and_clear(self):
history = self.history
self.clear()
return history
|