Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from diffusers.models.embeddings import Timesteps, TimestepEmbedding | |
import torch.optim as optim | |
from torch.nn import functional as F | |
# Define the Expert Network | |
class Expert(nn.Module): | |
def __init__(self, input_dim, hidden_dim, output_dim, use_softmax=False): | |
super(Expert, self).__init__() | |
self.use_softmax = use_softmax | |
self.net = nn.Sequential( | |
nn.Linear(input_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, output_dim), | |
) | |
def forward(self, x): | |
return ( | |
self.net(x) if not self.use_softmax else torch.softmax(self.net(x), dim=1) | |
) | |
class DynamicGatingNetwork(nn.Module): | |
def __init__(self, hidden_dim=64, embed_dim=64, dtype=torch.bfloat16): | |
super().__init__() | |
# 处理时间步 | |
self.time_proj = Timesteps( | |
hidden_dim, flip_sin_to_cos=True, downscale_freq_shift=0 | |
) | |
self.timestep_embedding = TimestepEmbedding(hidden_dim, embed_dim) | |
self.timestep_embedding = self.timestep_embedding.to(dtype=torch.bfloat16) | |
# 处理 noise_latent | |
self.noise_proj = nn.Linear(hidden_dim, hidden_dim) | |
self.dtype = dtype | |
# 权重计算 | |
self.gate = nn.Sequential( | |
nn.Linear(hidden_dim, hidden_dim), | |
nn.ReLU(), | |
nn.Linear(hidden_dim, 20), # 生成两个权重 | |
) | |
def forward(self, condition_latents, noise_latent, timestep): | |
""" | |
global_latents: (bs, 1024, 64) | |
noise_latent: (bs, 1024, 64) | |
timestep: (bs,) | |
""" | |
bs, seq_len, hidden_dim = condition_latents.shape | |
# 处理 timestep | |
time_emb = self.time_proj(timestep) # (bs, hidden_dim) | |
time_emb = time_emb.to(self.dtype) | |
time_emb = self.timestep_embedding(time_emb) # (bs, embed_dim) | |
time_emb = time_emb.unsqueeze(1).expand( | |
-1, seq_len, -1 | |
) # (bs, 1024, embed_dim) | |
# 处理 noise_latent | |
noise_emb = self.noise_proj(noise_latent) # (bs, 1024, 64) | |
# 拼接所有输入 | |
# fused_input = torch.cat([condition_latents, noise_emb, time_emb], dim=2) # (bs, 1024, 64+64+128) | |
fused_input = condition_latents + noise_emb + time_emb | |
# 计算权重 | |
weight = self.gate(fused_input) # (bs, 1024, 2) | |
weight = F.softmax(weight, dim=2) # 归一化 | |
return weight | |
class MoGLE(nn.Module): | |
def __init__( | |
self, | |
num_experts=20, | |
input_dim=64, | |
hidden_dim=32, | |
output_dim=64, | |
has_expert=True, | |
has_gating=True, | |
weight_is_scale=False, | |
): | |
super().__init__() | |
expert_model = None | |
if has_expert: | |
expert_model = Expert | |
else: | |
expert_model = nn.Identity | |
self.global_expert = expert_model( | |
input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim | |
) | |
self.local_experts = nn.ModuleList( | |
[ | |
expert_model( | |
input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim | |
) | |
for _ in range(num_experts - 1) | |
] | |
) | |
# self.gating = Gating(input_dim=input_dim, num_experts=num_experts) | |
if has_gating: | |
self.gating = DynamicGatingNetwork() | |
else: | |
self.gating = nn.Identity() | |
self.weight_is_scale = weight_is_scale | |
def forward(self, x: torch.Tensor, noise_latent, timestep): | |
global_mask = x[:, 0] # bs 1024 64 | |
local_mask = x[:, 1:] # bs 19 1024 64 | |
if not isinstance(self.gating, nn.Identity): | |
weights = self.gating.forward( | |
global_mask, noise_latent=noise_latent, timestep=timestep | |
) # bs 1024 20 | |
_, num_local, h, w = local_mask.shape | |
global_output = self.global_expert(global_mask).unsqueeze(1) | |
local_outputs = torch.stack( | |
[self.local_experts[i](local_mask[:, i]) for i in range(num_local)], dim=1 | |
) # (bs, 19, 1024, 64) | |
global_local_outputs = torch.cat( | |
[global_output, local_outputs], dim=1 | |
) # bs 20 1024 64 | |
if isinstance(self.gating, nn.Identity): | |
global_local_outputs = global_local_outputs.sum(dim=1) | |
return global_local_outputs | |
if self.weight_is_scale: | |
weights = torch.mean(weights, dim=1, keepdim=True) # bs 1 20 | |
# print("gating scale") | |
weights_expanded = weights.unsqueeze(-1) | |
output = (global_local_outputs.permute(0, 2, 1, 3) * weights_expanded).sum( | |
dim=2 | |
) | |
return output # bs 1024 64 | |