|
import os |
|
import torch |
|
from torch.nn import ModuleDict |
|
from transformers import ( |
|
PreTrainedModel, |
|
AutoModelForCausalLM, |
|
AutoModelForSequenceClassification, |
|
AutoTokenizer |
|
) |
|
from transformers.modeling_outputs import CausalLMOutput |
|
from typing import Optional, Tuple, Union, Dict |
|
|
|
from .qwen3moe_configuration import Qwen3MoEConfig |
|
|
|
|
|
class Qwen3MoEForCausalLM(PreTrainedModel): |
|
config_class = Qwen3MoEConfig |
|
|
|
def __init__(self, config: Qwen3MoEConfig): |
|
super().__init__(config) |
|
self.router = AutoModelForSequenceClassification.from_pretrained( |
|
config.router_model_path, |
|
torch_dtype=config.torch_dtype, |
|
trust_remote_code=True, |
|
local_files_only=True |
|
) |
|
|
|
self.router_tokenizer = AutoTokenizer.from_pretrained( |
|
config.router_model_path, |
|
trust_remote_code=True, |
|
local_files_only=True |
|
) |
|
|
|
self.experts = ModuleDict({ |
|
label: AutoModelForCausalLM.from_pretrained( |
|
path, |
|
torch_dtype=config.torch_dtype, |
|
trust_remote_code=True, |
|
local_files_only=True |
|
) |
|
for label, path in config.expert_model_paths.items() |
|
}) |
|
|
|
self.expert_tokenizer = AutoTokenizer.from_pretrained( |
|
config.tokenizer_path, |
|
trust_remote_code=True, |
|
local_files_only=True |
|
) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_dir: str, config: Optional[Qwen3MoEConfig] = None, **kwargs): |
|
if config is None: |
|
config = Qwen3MoEConfig.from_pretrained(pretrained_dir) |
|
|
|
base = pretrained_dir |
|
config.router_model_path = os.path.join(base, config.router_model_path) |
|
config.expert_model_paths = { |
|
label: os.path.join(base, path) |
|
for label, path in config.expert_model_paths.items() |
|
} |
|
config.tokenizer_path = os.path.join(base, config.tokenizer_path) |
|
|
|
return cls(config) |
|
|
|
def get_tokenizer(self): |
|
return self.expert_tokenizer |
|
|
|
def route(self, plain_text: str) -> str: |
|
with torch.no_grad(): |
|
inputs = self.router_tokenizer(plain_text, return_tensors="pt").to(self.router.device) |
|
logits = self.router(**inputs).logits |
|
|
|
if logits.dim() == 2: |
|
class_id = torch.argmax(logits, dim=-1).item() |
|
return self.config.labels[class_id] |
|
|
|
return self.config.labels[0] |
|
|
|
def generate( |
|
self, |
|
text: str, |
|
max_new_tokens: int = 50, |
|
**kwargs |
|
) -> torch.LongTensor: |
|
|
|
plain_text = text |
|
if "<|im_start|>" in plain_text: |
|
temp = plain_text.split("<|im_start|>")[-2] |
|
plain_text = temp[:temp.find("<|im_end|>")][4:] |
|
|
|
label = self.route(plain_text) |
|
expert = self.experts[label] |
|
|
|
|
|
inputs = self.expert_tokenizer(text, return_tensors="pt").to(expert.device) |
|
|
|
|
|
return expert.generate( |
|
input_ids=inputs.input_ids, |
|
attention_mask=inputs.attention_mask, |
|
max_new_tokens=max_new_tokens, |
|
**kwargs |
|
) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
**kwargs |
|
) -> Union[Tuple, CausalLMOutput]: |
|
raise NotImplementedError("Use `generate(text=...)` instead for inference.") |
|
|