|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import math |
|
from typing import Any, Optional |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import Tensor, nn |
|
from torch.optim import AdamW |
|
from torch.optim.lr_scheduler import OneCycleLR |
|
|
|
from pytorch_lightning.utilities.types import STEP_OUTPUT |
|
from timm.optim.optim_factory import param_groups_weight_decay |
|
|
|
from strhub.models.base import CrossEntropySystem |
|
from strhub.models.utils import init_weights |
|
|
|
from .model_abinet_iter import ABINetIterModel as Model |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
class ABINet(CrossEntropySystem): |
|
|
|
def __init__( |
|
self, |
|
charset_train: str, |
|
charset_test: str, |
|
max_label_length: int, |
|
batch_size: int, |
|
lr: float, |
|
warmup_pct: float, |
|
weight_decay: float, |
|
iter_size: int, |
|
d_model: int, |
|
nhead: int, |
|
d_inner: int, |
|
dropout: float, |
|
activation: str, |
|
v_loss_weight: float, |
|
v_attention: str, |
|
v_attention_mode: str, |
|
v_backbone: str, |
|
v_num_layers: int, |
|
l_loss_weight: float, |
|
l_num_layers: int, |
|
l_detach: bool, |
|
l_use_self_attn: bool, |
|
l_lr: float, |
|
a_loss_weight: float, |
|
lm_only: bool = False, |
|
**kwargs, |
|
) -> None: |
|
super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) |
|
self.scheduler = None |
|
self.save_hyperparameters() |
|
self.max_label_length = max_label_length |
|
self.num_classes = len(self.tokenizer) - 2 |
|
self.model = Model( |
|
max_label_length, |
|
self.eos_id, |
|
self.num_classes, |
|
iter_size, |
|
d_model, |
|
nhead, |
|
d_inner, |
|
dropout, |
|
activation, |
|
v_loss_weight, |
|
v_attention, |
|
v_attention_mode, |
|
v_backbone, |
|
v_num_layers, |
|
l_loss_weight, |
|
l_num_layers, |
|
l_detach, |
|
l_use_self_attn, |
|
a_loss_weight, |
|
) |
|
self.model.apply(init_weights) |
|
|
|
self._reset_alignment = True |
|
self._reset_optimizers = True |
|
self.l_lr = l_lr |
|
self.lm_only = lm_only |
|
|
|
if lm_only: |
|
self.l_lr = lr |
|
self.model.vision.requires_grad_(False) |
|
self.model.alignment.requires_grad_(False) |
|
|
|
@property |
|
def _pretraining(self): |
|
|
|
total_steps = self.trainer.estimated_stepping_batches * self.trainer.accumulate_grad_batches |
|
return self.global_step < (8 / (8 + 10)) * total_steps |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
return {'model.language.proj.weight'} |
|
|
|
def _add_weight_decay(self, model: nn.Module, skip_list=()): |
|
if self.weight_decay: |
|
return param_groups_weight_decay(model, self.weight_decay, skip_list) |
|
else: |
|
return [{'params': model.parameters()}] |
|
|
|
def configure_optimizers(self): |
|
agb = self.trainer.accumulate_grad_batches |
|
|
|
lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256.0 |
|
lr = lr_scale * self.lr |
|
l_lr = lr_scale * self.l_lr |
|
params = [] |
|
params.extend(self._add_weight_decay(self.model.vision)) |
|
params.extend(self._add_weight_decay(self.model.alignment)) |
|
|
|
for p in self._add_weight_decay(self.model.language, ('proj.weight',)): |
|
p['lr'] = l_lr |
|
params.append(p) |
|
max_lr = [p.get('lr', lr) for p in params] |
|
optim = AdamW(params, lr) |
|
self.scheduler = OneCycleLR( |
|
optim, max_lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, cycle_momentum=False |
|
) |
|
return {'optimizer': optim, 'lr_scheduler': {'scheduler': self.scheduler, 'interval': 'step'}} |
|
|
|
def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: |
|
max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) |
|
logits = self.model.forward(images)[0]['logits'] |
|
return logits[:, : max_length + 1] |
|
|
|
def calc_loss(self, targets, *res_lists) -> Tensor: |
|
total_loss = 0 |
|
for res_list in res_lists: |
|
loss = 0 |
|
if isinstance(res_list, dict): |
|
res_list = [res_list] |
|
for res in res_list: |
|
logits = res['logits'].flatten(end_dim=1) |
|
loss += F.cross_entropy(logits, targets.flatten(), ignore_index=self.pad_id) |
|
loss /= len(res_list) |
|
self.log('loss_' + res_list[0]['name'], loss) |
|
total_loss += res_list[0]['loss_weight'] * loss |
|
return total_loss |
|
|
|
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: |
|
if not self._pretraining and self._reset_optimizers: |
|
log.info('Pretraining ends. Updating base LRs.') |
|
self._reset_optimizers = False |
|
|
|
base_lr = self.scheduler.base_lrs[0] |
|
self.scheduler.base_lrs = [base_lr] * len(self.scheduler.base_lrs) |
|
|
|
def _prepare_inputs_and_targets(self, labels): |
|
|
|
dummy = ['0' * self.max_label_length] |
|
targets = self.tokenizer.encode(dummy + list(labels), self.device)[1:] |
|
targets = targets[:, 1:] |
|
|
|
inputs = torch.where(targets == self.pad_id, self.eos_id, targets) |
|
inputs = F.one_hot(inputs, self.num_classes).float() |
|
lengths = torch.as_tensor(list(map(len, labels)), device=self.device) + 1 |
|
return inputs, lengths, targets |
|
|
|
def training_step(self, batch, batch_idx) -> STEP_OUTPUT: |
|
images, labels = batch |
|
inputs, lengths, targets = self._prepare_inputs_and_targets(labels) |
|
if self.lm_only: |
|
l_res = self.model.language(inputs, lengths) |
|
loss = self.calc_loss(targets, l_res) |
|
|
|
elif self._pretraining: |
|
|
|
v_res = self.model.vision(images) |
|
|
|
l_res = self.model.language(inputs, lengths) |
|
|
|
|
|
a_res = self.model.alignment(l_res['feature'].detach(), v_res['feature'].detach()) |
|
loss = self.calc_loss(targets, v_res, l_res, a_res) |
|
else: |
|
|
|
if self._reset_alignment: |
|
log.info('Pretraining ends. Resetting alignment model.') |
|
self._reset_alignment = False |
|
self.model.alignment.apply(init_weights) |
|
all_a_res, all_l_res, v_res = self.model.forward(images) |
|
loss = self.calc_loss(targets, v_res, all_l_res, all_a_res) |
|
self.log('loss', loss) |
|
return loss |
|
|
|
def forward_logits_loss(self, images: Tensor, labels: list[str]) -> tuple[Tensor, Tensor, int]: |
|
if self.lm_only: |
|
inputs, lengths, targets = self._prepare_inputs_and_targets(labels) |
|
l_res = self.model.language(inputs, lengths) |
|
loss = self.calc_loss(targets, l_res) |
|
loss_numel = (targets != self.pad_id).sum() |
|
return l_res['logits'], loss, loss_numel |
|
else: |
|
return super().forward_logits_loss(images, labels) |
|
|