import torch.nn as nn class Pooler(nn.Module): def __init__(self, hidden_size): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class ITCHead(nn.Module): def __init__(self, hidden_size, out_size): super().__init__() self.fc = nn.Linear(hidden_size, out_size, bias=False) def forward(self, x): x = self.fc(x) return x