Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn commited on
Commit
d999b9f
·
1 Parent(s): 40b9839

fix: updates

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +220 -21
modeling_hf_nomic_bert.py CHANGED
@@ -41,14 +41,22 @@ from transformers.modeling_outputs import (
41
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
42
  from transformers.utils.hub import cached_file, get_checkpoint_shard_files
43
 
44
- from .configuration_hf_nomic_bert import NomicBertConfig
 
45
 
46
  try:
47
  from torch.nn.functional import scaled_dot_product_attention
48
  except ImportError:
 
49
  scaled_dot_product_attention = None
50
 
51
- logger = logging.getLogger(__name__)
 
 
 
 
 
 
52
 
53
 
54
  # adapted from flash attention, added safe serialization option for hf models
@@ -1083,6 +1091,171 @@ class NomciBertGatedMLP(nn.Module):
1083
  y = self.fc2(y)
1084
  return y if not self.return_residual else (y, x)
1085
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1086
 
1087
  def rotate_half(x, interleaved=False):
1088
  if not interleaved:
@@ -1431,6 +1604,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
1431
  def __init__(
1432
  self,
1433
  config,
 
1434
  ):
1435
  super().__init__(config=config)
1436
  self.prenorm = config.prenorm
@@ -1442,25 +1616,46 @@ class NomicBertBlock(NomicBertPreTrainedModel):
1442
  if config.activation_function == "glu"
1443
  else (F.silu if config.activation_function == "swiglu" else F.gelu)
1444
  )
1445
- if config.activation_function in ["glu", "swiglu", "geglu"]:
1446
- self.mlp = NomciBertGatedMLP(
1447
- config.n_embd,
1448
- hidden_features=config.n_inner,
1449
- bias1=config.mlp_fc1_bias,
1450
- bias2=config.mlp_fc2_bias,
1451
- activation=activation,
1452
- fused_bias_fc=config.fused_bias_fc,
1453
- norm_layer=getattr(config, "norm_mlp", False),
1454
- )
 
 
 
 
 
 
 
 
 
 
1455
  else:
1456
- self.mlp = NomicBertMLP(
1457
- config.n_embd,
1458
- hidden_features=config.n_inner,
1459
- bias1=config.mlp_fc1_bias,
1460
- bias2=config.mlp_fc2_bias,
1461
- activation=activation,
1462
- fused_bias_fc=config.fused_bias_fc,
1463
- )
 
 
 
 
 
 
 
 
 
 
 
1464
 
1465
  self.dropout1 = nn.Dropout(config.resid_pdrop)
1466
  self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
@@ -1530,7 +1725,11 @@ class NomicBertBlock(NomicBertPreTrainedModel):
1530
  class NomicBertEncoder(nn.Module):
1531
  def __init__(self, config: GPT2Config):
1532
  super().__init__()
1533
- self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
 
 
 
 
1534
  self.gradient_checkpointing = False
1535
  self.config = config
1536
 
 
41
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
42
  from transformers.utils.hub import cached_file, get_checkpoint_shard_files
43
 
44
+ from configuration_hf_nomic_bert import NomicBertConfig
45
+ logger = logging.getLogger(__name__)
46
 
47
  try:
48
  from torch.nn.functional import scaled_dot_product_attention
49
  except ImportError:
50
+ logger.warning("scaled_dot_product_attention not available, using torch.matmul instead")
51
  scaled_dot_product_attention = None
52
 
53
+ try:
54
+ from megablocks.layers import dmoe
55
+ from megablocks.layers.arguments import Arguments
56
+ except ImportError:
57
+ logger.warning("!!!!!!!!!!!!megablocks not available, using torch.matmul instead")
58
+ dmoe = None
59
+
60
 
61
 
62
  # adapted from flash attention, added safe serialization option for hf models
 
1091
  y = self.fc2(y)
1092
  return y if not self.return_residual else (y, x)
1093
 
1094
+ class NomicRouter(nn.Module):
1095
+ def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int,
1096
+ moe_jitter_eps: Optional[float] = None,
1097
+ moe_normalize_expert_weights: Optional[float] = None,
1098
+ uniform_expert_assignment: bool = False):
1099
+ super().__init__()
1100
+ self.hidden_size = hidden_size
1101
+ self.moe_num_experts = moe_num_experts
1102
+ self.moe_top_k = moe_top_k
1103
+ self.moe_jitter_eps = moe_jitter_eps
1104
+ self.moe_normalize_expert_weights = moe_normalize_expert_weights
1105
+ self.uniform_expert_assignment = uniform_expert_assignment
1106
+
1107
+ self.layer = nn.Linear(self.hidden_size,
1108
+ self.moe_num_experts,
1109
+ bias=False)
1110
+
1111
+ def jitter(self, x: torch.Tensor) -> torch.Tensor:
1112
+ if self.moe_jitter_eps is None:
1113
+ raise RuntimeError('The router does not have moe_jitter_eps set.')
1114
+ low = 1.0 - self.moe_jitter_eps
1115
+ high = 1.0 + self.moe_jitter_eps
1116
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
1117
+ return low + noise * (high - low)
1118
+
1119
+ def forward(
1120
+ self, x: torch.Tensor
1121
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
1122
+ if self.training and self.moe_jitter_eps is not None:
1123
+ x = x * self.jitter(x)
1124
+
1125
+ weights = self.layer(x.view(-1,
1126
+ x.shape[-1])).softmax(dim=-1,
1127
+ dtype=torch.float32)
1128
+ top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
1129
+
1130
+ if self.moe_normalize_expert_weights:
1131
+ top_weights = top_weights / torch.norm(
1132
+ top_weights,
1133
+ p=self.moe_normalize_expert_weights,
1134
+ dim=-1,
1135
+ keepdim=True)
1136
+
1137
+ if self.uniform_expert_assignment:
1138
+ with torch.no_grad():
1139
+ uniform_tensor = torch.arange(
1140
+ 0,
1141
+ top_experts.numel(),
1142
+ device=top_experts.device,
1143
+ dtype=top_experts.dtype) % self.moe_num_experts
1144
+ top_experts = uniform_tensor.reshape(top_experts.shape)
1145
+ # Note, weights and top_weights are not changed
1146
+
1147
+ weights = weights.to(x.dtype)
1148
+ top_weights = top_weights.to(x.dtype)
1149
+ return weights, top_weights, top_experts # type: ignore
1150
+
1151
+
1152
+ class NomicExpertMLP(nn.Module):
1153
+
1154
+ def __init__(self, hidden_size: int, ffn_hidden_size: int,
1155
+ moe_num_experts: int, ffn_act_fn: dict):
1156
+ super().__init__()
1157
+ self.hidden_size = hidden_size
1158
+ self.ffn_hidden_size = ffn_hidden_size
1159
+ self.moe_num_experts = moe_num_experts
1160
+
1161
+ self.w1 = nn.Parameter(
1162
+ torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
1163
+ self.w2 = nn.Parameter(
1164
+ torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
1165
+ self.activation_fn = ffn_act_fn
1166
+
1167
+ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
1168
+ expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size,
1169
+ self.hidden_size)[expert_idx]
1170
+ expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size,
1171
+ self.hidden_size)[expert_idx]
1172
+
1173
+ x1 = x.matmul(expert_w1.t())
1174
+ act_out = self.activation_fn(x1)
1175
+ x2 = act_out.matmul(expert_w2)
1176
+ return x2
1177
+
1178
+ class NomicExperts(nn.Module):
1179
+ def __init__(self, config, hidden_size: int, ffn_hidden_size: int,
1180
+ moe_num_experts: int):
1181
+ super().__init__()
1182
+ self.moe_num_experts = moe_num_experts
1183
+ activation = (
1184
+ F.sigmoid
1185
+ if config.activation_function == "glu"
1186
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
1187
+ )
1188
+ self.mlp = NomicExpertMLP(
1189
+ hidden_size=config.n_embd,
1190
+ ffn_hidden_size=config.n_inner,
1191
+ moe_num_experts=moe_num_experts,
1192
+ ffn_act_fn=activation,
1193
+ )
1194
+ self.bias = nn.Parameter(torch.zeros(config.n_embd))
1195
+
1196
+ def forward(self, x: torch.Tensor, weights: torch.Tensor,
1197
+ top_weights: torch.Tensor,
1198
+ top_experts: torch.LongTensor) -> torch.Tensor:
1199
+ bsz, q_len, hidden_size = x.shape
1200
+ x = x.view(-1, hidden_size)
1201
+ out = torch.zeros_like(x)
1202
+
1203
+ expert_mask = nn.functional.one_hot(
1204
+ top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
1205
+ for expert_idx in range(0, self.moe_num_experts):
1206
+ topk_idx, token_idx = torch.where(expert_mask[expert_idx])
1207
+ if token_idx.shape[0] == 0:
1208
+ continue
1209
+
1210
+ token_list = token_idx.tolist()
1211
+ topk_list = topk_idx.tolist()
1212
+
1213
+ expert_tokens = x[None, token_list].reshape(-1, hidden_size)
1214
+ expert_out = self.mlp(
1215
+ expert_tokens, expert_idx) * top_weights[token_list, topk_list,
1216
+ None]
1217
+
1218
+ out.index_add_(0, token_idx, expert_out)
1219
+
1220
+ out = out.reshape(bsz, q_len, hidden_size)
1221
+ return out + self.bias
1222
+
1223
+
1224
+ class NomicMoELayer(nn.Module):
1225
+
1226
+ def __init__(self, config: NomicBertConfig):
1227
+ super().__init__()
1228
+
1229
+ self.router = NomicRouter(
1230
+ config.n_embd,
1231
+ moe_num_experts=config.num_experts,
1232
+ moe_top_k=config.moe_top_k,
1233
+ )
1234
+
1235
+ self.experts = NomicExperts(
1236
+ config,
1237
+ hidden_size=config.n_embd,
1238
+ ffn_hidden_size=config.n_inner,
1239
+ moe_num_experts=config.num_experts,
1240
+ )
1241
+
1242
+
1243
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
1244
+ batch_size, seq_len, hidden_dim = x.shape
1245
+ if attention_mask is not None:
1246
+ valid_indices = attention_mask.bool().view(-1)
1247
+ x_valid = x.view(-1, hidden_dim)[valid_indices]
1248
+
1249
+ weights, top_weights, top_experts = self.router(x)
1250
+ out = self.experts(x, weights, top_weights, top_experts)
1251
+
1252
+ if attention_mask is not None:
1253
+ full_out = torch.zeros(batch_size * seq_len, hidden_dim, dtype=out.dtype, device=out.device)
1254
+ full_out[valid_indices] = out
1255
+ out = full_out.view(batch_size, seq_len, hidden_dim)
1256
+
1257
+ return out
1258
+
1259
 
1260
  def rotate_half(x, interleaved=False):
1261
  if not interleaved:
 
1604
  def __init__(
1605
  self,
1606
  config,
1607
+ moe=False,
1608
  ):
1609
  super().__init__(config=config)
1610
  self.prenorm = config.prenorm
 
1616
  if config.activation_function == "glu"
1617
  else (F.silu if config.activation_function == "swiglu" else F.gelu)
1618
  )
1619
+ if moe:
1620
+ if dmoe is not None:
1621
+ megablocks_args = Arguments(
1622
+ moe_num_experts=config.num_experts,
1623
+ moe_top_k=config.moe_top_k,
1624
+ hidden_size=config.n_embd,
1625
+ ffn_hidden_size=config.n_inner,
1626
+ num_layers=config.n_layer,
1627
+ moe_normalize_expert_weights=config.moe_normalize_expert_weights,
1628
+ activation_fn=activation,
1629
+ mlp_type="glu" if config.activation_function == "swiglu" else "mlp",
1630
+ fp16=True,
1631
+ bf16=False,
1632
+ return_bias=False,
1633
+ )
1634
+ self.mlp = dmoe.dMoE(megablocks_args)
1635
+ else:
1636
+ self.mlp = NomicMoELayer(
1637
+ config
1638
+ )
1639
  else:
1640
+ if config.activation_function in ["glu", "swiglu", "geglu"]:
1641
+ self.mlp = NomciBertGatedMLP(
1642
+ config.n_embd,
1643
+ hidden_features=config.n_inner,
1644
+ bias1=config.mlp_fc1_bias,
1645
+ bias2=config.mlp_fc2_bias,
1646
+ activation=activation,
1647
+ fused_bias_fc=config.fused_bias_fc,
1648
+ norm_layer=getattr(config, "norm_mlp", False),
1649
+ )
1650
+ else:
1651
+ self.mlp = NomicBertMLP(
1652
+ config.n_embd,
1653
+ hidden_features=config.n_inner,
1654
+ bias1=config.mlp_fc1_bias,
1655
+ bias2=config.mlp_fc2_bias,
1656
+ activation=activation,
1657
+ fused_bias_fc=config.fused_bias_fc,
1658
+ )
1659
 
1660
  self.dropout1 = nn.Dropout(config.resid_pdrop)
1661
  self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
 
1725
  class NomicBertEncoder(nn.Module):
1726
  def __init__(self, config: GPT2Config):
1727
  super().__init__()
1728
+ if getattr(config, "moe_every_n_layers", 0) > 0:
1729
+ every_n = config.moe_every_n_layers
1730
+ self.layers = nn.ModuleList([NomicBertBlock(config, moe=i%every_n == 1) for i in range(config.n_layer)])
1731
+ else:
1732
+ self.layers = nn.ModuleList([NomicBertBlock(config, moe=False) for _ in range(config.n_layer)])
1733
  self.gradient_checkpointing = False
1734
  self.config = config
1735