fix: updates
Browse files- modeling_hf_nomic_bert.py +220 -21
modeling_hf_nomic_bert.py
CHANGED
@@ -41,14 +41,22 @@ from transformers.modeling_outputs import (
|
|
41 |
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
42 |
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
|
43 |
|
44 |
-
from
|
|
|
45 |
|
46 |
try:
|
47 |
from torch.nn.functional import scaled_dot_product_attention
|
48 |
except ImportError:
|
|
|
49 |
scaled_dot_product_attention = None
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
|
54 |
# adapted from flash attention, added safe serialization option for hf models
|
@@ -1083,6 +1091,171 @@ class NomciBertGatedMLP(nn.Module):
|
|
1083 |
y = self.fc2(y)
|
1084 |
return y if not self.return_residual else (y, x)
|
1085 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1086 |
|
1087 |
def rotate_half(x, interleaved=False):
|
1088 |
if not interleaved:
|
@@ -1431,6 +1604,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
1431 |
def __init__(
|
1432 |
self,
|
1433 |
config,
|
|
|
1434 |
):
|
1435 |
super().__init__(config=config)
|
1436 |
self.prenorm = config.prenorm
|
@@ -1442,25 +1616,46 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
1442 |
if config.activation_function == "glu"
|
1443 |
else (F.silu if config.activation_function == "swiglu" else F.gelu)
|
1444 |
)
|
1445 |
-
if
|
1446 |
-
|
1447 |
-
|
1448 |
-
|
1449 |
-
|
1450 |
-
|
1451 |
-
|
1452 |
-
|
1453 |
-
|
1454 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1455 |
else:
|
1456 |
-
|
1457 |
-
|
1458 |
-
|
1459 |
-
|
1460 |
-
|
1461 |
-
|
1462 |
-
|
1463 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1464 |
|
1465 |
self.dropout1 = nn.Dropout(config.resid_pdrop)
|
1466 |
self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
@@ -1530,7 +1725,11 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
1530 |
class NomicBertEncoder(nn.Module):
|
1531 |
def __init__(self, config: GPT2Config):
|
1532 |
super().__init__()
|
1533 |
-
|
|
|
|
|
|
|
|
|
1534 |
self.gradient_checkpointing = False
|
1535 |
self.config = config
|
1536 |
|
|
|
41 |
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
42 |
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
|
43 |
|
44 |
+
from configuration_hf_nomic_bert import NomicBertConfig
|
45 |
+
logger = logging.getLogger(__name__)
|
46 |
|
47 |
try:
|
48 |
from torch.nn.functional import scaled_dot_product_attention
|
49 |
except ImportError:
|
50 |
+
logger.warning("scaled_dot_product_attention not available, using torch.matmul instead")
|
51 |
scaled_dot_product_attention = None
|
52 |
|
53 |
+
try:
|
54 |
+
from megablocks.layers import dmoe
|
55 |
+
from megablocks.layers.arguments import Arguments
|
56 |
+
except ImportError:
|
57 |
+
logger.warning("!!!!!!!!!!!!megablocks not available, using torch.matmul instead")
|
58 |
+
dmoe = None
|
59 |
+
|
60 |
|
61 |
|
62 |
# adapted from flash attention, added safe serialization option for hf models
|
|
|
1091 |
y = self.fc2(y)
|
1092 |
return y if not self.return_residual else (y, x)
|
1093 |
|
1094 |
+
class NomicRouter(nn.Module):
|
1095 |
+
def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int,
|
1096 |
+
moe_jitter_eps: Optional[float] = None,
|
1097 |
+
moe_normalize_expert_weights: Optional[float] = None,
|
1098 |
+
uniform_expert_assignment: bool = False):
|
1099 |
+
super().__init__()
|
1100 |
+
self.hidden_size = hidden_size
|
1101 |
+
self.moe_num_experts = moe_num_experts
|
1102 |
+
self.moe_top_k = moe_top_k
|
1103 |
+
self.moe_jitter_eps = moe_jitter_eps
|
1104 |
+
self.moe_normalize_expert_weights = moe_normalize_expert_weights
|
1105 |
+
self.uniform_expert_assignment = uniform_expert_assignment
|
1106 |
+
|
1107 |
+
self.layer = nn.Linear(self.hidden_size,
|
1108 |
+
self.moe_num_experts,
|
1109 |
+
bias=False)
|
1110 |
+
|
1111 |
+
def jitter(self, x: torch.Tensor) -> torch.Tensor:
|
1112 |
+
if self.moe_jitter_eps is None:
|
1113 |
+
raise RuntimeError('The router does not have moe_jitter_eps set.')
|
1114 |
+
low = 1.0 - self.moe_jitter_eps
|
1115 |
+
high = 1.0 + self.moe_jitter_eps
|
1116 |
+
noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
|
1117 |
+
return low + noise * (high - low)
|
1118 |
+
|
1119 |
+
def forward(
|
1120 |
+
self, x: torch.Tensor
|
1121 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
|
1122 |
+
if self.training and self.moe_jitter_eps is not None:
|
1123 |
+
x = x * self.jitter(x)
|
1124 |
+
|
1125 |
+
weights = self.layer(x.view(-1,
|
1126 |
+
x.shape[-1])).softmax(dim=-1,
|
1127 |
+
dtype=torch.float32)
|
1128 |
+
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
|
1129 |
+
|
1130 |
+
if self.moe_normalize_expert_weights:
|
1131 |
+
top_weights = top_weights / torch.norm(
|
1132 |
+
top_weights,
|
1133 |
+
p=self.moe_normalize_expert_weights,
|
1134 |
+
dim=-1,
|
1135 |
+
keepdim=True)
|
1136 |
+
|
1137 |
+
if self.uniform_expert_assignment:
|
1138 |
+
with torch.no_grad():
|
1139 |
+
uniform_tensor = torch.arange(
|
1140 |
+
0,
|
1141 |
+
top_experts.numel(),
|
1142 |
+
device=top_experts.device,
|
1143 |
+
dtype=top_experts.dtype) % self.moe_num_experts
|
1144 |
+
top_experts = uniform_tensor.reshape(top_experts.shape)
|
1145 |
+
# Note, weights and top_weights are not changed
|
1146 |
+
|
1147 |
+
weights = weights.to(x.dtype)
|
1148 |
+
top_weights = top_weights.to(x.dtype)
|
1149 |
+
return weights, top_weights, top_experts # type: ignore
|
1150 |
+
|
1151 |
+
|
1152 |
+
class NomicExpertMLP(nn.Module):
|
1153 |
+
|
1154 |
+
def __init__(self, hidden_size: int, ffn_hidden_size: int,
|
1155 |
+
moe_num_experts: int, ffn_act_fn: dict):
|
1156 |
+
super().__init__()
|
1157 |
+
self.hidden_size = hidden_size
|
1158 |
+
self.ffn_hidden_size = ffn_hidden_size
|
1159 |
+
self.moe_num_experts = moe_num_experts
|
1160 |
+
|
1161 |
+
self.w1 = nn.Parameter(
|
1162 |
+
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
|
1163 |
+
self.w2 = nn.Parameter(
|
1164 |
+
torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
|
1165 |
+
self.activation_fn = ffn_act_fn
|
1166 |
+
|
1167 |
+
def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
|
1168 |
+
expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size,
|
1169 |
+
self.hidden_size)[expert_idx]
|
1170 |
+
expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size,
|
1171 |
+
self.hidden_size)[expert_idx]
|
1172 |
+
|
1173 |
+
x1 = x.matmul(expert_w1.t())
|
1174 |
+
act_out = self.activation_fn(x1)
|
1175 |
+
x2 = act_out.matmul(expert_w2)
|
1176 |
+
return x2
|
1177 |
+
|
1178 |
+
class NomicExperts(nn.Module):
|
1179 |
+
def __init__(self, config, hidden_size: int, ffn_hidden_size: int,
|
1180 |
+
moe_num_experts: int):
|
1181 |
+
super().__init__()
|
1182 |
+
self.moe_num_experts = moe_num_experts
|
1183 |
+
activation = (
|
1184 |
+
F.sigmoid
|
1185 |
+
if config.activation_function == "glu"
|
1186 |
+
else (F.silu if config.activation_function == "swiglu" else F.gelu)
|
1187 |
+
)
|
1188 |
+
self.mlp = NomicExpertMLP(
|
1189 |
+
hidden_size=config.n_embd,
|
1190 |
+
ffn_hidden_size=config.n_inner,
|
1191 |
+
moe_num_experts=moe_num_experts,
|
1192 |
+
ffn_act_fn=activation,
|
1193 |
+
)
|
1194 |
+
self.bias = nn.Parameter(torch.zeros(config.n_embd))
|
1195 |
+
|
1196 |
+
def forward(self, x: torch.Tensor, weights: torch.Tensor,
|
1197 |
+
top_weights: torch.Tensor,
|
1198 |
+
top_experts: torch.LongTensor) -> torch.Tensor:
|
1199 |
+
bsz, q_len, hidden_size = x.shape
|
1200 |
+
x = x.view(-1, hidden_size)
|
1201 |
+
out = torch.zeros_like(x)
|
1202 |
+
|
1203 |
+
expert_mask = nn.functional.one_hot(
|
1204 |
+
top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
|
1205 |
+
for expert_idx in range(0, self.moe_num_experts):
|
1206 |
+
topk_idx, token_idx = torch.where(expert_mask[expert_idx])
|
1207 |
+
if token_idx.shape[0] == 0:
|
1208 |
+
continue
|
1209 |
+
|
1210 |
+
token_list = token_idx.tolist()
|
1211 |
+
topk_list = topk_idx.tolist()
|
1212 |
+
|
1213 |
+
expert_tokens = x[None, token_list].reshape(-1, hidden_size)
|
1214 |
+
expert_out = self.mlp(
|
1215 |
+
expert_tokens, expert_idx) * top_weights[token_list, topk_list,
|
1216 |
+
None]
|
1217 |
+
|
1218 |
+
out.index_add_(0, token_idx, expert_out)
|
1219 |
+
|
1220 |
+
out = out.reshape(bsz, q_len, hidden_size)
|
1221 |
+
return out + self.bias
|
1222 |
+
|
1223 |
+
|
1224 |
+
class NomicMoELayer(nn.Module):
|
1225 |
+
|
1226 |
+
def __init__(self, config: NomicBertConfig):
|
1227 |
+
super().__init__()
|
1228 |
+
|
1229 |
+
self.router = NomicRouter(
|
1230 |
+
config.n_embd,
|
1231 |
+
moe_num_experts=config.num_experts,
|
1232 |
+
moe_top_k=config.moe_top_k,
|
1233 |
+
)
|
1234 |
+
|
1235 |
+
self.experts = NomicExperts(
|
1236 |
+
config,
|
1237 |
+
hidden_size=config.n_embd,
|
1238 |
+
ffn_hidden_size=config.n_inner,
|
1239 |
+
moe_num_experts=config.num_experts,
|
1240 |
+
)
|
1241 |
+
|
1242 |
+
|
1243 |
+
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
|
1244 |
+
batch_size, seq_len, hidden_dim = x.shape
|
1245 |
+
if attention_mask is not None:
|
1246 |
+
valid_indices = attention_mask.bool().view(-1)
|
1247 |
+
x_valid = x.view(-1, hidden_dim)[valid_indices]
|
1248 |
+
|
1249 |
+
weights, top_weights, top_experts = self.router(x)
|
1250 |
+
out = self.experts(x, weights, top_weights, top_experts)
|
1251 |
+
|
1252 |
+
if attention_mask is not None:
|
1253 |
+
full_out = torch.zeros(batch_size * seq_len, hidden_dim, dtype=out.dtype, device=out.device)
|
1254 |
+
full_out[valid_indices] = out
|
1255 |
+
out = full_out.view(batch_size, seq_len, hidden_dim)
|
1256 |
+
|
1257 |
+
return out
|
1258 |
+
|
1259 |
|
1260 |
def rotate_half(x, interleaved=False):
|
1261 |
if not interleaved:
|
|
|
1604 |
def __init__(
|
1605 |
self,
|
1606 |
config,
|
1607 |
+
moe=False,
|
1608 |
):
|
1609 |
super().__init__(config=config)
|
1610 |
self.prenorm = config.prenorm
|
|
|
1616 |
if config.activation_function == "glu"
|
1617 |
else (F.silu if config.activation_function == "swiglu" else F.gelu)
|
1618 |
)
|
1619 |
+
if moe:
|
1620 |
+
if dmoe is not None:
|
1621 |
+
megablocks_args = Arguments(
|
1622 |
+
moe_num_experts=config.num_experts,
|
1623 |
+
moe_top_k=config.moe_top_k,
|
1624 |
+
hidden_size=config.n_embd,
|
1625 |
+
ffn_hidden_size=config.n_inner,
|
1626 |
+
num_layers=config.n_layer,
|
1627 |
+
moe_normalize_expert_weights=config.moe_normalize_expert_weights,
|
1628 |
+
activation_fn=activation,
|
1629 |
+
mlp_type="glu" if config.activation_function == "swiglu" else "mlp",
|
1630 |
+
fp16=True,
|
1631 |
+
bf16=False,
|
1632 |
+
return_bias=False,
|
1633 |
+
)
|
1634 |
+
self.mlp = dmoe.dMoE(megablocks_args)
|
1635 |
+
else:
|
1636 |
+
self.mlp = NomicMoELayer(
|
1637 |
+
config
|
1638 |
+
)
|
1639 |
else:
|
1640 |
+
if config.activation_function in ["glu", "swiglu", "geglu"]:
|
1641 |
+
self.mlp = NomciBertGatedMLP(
|
1642 |
+
config.n_embd,
|
1643 |
+
hidden_features=config.n_inner,
|
1644 |
+
bias1=config.mlp_fc1_bias,
|
1645 |
+
bias2=config.mlp_fc2_bias,
|
1646 |
+
activation=activation,
|
1647 |
+
fused_bias_fc=config.fused_bias_fc,
|
1648 |
+
norm_layer=getattr(config, "norm_mlp", False),
|
1649 |
+
)
|
1650 |
+
else:
|
1651 |
+
self.mlp = NomicBertMLP(
|
1652 |
+
config.n_embd,
|
1653 |
+
hidden_features=config.n_inner,
|
1654 |
+
bias1=config.mlp_fc1_bias,
|
1655 |
+
bias2=config.mlp_fc2_bias,
|
1656 |
+
activation=activation,
|
1657 |
+
fused_bias_fc=config.fused_bias_fc,
|
1658 |
+
)
|
1659 |
|
1660 |
self.dropout1 = nn.Dropout(config.resid_pdrop)
|
1661 |
self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
|
|
1725 |
class NomicBertEncoder(nn.Module):
|
1726 |
def __init__(self, config: GPT2Config):
|
1727 |
super().__init__()
|
1728 |
+
if getattr(config, "moe_every_n_layers", 0) > 0:
|
1729 |
+
every_n = config.moe_every_n_layers
|
1730 |
+
self.layers = nn.ModuleList([NomicBertBlock(config, moe=i%every_n == 1) for i in range(config.n_layer)])
|
1731 |
+
else:
|
1732 |
+
self.layers = nn.ModuleList([NomicBertBlock(config, moe=False) for _ in range(config.n_layer)])
|
1733 |
self.gradient_checkpointing = False
|
1734 |
self.config = config
|
1735 |
|