PyCIL / models /coil.py
HungNP
New single commit message
cb80c28
raw
history blame
14 kB
import logging
import numpy as np
from tqdm import tqdm
import torch
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from models.base import BaseLearner
from utils.inc_net import (
IncrementalNet,
CosineIncrementalNet,
SimpleCosineIncrementalNet,
)
from utils.toolkit import target2onehot, tensor2numpy
import ot
from torch import nn
import copy
EPSILON = 1e-8
epochs = 100
lrate = 0.1
milestones = [40, 80]
lrate_decay = 0.1
batch_size = 32
memory_size = 2000
T = 2
class COIL(BaseLearner):
def __init__(self, args):
super().__init__(args)
self._network = SimpleCosineIncrementalNet(args, False)
self.data_manager = None
self.nextperiod_initialization = None
self.sinkhorn_reg = args["sinkhorn"]
self.calibration_term = args["calibration_term"]
self.args = args
def after_task(self):
self.nextperiod_initialization = self.solving_ot()
self._old_network = self._network.copy().freeze()
self._known_classes = self._total_classes
def solving_ot(self):
with torch.no_grad():
if self._total_classes == self.data_manager.get_total_classnum():
print("training over, no more ot solving")
return None
each_time_class_num = self.data_manager.get_task_size(1)
self._extract_class_means(
self.data_manager, 0, self._total_classes + each_time_class_num
)
former_class_means = torch.tensor(
self._ot_prototype_means[: self._total_classes]
)
next_period_class_means = torch.tensor(
self._ot_prototype_means[
self._total_classes : self._total_classes + each_time_class_num
]
)
Q_cost_matrix = torch.cdist(
former_class_means, next_period_class_means, p=self.args["norm_term"]
)
# solving ot
_mu1_vec = (
torch.ones(len(former_class_means)) / len(former_class_means) * 1.0
)
_mu2_vec = (
torch.ones(len(next_period_class_means)) / len(former_class_means) * 1.0
)
T = ot.sinkhorn(_mu1_vec, _mu2_vec, Q_cost_matrix, self.sinkhorn_reg)
T = torch.tensor(T).float().cuda()
transformed_hat_W = torch.mm(
T.T, F.normalize(self._network.fc.weight, p=2, dim=1)
)
oldnorm = torch.norm(self._network.fc.weight, p=2, dim=1)
newnorm = torch.norm(
transformed_hat_W * len(former_class_means), p=2, dim=1
)
meannew = torch.mean(newnorm)
meanold = torch.mean(oldnorm)
gamma = meanold / meannew
self.calibration_term = gamma
self._ot_new_branch = (
transformed_hat_W * len(former_class_means) * self.calibration_term
)
return transformed_hat_W * len(former_class_means) * self.calibration_term
def solving_ot_to_old(self):
current_class_num = self.data_manager.get_task_size(self._cur_task)
self._extract_class_means_with_memory(
self.data_manager, self._known_classes, self._total_classes
)
former_class_means = torch.tensor(
self._ot_prototype_means[: self._known_classes]
)
next_period_class_means = torch.tensor(
self._ot_prototype_means[self._known_classes : self._total_classes]
)
Q_cost_matrix = (
torch.cdist(
next_period_class_means, former_class_means, p=self.args["norm_term"]
)
+ EPSILON
) # in case of numerical err
_mu1_vec = torch.ones(len(former_class_means)) / len(former_class_means) * 1.0
_mu2_vec = (
torch.ones(len(next_period_class_means)) / len(former_class_means) * 1.0
)
T = ot.sinkhorn(_mu2_vec, _mu1_vec, Q_cost_matrix, self.sinkhorn_reg)
T = torch.tensor(T).float().cuda()
transformed_hat_W = torch.mm(
T.T,
F.normalize(self._network.fc.weight[-current_class_num:, :], p=2, dim=1),
)
return transformed_hat_W * len(former_class_means) * self.calibration_term
def incremental_train(self, data_manager):
self._cur_task += 1
self._total_classes = self._known_classes + data_manager.get_task_size(
self._cur_task
)
self._network.update_fc(self._total_classes, self.nextperiod_initialization)
self.data_manager = data_manager
logging.info(
"Learning on {}-{}".format(self._known_classes, self._total_classes)
)
self.lamda = self._known_classes / self._total_classes
# Loader
train_dataset = data_manager.get_dataset(
np.arange(self._known_classes, self._total_classes),
source="train",
mode="train",
appendent=self._get_memory(),
)
self.train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
)
test_dataset = data_manager.get_dataset(
np.arange(0, self._total_classes), source="test", mode="test"
)
self.test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=4
)
self._train(self.train_loader, self.test_loader)
if self.args['fixed_memory']:
examplar_size = self.args["memory_per_class"]
else:
examplar_size = memory_size // self._total_classes
self._reduce_exemplar(data_manager, examplar_size)
self._construct_exemplar(data_manager, examplar_size)
def _train(self, train_loader, test_loader):
self._network.to(self._device)
if self._old_network is not None:
self._old_network.to(self._device)
optimizer = optim.SGD(
self._network.parameters(), lr=lrate, momentum=0.9, weight_decay=5e-4
) # 1e-5
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer=optimizer, milestones=milestones, gamma=lrate_decay
)
self._update_representation(train_loader, test_loader, optimizer, scheduler)
def _update_representation(self, train_loader, test_loader, optimizer, scheduler):
prog_bar = tqdm(range(epochs))
for _, epoch in enumerate(prog_bar):
weight_ot_init = max(1.0 - (epoch / 2) ** 2, 0)
weight_ot_co_tuning = (epoch / epochs) ** 2.0
self._network.train()
losses = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(self._device), targets.to(self._device)
output = self._network(inputs)
logits = output["logits"]
onehots = target2onehot(targets, self._total_classes)
clf_loss = F.cross_entropy(logits, targets)
if self._old_network is not None:
old_logits = self._old_network(inputs)["logits"].detach()
hat_pai_k = F.softmax(old_logits / T, dim=1)
log_pai_k = F.log_softmax(
logits[:, : self._known_classes] / T, dim=1
)
distill_loss = -torch.mean(torch.sum(hat_pai_k * log_pai_k, dim=1))
if epoch < 1:
features = F.normalize(output["features"], p=2, dim=1)
current_logit_new = F.log_softmax(
logits[:, self._known_classes :] / T, dim=1
)
new_logit_by_wnew_init_by_ot = F.linear(
features, F.normalize(self._ot_new_branch, p=2, dim=1)
)
new_logit_by_wnew_init_by_ot = F.softmax(
new_logit_by_wnew_init_by_ot / T, dim=1
)
new_branch_distill_loss = -torch.mean(
torch.sum(
current_logit_new * new_logit_by_wnew_init_by_ot, dim=1
)
)
loss = (
distill_loss * self.lamda
+ clf_loss * (1 - self.lamda)
+ 0.001 * (weight_ot_init * new_branch_distill_loss)
)
else:
features = F.normalize(output["features"], p=2, dim=1)
if i % 30 == 0:
with torch.no_grad():
self._ot_old_branch = self.solving_ot_to_old()
old_logit_by_wold_init_by_ot = F.linear(
features, F.normalize(self._ot_old_branch, p=2, dim=1)
)
old_logit_by_wold_init_by_ot = F.log_softmax(
old_logit_by_wold_init_by_ot / T, dim=1
)
old_branch_distill_loss = -torch.mean(
torch.sum(hat_pai_k * old_logit_by_wold_init_by_ot, dim=1)
)
loss = (
distill_loss * self.lamda
+ clf_loss * (1 - self.lamda)
+ self.args["reg_term"]
* (weight_ot_co_tuning * old_branch_distill_loss)
)
else:
loss = clf_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()
_, preds = torch.max(logits, dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)
scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)
test_acc = self._compute_accuracy(self._network, test_loader)
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
self._cur_task,
epoch + 1,
epochs,
losses / len(train_loader),
train_acc,
test_acc,
)
prog_bar.set_description(info)
logging.info(info)
def _extract_class_means(self, data_manager, low, high):
self._ot_prototype_means = np.zeros(
(data_manager.get_total_classnum(), self._network.feature_dim)
)
with torch.no_grad():
for class_idx in range(low, high):
data, targets, idx_dataset = data_manager.get_dataset(
np.arange(class_idx, class_idx + 1),
source="train",
mode="test",
ret_data=True,
)
idx_loader = DataLoader(
idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
)
vectors, _ = self._extract_vectors(idx_loader)
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
class_mean = np.mean(vectors, axis=0)
class_mean = class_mean / (np.linalg.norm(class_mean))
self._ot_prototype_means[class_idx, :] = class_mean
self._network.train()
def _extract_class_means_with_memory(self, data_manager, low, high):
self._ot_prototype_means = np.zeros(
(data_manager.get_total_classnum(), self._network.feature_dim)
)
memoryx, memoryy = self._data_memory, self._targets_memory
with torch.no_grad():
for class_idx in range(0, low):
idxes = np.where(
np.logical_and(memoryy >= class_idx, memoryy < class_idx + 1)
)[0]
data, targets = memoryx[idxes], memoryy[idxes]
# idx_dataset=TensorDataset(data,targets)
# idx_loader = DataLoader(idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
_, _, idx_dataset = data_manager.get_dataset(
[],
source="train",
appendent=(data, targets),
mode="test",
ret_data=True,
)
idx_loader = DataLoader(
idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
)
vectors, _ = self._extract_vectors(idx_loader)
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
class_mean = np.mean(vectors, axis=0)
class_mean = class_mean / np.linalg.norm(class_mean)
self._ot_prototype_means[class_idx, :] = class_mean
for class_idx in range(low, high):
data, targets, idx_dataset = data_manager.get_dataset(
np.arange(class_idx, class_idx + 1),
source="train",
mode="test",
ret_data=True,
)
idx_loader = DataLoader(
idx_dataset, batch_size=batch_size, shuffle=False, num_workers=4
)
vectors, _ = self._extract_vectors(idx_loader)
vectors = (vectors.T / (np.linalg.norm(vectors.T, axis=0) + EPSILON)).T
class_mean = np.mean(vectors, axis=0)
class_mean = class_mean / np.linalg.norm(class_mean)
self._ot_prototype_means[class_idx, :] = class_mean
self._network.train()