File size: 522 Bytes
393d3de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from accelerate import Accelerator
from datasets.core import TrajectoryDataset


class Workspace:
    def __init__(self, cfg, work_dir):
        self.cfg = cfg
        self.work_dir = work_dir
        self.accelerator = Accelerator()
        self.dataset: TrajectoryDataset = None

    def set_models(self, encoder, projector):
        self.encoder = encoder
        self.projector = projector

    def set_dataset(self, dataset):
        self.dataset = dataset

    def run_offline_eval(self):
        return {"loss": 0}