File size: 385 Bytes
7cabed8
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

from torch import nn


class DummyPooling(nn.Module):
    def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
        return {'sentence_embedding': features['token_embeddings']}

    def save(self, save_dir: str, **kwargs) -> None:
        pass

    def load(load_dir: str, **kwargs) -> "DummyPooling":
        return DummyPooling()