Spaces:
Running
Running
import torch | |
from torch import nn | |
def to(x, device): | |
if isinstance(x, dict): | |
for k, v in x.items(): | |
if isinstance(v, torch.Tensor): | |
x[k] = v.to(device) | |
else: | |
x = x.to(device) | |
return x | |
def get_cur_acc(testset, hyps, model, shuffle, iter_index): | |
from data import split_dataset, build_dataloader | |
cur_test_batch_dataset = split_dataset(testset, hyps['val_batch_size'], iter_index)[0] | |
cur_test_batch_dataloader = build_dataloader(cur_test_batch_dataset, hyps['train_batch_size'], hyps['num_workers'], False, shuffle) | |
cur_acc = model.get_accuracy(cur_test_batch_dataloader) | |
return cur_acc | |