jeffacce
initial commit
393d3de
raw
history blame contribute delete
487 Bytes
import torch
import torch.nn as nn
class IdentityProjector(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
):
super().__init__()
self.placeholder_param = nn.Parameter(torch.zeros(1))
def forward(self, *args):
return args[0]
def configure_optimizers(self, weight_decay, lr, betas):
return torch.optim.AdamW(
self.parameters(), lr=lr, betas=betas, weight_decay=weight_decay
)