Arcana-Qwen3-2.4B-A0.6B / qwen3moe_model.py
suayptalha's picture
Update qwen3moe_model.py
64d07e3 verified
import os
import torch
from torch.nn import ModuleDict
from transformers import (
PreTrainedModel,
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer
)
from transformers.modeling_outputs import CausalLMOutput
from typing import Optional, Tuple, Union, Dict
from .qwen3moe_configuration import Qwen3MoEConfig
class Qwen3MoEForCausalLM(PreTrainedModel):
config_class = Qwen3MoEConfig
def __init__(self, config: Qwen3MoEConfig):
super().__init__(config)
self.router = AutoModelForSequenceClassification.from_pretrained(
config.router_model_path,
torch_dtype=config.torch_dtype,
trust_remote_code=True,
local_files_only=True
)
self.router_tokenizer = AutoTokenizer.from_pretrained(
config.router_model_path,
trust_remote_code=True,
local_files_only=True
)
self.experts = ModuleDict({
label: AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=config.torch_dtype,
trust_remote_code=True,
local_files_only=True
)
for label, path in config.expert_model_paths.items()
})
self.expert_tokenizer = AutoTokenizer.from_pretrained(
config.tokenizer_path,
trust_remote_code=True,
local_files_only=True
)
@classmethod
def from_pretrained(cls, pretrained_dir: str, config: Optional[Qwen3MoEConfig] = None, **kwargs):
if config is None:
config = Qwen3MoEConfig.from_pretrained(pretrained_dir)
base = pretrained_dir
config.router_model_path = os.path.join(base, config.router_model_path)
config.expert_model_paths = {
label: os.path.join(base, path)
for label, path in config.expert_model_paths.items()
}
config.tokenizer_path = os.path.join(base, config.tokenizer_path)
return cls(config)
def get_tokenizer(self):
return self.expert_tokenizer
def route(self, plain_text: str) -> str:
with torch.no_grad():
inputs = self.router_tokenizer(plain_text, return_tensors="pt").to(self.router.device)
logits = self.router(**inputs).logits
if logits.dim() == 2:
class_id = torch.argmax(logits, dim=-1).item()
return self.config.labels[class_id]
return self.config.labels[0]
def generate(
self,
text: str,
max_new_tokens: int = 50,
**kwargs
) -> torch.LongTensor:
# 1. Route using router tokenizer
plain_text = text
if "<|im_start|>" in plain_text:
temp = plain_text.split("<|im_start|>")[-2]
plain_text = temp[:temp.find("<|im_end|>")][4:]
label = self.route(plain_text)
expert = self.experts[label]
# 2. Tokenize once with the expert tokenizer
inputs = self.expert_tokenizer(text, return_tensors="pt").to(expert.device)
# 3. Generate using selected expert
return expert.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_new_tokens,
**kwargs
)
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs
) -> Union[Tuple, CausalLMOutput]:
raise NotImplementedError("Use `generate(text=...)` instead for inference.")