File size: 537 Bytes
71f183c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from ignite.metrics import Accuracy, Loss
from typing import Sequence

class TopKAccuracy(Accuracy):
    def update(self, output: Sequence[torch.Tensor], **kwargs) -> None:
        y_pred, y_attack = output[0].detach(), output[1].detach()
        k = y_attack.shape[-1]

        y_pred_indices = y_pred.argsort(dim=-1, descending=True) # [N, C]

        correct = (y_pred_indices[:, :k] == y_attack).all(dim=-1)

        self._num_correct += torch.sum(correct).to(self._device)
        self._num_examples += correct.shape[0]