face-mogle / src /moe /mogle.py
XavierJiezou's picture
Upload folder using huggingface_hub
48ed5ae verified
import torch
import torch.nn as nn
from diffusers.models.embeddings import Timesteps, TimestepEmbedding
import torch.optim as optim
from torch.nn import functional as F
# Define the Expert Network
class Expert(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, use_softmax=False):
super(Expert, self).__init__()
self.use_softmax = use_softmax
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, x):
return (
self.net(x) if not self.use_softmax else torch.softmax(self.net(x), dim=1)
)
class DynamicGatingNetwork(nn.Module):
def __init__(self, hidden_dim=64, embed_dim=64, dtype=torch.bfloat16):
super().__init__()
# 处理时间步
self.time_proj = Timesteps(
hidden_dim, flip_sin_to_cos=True, downscale_freq_shift=0
)
self.timestep_embedding = TimestepEmbedding(hidden_dim, embed_dim)
self.timestep_embedding = self.timestep_embedding.to(dtype=torch.bfloat16)
# 处理 noise_latent
self.noise_proj = nn.Linear(hidden_dim, hidden_dim)
self.dtype = dtype
# 权重计算
self.gate = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 20), # 生成两个权重
)
def forward(self, condition_latents, noise_latent, timestep):
"""
global_latents: (bs, 1024, 64)
noise_latent: (bs, 1024, 64)
timestep: (bs,)
"""
bs, seq_len, hidden_dim = condition_latents.shape
# 处理 timestep
time_emb = self.time_proj(timestep) # (bs, hidden_dim)
time_emb = time_emb.to(self.dtype)
time_emb = self.timestep_embedding(time_emb) # (bs, embed_dim)
time_emb = time_emb.unsqueeze(1).expand(
-1, seq_len, -1
) # (bs, 1024, embed_dim)
# 处理 noise_latent
noise_emb = self.noise_proj(noise_latent) # (bs, 1024, 64)
# 拼接所有输入
# fused_input = torch.cat([condition_latents, noise_emb, time_emb], dim=2) # (bs, 1024, 64+64+128)
fused_input = condition_latents + noise_emb + time_emb
# 计算权重
weight = self.gate(fused_input) # (bs, 1024, 2)
weight = F.softmax(weight, dim=2) # 归一化
return weight
class MoGLE(nn.Module):
def __init__(
self,
num_experts=20,
input_dim=64,
hidden_dim=32,
output_dim=64,
has_expert=True,
has_gating=True,
weight_is_scale=False,
):
super().__init__()
expert_model = None
if has_expert:
expert_model = Expert
else:
expert_model = nn.Identity
self.global_expert = expert_model(
input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim
)
self.local_experts = nn.ModuleList(
[
expert_model(
input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim
)
for _ in range(num_experts - 1)
]
)
# self.gating = Gating(input_dim=input_dim, num_experts=num_experts)
if has_gating:
self.gating = DynamicGatingNetwork()
else:
self.gating = nn.Identity()
self.weight_is_scale = weight_is_scale
def forward(self, x: torch.Tensor, noise_latent, timestep):
global_mask = x[:, 0] # bs 1024 64
local_mask = x[:, 1:] # bs 19 1024 64
if not isinstance(self.gating, nn.Identity):
weights = self.gating.forward(
global_mask, noise_latent=noise_latent, timestep=timestep
) # bs 1024 20
_, num_local, h, w = local_mask.shape
global_output = self.global_expert(global_mask).unsqueeze(1)
local_outputs = torch.stack(
[self.local_experts[i](local_mask[:, i]) for i in range(num_local)], dim=1
) # (bs, 19, 1024, 64)
global_local_outputs = torch.cat(
[global_output, local_outputs], dim=1
) # bs 20 1024 64
if isinstance(self.gating, nn.Identity):
global_local_outputs = global_local_outputs.sum(dim=1)
return global_local_outputs
if self.weight_is_scale:
weights = torch.mean(weights, dim=1, keepdim=True) # bs 1 20
# print("gating scale")
weights_expanded = weights.unsqueeze(-1)
output = (global_local_outputs.permute(0, 2, 1, 3) * weights_expanded).sum(
dim=2
)
return output # bs 1024 64