|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
try: |
|
from apex.normalization import FusedLayerNorm as LayerNorm |
|
except ModuleNotFoundError: |
|
from torch.nn import LayerNorm |
|
|
|
|
|
class set_torch_seed(object): |
|
def __init__(self, seed): |
|
assert isinstance(seed, int) |
|
self.rng_state = self.get_rng_state() |
|
|
|
torch.manual_seed(seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(seed) |
|
|
|
def get_rng_state(self): |
|
state = {"torch_rng_state": torch.get_rng_state()} |
|
if torch.cuda.is_available(): |
|
state["cuda_rng_state"] = torch.cuda.get_rng_state() |
|
return state |
|
|
|
def set_rng_state(self, state): |
|
torch.set_rng_state(state["torch_rng_state"]) |
|
if torch.cuda.is_available(): |
|
torch.cuda.set_rng_state(state["cuda_rng_state"]) |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, *exc): |
|
self.set_rng_state(self.rng_state) |
|
|
|
|
|
def make_experts(args, embed_dim, expert_ffn_dim): |
|
world_size = 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size() |
|
expert_list = [] |
|
ddp_rank = args.ddp_rank |
|
start_seed = torch.randint(1000000, (1,)).item() |
|
|
|
if args.moe_expert_count >= world_size: |
|
assert args.moe_expert_count % world_size == 0, f"{args.moe_expert_count}, {world_size}" |
|
local_moe_expert_count = args.moe_expert_count // world_size |
|
for i in range(local_moe_expert_count): |
|
with set_torch_seed(start_seed + ddp_rank * local_moe_expert_count + i): |
|
expert_list.append( |
|
FeedForwardNetwork( |
|
embed_dim, |
|
expert_ffn_dim, |
|
args.activation_fn, |
|
args.dropout, |
|
args.activation_dropout, |
|
args.layernorm_eps, |
|
args.subln, |
|
) |
|
) |
|
else: |
|
assert world_size % args.moe_expert_count == 0, f"{world_size}, {args.moe_expert_count}" |
|
|
|
with set_torch_seed(start_seed + ddp_rank % args.moe_expert_count): |
|
expert_list.append( |
|
FeedForwardNetwork( |
|
embed_dim, |
|
expert_ffn_dim, |
|
args.activation_fn, |
|
args.dropout, |
|
args.activation_dropout, |
|
args.layernorm_eps, |
|
args.subln, |
|
) |
|
) |
|
experts = nn.ModuleList(expert_list) |
|
return experts |
|
|
|
|
|
def get_activation_fn(activation): |
|
if activation == "relu": |
|
return F.relu |
|
elif activation == "gelu": |
|
return F.gelu |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
class FeedForwardNetwork(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim, |
|
ffn_dim, |
|
activation_fn, |
|
dropout, |
|
activation_dropout, |
|
layernorm_eps, |
|
subln=False, |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.activation_fn = get_activation_fn(activation=str(activation_fn)) |
|
self.activation_dropout_module = torch.nn.Dropout(activation_dropout) |
|
self.dropout_module = torch.nn.Dropout(dropout) |
|
self.fc1 = nn.Linear(self.embed_dim, ffn_dim) |
|
self.fc2 = nn.Linear(ffn_dim, self.embed_dim) |
|
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None |
|
|
|
def reset_parameters(self): |
|
self.fc1.reset_parameters() |
|
self.fc2.reset_parameters() |
|
if self.ffn_layernorm is not None: |
|
self.ffn_layernorm.reset_parameters() |
|
|
|
def forward(self, x): |
|
|
|
x = self.fc1(x) |
|
|
|
x = self.activation_fn(x) |
|
x = self.activation_dropout_module(x) |
|
if self.ffn_layernorm is not None: |
|
x = self.ffn_layernorm(x) |
|
x = self.fc2(x) |
|
|
|
x = self.dropout_module(x) |
|
return x |
|
|