File size: 522 Bytes
393d3de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from accelerate import Accelerator
from datasets.core import TrajectoryDataset
class Workspace:
def __init__(self, cfg, work_dir):
self.cfg = cfg
self.work_dir = work_dir
self.accelerator = Accelerator()
self.dataset: TrajectoryDataset = None
def set_models(self, encoder, projector):
self.encoder = encoder
self.projector = projector
def set_dataset(self, dataset):
self.dataset = dataset
def run_offline_eval(self):
return {"loss": 0}
|