File size: 487 Bytes
393d3de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
import torch.nn as nn
class IdentityProjector(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
):
super().__init__()
self.placeholder_param = nn.Parameter(torch.zeros(1))
def forward(self, *args):
return args[0]
def configure_optimizers(self, weight_decay, lr, betas):
return torch.optim.AdamW(
self.parameters(), lr=lr, betas=betas, weight_decay=weight_decay
)
|