File size: 385 Bytes
7cabed8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
from torch import nn
class DummyPooling(nn.Module):
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
return {'sentence_embedding': features['token_embeddings']}
def save(self, save_dir: str, **kwargs) -> None:
pass
def load(load_dir: str, **kwargs) -> "DummyPooling":
return DummyPooling() |