|
""" |
|
Custom Gemma3 model for token classification with non-causal attention |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
from typing import Optional, Tuple, List, Dict, Any |
|
import types |
|
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
from transformers import Gemma3ForCausalLM |
|
from transformers.models.gemma3.modeling_gemma3 import ( |
|
Gemma3Attention, |
|
repeat_kv, |
|
apply_rotary_pos_emb, |
|
ALL_ATTENTION_FUNCTIONS, |
|
Cache, |
|
FlashAttentionKwargs, |
|
) |
|
from transformers.modeling_outputs import TokenClassifierOutput |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Gemma3PunctuationConfig(PretrainedConfig): |
|
""" |
|
Configuration class for Gemma3 punctuation model. |
|
""" |
|
|
|
model_type = "cadence_punctuation" |
|
|
|
def __init__( |
|
self, |
|
num_labels: int = 31, |
|
classifier_dropout_prob: float = 0.0, |
|
use_non_causal_attention: bool = True, |
|
**kwargs |
|
): |
|
self.num_labels = num_labels |
|
self.classifier_dropout_prob = classifier_dropout_prob |
|
self.use_non_causal_attention = use_non_causal_attention |
|
super().__init__(**kwargs) |
|
|
|
|
|
def _extract_padding_mask_corrected( |
|
combined_mask_4d: Optional[torch.Tensor], |
|
debug_print: bool = False |
|
) -> Optional[torch.Tensor]: |
|
"""Extract padding mask from combined 4D attention mask.""" |
|
if combined_mask_4d is None: |
|
return None |
|
|
|
mask_value = torch.finfo(combined_mask_4d.dtype).min |
|
is_key_padding = (combined_mask_4d == mask_value).all(dim=2, keepdim=True) |
|
padding_only_mask = torch.where( |
|
is_key_padding.expand_as(combined_mask_4d), |
|
torch.full_like(combined_mask_4d, mask_value), |
|
torch.zeros_like(combined_mask_4d) |
|
) |
|
return padding_only_mask |
|
|
|
|
|
def non_causal_eager_attention_forward_with_padding( |
|
module: nn.Module, |
|
query: torch.Tensor, |
|
key: torch.Tensor, |
|
value: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor], |
|
**kwargs: Any, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Non-causal eager attention implementation.""" |
|
dropout = kwargs.get("dropout", 0.0) |
|
scaling = kwargs.get("scaling", None) |
|
softcap = kwargs.get("softcap", None) |
|
|
|
if scaling is None: |
|
head_dim = getattr(module, "head_dim", query.shape[-1]) |
|
scaling = head_dim**-0.5 |
|
|
|
num_key_value_groups = getattr(module, "num_key_value_groups", 1) |
|
key_states = repeat_kv(key, num_key_value_groups) |
|
value_states = repeat_kv(value, num_key_value_groups) |
|
|
|
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
|
|
|
if softcap is not None: |
|
attn_weights = attn_weights / softcap |
|
attn_weights = torch.tanh(attn_weights) |
|
attn_weights = attn_weights * softcap |
|
|
|
if attention_mask is not None: |
|
mask_slice = attention_mask[:, :, :, : key_states.shape[-2]] |
|
attn_weights = attn_weights + mask_slice |
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
|
is_training = getattr(module, "training", False) |
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=is_training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
return attn_output, attn_weights |
|
|
|
|
|
def modified_gemma3_attention_forward_non_causal( |
|
self: Gemma3Attention, |
|
hidden_states: torch.Tensor, |
|
position_embeddings: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor], |
|
past_key_value: Optional[Cache] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
**kwargs: Any, |
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: |
|
"""Modified Gemma3 attention forward for non-causal behavior.""" |
|
bsz, q_len, _ = hidden_states.size() |
|
input_shape = hidden_states.shape[:-1] |
|
hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
|
query_states = self.q_norm(query_states) |
|
key_states = self.k_norm(key_states) |
|
cos, sin = position_embeddings |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
cache_kwargs = { |
|
"sin": sin, |
|
"cos": cos, |
|
"cache_position": cache_position, |
|
"sliding_window": self.sliding_window |
|
} |
|
key_states, value_states = past_key_value.update( |
|
key_states, value_states, self.layer_idx, cache_kwargs |
|
) |
|
|
|
effective_attn_implementation = self.config._attn_implementation |
|
output_attentions = kwargs.get("output_attentions", False) |
|
|
|
if effective_attn_implementation == "sdpa" and output_attentions: |
|
effective_attn_implementation = "eager" |
|
elif effective_attn_implementation == "flash_attention_2" and output_attentions: |
|
effective_attn_implementation = "eager" |
|
|
|
padding_only_mask = _extract_padding_mask_corrected(attention_mask) |
|
use_causal_flag = False |
|
|
|
|
|
if effective_attn_implementation == "eager": |
|
attention_interface = non_causal_eager_attention_forward_with_padding |
|
elif effective_attn_implementation == "sdpa": |
|
attention_interface = ALL_ATTENTION_FUNCTIONS.get("sdpa", non_causal_eager_attention_forward_with_padding) |
|
elif effective_attn_implementation == "flash_attention_2": |
|
attention_interface = ALL_ATTENTION_FUNCTIONS.get("flash_attention_2", non_causal_eager_attention_forward_with_padding) |
|
else: |
|
attention_interface = non_causal_eager_attention_forward_with_padding |
|
|
|
final_attention_mask = padding_only_mask |
|
if final_attention_mask is not None: |
|
final_attention_mask = final_attention_mask.to(query_states.device) |
|
|
|
|
|
attn_specific_kwargs: Dict[str, Any] = {} |
|
if attention_interface == non_causal_eager_attention_forward_with_padding: |
|
attn_specific_kwargs = { |
|
"dropout": 0.0, |
|
"scaling": self.scaling, |
|
"softcap": getattr(self, "softcap", None) |
|
} |
|
elif effective_attn_implementation == "sdpa": |
|
attn_specific_kwargs = {"is_causal": use_causal_flag} |
|
if output_attentions: |
|
attn_specific_kwargs["output_attentions"] = True |
|
elif effective_attn_implementation == "flash_attention_2": |
|
attn_specific_kwargs = { |
|
"causal": use_causal_flag, |
|
"softcap": getattr(self, "softcap", None), |
|
"dropout": 0.0 |
|
} |
|
if output_attentions: |
|
attn_specific_kwargs["output_attentions"] = True |
|
|
|
attn_output, attn_weights = attention_interface( |
|
self, query_states, key_states, value_states, final_attention_mask, **attn_specific_kwargs |
|
) |
|
|
|
attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
|
attn_output = self.o_proj(attn_output) |
|
|
|
returned_weights = attn_weights if output_attentions and attn_weights is not None else None |
|
|
|
return attn_output, returned_weights |
|
|
|
|
|
class Gemma3ForTokenClassification(Gemma3ForCausalLM): |
|
""" |
|
Gemma3 model for token classification (punctuation prediction). |
|
Inherits from Gemma3ForCausalLM and replaces the LM head with classification head. |
|
""" |
|
|
|
config_class = Gemma3PunctuationConfig |
|
|
|
def __init__(self, config): |
|
|
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
|
|
|
|
|
|
classifier_dropout_prob = getattr(config, 'classifier_dropout_prob', 0.0) |
|
self.lm_head = nn.Sequential( |
|
nn.Dropout(classifier_dropout_prob), |
|
nn.Linear(config.hidden_size, config.num_labels) |
|
) |
|
|
|
|
|
self.config.num_labels = config.num_labels |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
if getattr(config, 'use_non_causal_attention', True): |
|
self._patch_attention_layers() |
|
|
|
def _patch_attention_layers(self): |
|
"""Patch attention layers to use non-causal attention.""" |
|
count = 0 |
|
|
|
|
|
if hasattr(self, 'model') and hasattr(self.model, 'layers'): |
|
target_layers = self.model.layers |
|
else: |
|
logger.warning("Could not find model.layers for attention patching") |
|
return |
|
|
|
for idx, layer in enumerate(target_layers): |
|
if hasattr(layer, 'self_attn') and isinstance(layer.self_attn, Gemma3Attention): |
|
layer.self_attn.layer_idx = idx |
|
layer.self_attn.forward = types.MethodType( |
|
modified_gemma3_attention_forward_non_causal, |
|
layer.self_attn |
|
) |
|
count += 1 |
|
|
|
logger.info(f"Patched {count} attention layers for non-causal attention") |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
) -> TokenClassifierOutput: |
|
""" |
|
Forward pass for token classification. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
|
|
|
logits = self.lm_head(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return TokenClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
|
|
from transformers import AutoConfig, AutoModel |
|
AutoConfig.register("cadence_punctuation", Gemma3PunctuationConfig) |
|
AutoModel.register(Gemma3PunctuationConfig, Gemma3ForTokenClassification) |