## 1. Import necessary packages

In [None]:
import numpy as np
import tensorflow as tf
import scipy.io
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
import random
import os
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
import torch.nn as nn
from torch.autograd import Function

cudnn.benchmark = False
cudnn.deterministic = True
cuda = True

lr = 3e-4
batch_size = 16
image_size = 28
n_epoch = 200

def dataprocess(data, target):
 data = torch.from_numpy(data).float()
 target = torch.from_numpy(target).long() 
 dataset = TensorDataset(data, target)
 trainloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

 return trainloader

## 2. Prepare the datasets for clients (source) and the cloud (target)

In [None]:
mat = scipy.io.loadmat('Digit-Five/mnist_data.mat')
data = np.transpose((np.array((tf.image.grayscale_to_rgb(tf.convert_to_tensor(mat['train_28'])))).astype('float32')/255.0).reshape(-1,28,28,3), (0,3,1,2))
target = np.argmax((mat['label_train']), axis = 1)
c1_mt = [data, target]

mat = scipy.io.loadmat('Digit-Five/mnistm_with_label.mat')
data = np.transpose((np.array((tf.convert_to_tensor(mat['train']))).astype('float32')/255.0).reshape(-1,28,28,3), (0,3,1,2)) 
target = np.argmax((mat['label_train']), axis = 1)
c2_mm =[data, target]

mat = scipy.io.loadmat('Digit-Five/usps_28x28.mat')
data = np.transpose((np.array((tf.image.grayscale_to_rgb(tf.convert_to_tensor(mat['dataset'][0][0].reshape(-1,28,28,1))))).astype('float32')).reshape(-1,28,28,3), (0,3,1,2))
target = mat['dataset'][0][1].flatten()
c3_up = [data, target]

mat = scipy.io.loadmat('Digit-Five/svhn_train_32x32.mat')
data = np.transpose((np.array((tf.image.resize(np.moveaxis(mat['X'], -1, 0), [28,28]) )).astype('float32')/255.0).reshape(-1,28,28,3), (0,3,1,2))
target = (mat['y']-1).flatten()
c4_sv = [data, target]

mat = scipy.io.loadmat('Digit-Five/syn_number.mat')
data = np.transpose((np.array((tf.image.resize(mat['train_data'], [28,28]) )).astype('float32')/255.0).reshape(-1,28,28,3), (0,3,1,2)) 
target = mat['train_label'].flatten()
c5_sy = [data, target]

In [None]:
c1 = dataprocess(c1_mt[0], c1_mt[1])
c2 = dataprocess(c2_mm[0], c2_mm[1])
c3 = dataprocess(c3_up[0], c3_up[1])
c4 = dataprocess(c4_sv[0], c4_sv[1])
c5 = dataprocess(c5_sy[0], c5_sy[1])

data_all = [c1_mt[0],c2_mm[0], c3_up[0], c4_sv[0], c5_sy[0]]

## 3. Define the MK-MMD loss (guassian kernel)

In [None]:
class MMD_loss(nn.Module):
 def __init__(self, kernel_mul = 2.0, kernel_num = 5):
 super(MMD_loss, self).__init__()
 self.kernel_num = kernel_num
 self.kernel_mul = kernel_mul
 self.fix_sigma = None
 return

 def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
 n_samples = int(source.size()[0])+int(target.size()[0])
 total = torch.cat([source, target], dim=0)

 total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
 total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
 L2_distance = ((total0-total1)**2).sum(2) 
 if fix_sigma:
 bandwidth = fix_sigma
 else:
 bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
 bandwidth /= kernel_mul ** (kernel_num // 2)
 bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
 kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
 return sum(kernel_val)

 def forward(self, source, target):
 batch_size = int(source.size()[0])
 kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
 XX = kernels[:batch_size, :batch_size]
 YY = kernels[batch_size:, batch_size:]
 XY = kernels[:batch_size, batch_size:]
 YX = kernels[batch_size:, :batch_size]
 loss = torch.mean(XX + YY - XY -YX)
 return loss

## 4. Define the model architecture following the Reverse Layer

In [None]:
class ReverseLayerF(Function):

 @staticmethod
 def forward(ctx, x, alpha):
 ctx.alpha = alpha

 return x.view_as(x)

 @staticmethod
 def backward(ctx, grad_output):
 output = grad_output.neg() * ctx.alpha

 return output, None

class CNNModel(nn.Module):

 def __init__(self):
 super(CNNModel, self).__init__()
 self.feature = nn.Sequential()
 self.feature.add_module('f_conv1', nn.Conv2d(3, 64, kernel_size=5))
 self.feature.add_module('f_bn1', nn.BatchNorm2d(64))
 self.feature.add_module('f_pool1', nn.MaxPool2d(2))
 self.feature.add_module('f_relu1', nn.ReLU(True))
 self.feature.add_module('f_conv2', nn.Conv2d(64, 50, kernel_size=5))
 self.feature.add_module('f_bn2', nn.BatchNorm2d(50))
 self.feature.add_module('f_drop1', nn.Dropout2d())
 self.feature.add_module('f_pool2', nn.MaxPool2d(2))
 self.feature.add_module('f_relu2', nn.ReLU(True))
 
 self.class_classifier = nn.Sequential()
 self.class_classifier.add_module('c_fc1', nn.Linear(50 * 4 * 4, 100))
 self.class_classifier.add_module('c_bn1', nn.BatchNorm1d(100))
 self.class_classifier.add_module('c_relu1', nn.ReLU(True))
 self.class_classifier.add_module('c_fc3', nn.Linear(100, 10))
 self.class_classifier.add_module('c_softmax', nn.LogSoftmax())
 
 self.domain_classifier = nn.Sequential()
 self.domain_classifier.add_module('d_fc1', nn.Linear(50 * 4 * 4, 100))
 self.domain_classifier.add_module('d_bn1', nn.BatchNorm1d(100))
 self.domain_classifier.add_module('d_relu1', nn.ReLU(True))
 self.domain_classifier.add_module('d_fc2', nn.Linear(100, 2))
 self.domain_classifier.add_module('d_softmax', nn.LogSoftmax(dim=1))

 def forward(self, input_data, alpha):
 input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28)
 feature = self.feature(input_data)
 feature = feature.view(-1, 50 * 4 * 4)
 reverse_feature = ReverseLayerF.apply(feature, alpha)
 class_output = self.class_classifier(feature)
 domain_output = self.domain_classifier(reverse_feature)

 return class_output, domain_output

## 5. Federated Knowledge Alignment (FedKA) 

In [None]:
from tqdm import notebook

def run(method, voting):
 learned_models = []
 global_net = CNNModel()
 global_optimizer = optim.Adam(global_net.parameters(), lr=lr) 
 clients = []
 optims = []
 client_num = 4
 for n in range(client_num):
 local_net = CNNModel()
 local_optimizer = optim.Adam(local_net.parameters(), lr=lr) 
 clients.append(local_net)
 optims.append(local_optimizer)

 loss_class = torch.nn.NLLLoss()
 loss_domain = torch.nn.NLLLoss()
 loss_class = loss_class.cuda()
 loss_domain = loss_domain.cuda()
 loss_mmd = MMD_loss() 
 
 acc_list = []
 for epoch in notebook.tqdm(range(n_epoch)):
 print(f"===========Round {epoch} ===========")
 if cuda:
 global_net =global_net.cuda()

 data_target_iter = iter(cloud_dataset[0]) 
 loss_epoch = []
 
 acc = []
 for n in range(client_num):
 # Enumerating batches from the dataloader provides a random selection of 512 samples every round.
 for i, (s_img, s_label) in enumerate(clients_datasets[n]): 
 len_dataloader = 32
 if i > 31: 
 break

 p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader
 alpha = 2. / (1. + np.exp(-5 * p)) - 1

 optims[n].zero_grad()
 batch_size = len(s_label)

 input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
 class_label = torch.LongTensor(batch_size)
 domain_label = torch.zeros(batch_size)
 domain_label = domain_label.long()

 if cuda:
 clients[n] = clients[n].cuda()
 s_img = s_img.cuda()
 s_label = s_label.cuda()
 input_img = input_img.cuda()
 class_label = class_label.cuda()
 domain_label = domain_label.cuda()

 input_img.resize_as_(s_img).copy_(s_img)
 class_label.resize_as_(s_label).copy_(s_label)
 class_output, domain_output = clients[n](input_data=input_img, alpha=alpha)

 err_s_label = loss_class(class_output, class_label)
 err_s_domain = loss_domain(domain_output, domain_label)

 t_img, _ = data_target_iter.next()
 batch_size = len(t_img)
 input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)
 domain_label = torch.ones(batch_size)
 domain_label = domain_label.long()

 if cuda:
 t_img = t_img.cuda()
 input_img = input_img.cuda()
 domain_label = domain_label.cuda()

 input_img.resize_as_(t_img).copy_(t_img)

 _, domain_output = clients[n](input_data=input_img, alpha=alpha)
 err_t_domain = loss_domain(domain_output, domain_label)

 if method == 0 or method == 3:
 err = err_s_label
 else:
 err = err_s_label + err_s_domain + err_t_domain

 err.backward()
 optims[n].step()
 
 # mmd loss
 mmd_loss_total = 0
 for i, (s_img, s_label) in enumerate(clients_datasets[n]): 
 len_dataloader = 32
 if i > 31: 
 break

 p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader
 alpha = 2. / (1. + np.exp(-5* p)) - 1
 batch_size = len(s_label)
 t_img, _ = data_target_iter.next()
 s_img = s_img.cuda()
 t_img = t_img.cuda()

 hidden_c = clients[n].feature(s_img).reshape(batch_size, -1)
 hidden_avg = global_net.feature(t_img).reshape(batch_size, -1)
 mmd_loss = loss_mmd(hidden_c, hidden_avg)
 mmd_loss_total =+mmd_loss*alpha
 
 if method == 2 or method == 3:
 if (i+1) % 8 == 0:
 optims[n].zero_grad()
 err_mmd = mmd_loss_total/8
 err_mmd.backward()
 optims[n].step()
 mmd_loss_total = 0 
 
 # FedAvg
 global_sd = global_net.state_dict()
 for key in global_sd:
 global_sd[key] = torch.sum(torch.stack([model.state_dict()[key] for m, model in enumerate(clients)]), axis = 0)/client_num
 # update the global model
 global_net.load_state_dict(global_sd) 
 
 
 if voting:
 total = 0
 num_correct = 0
 for i, (images, labels) in enumerate(cloud_dataset[0]):
 len_dataloader = 128
 if i > 127: 
 break

 p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader
 alpha = 2. / (1. + np.exp(-5* p)) - 1

 images =images.cuda()
 labels = labels.cuda()
 global_optimizer.zero_grad()
 class_output, _ = global_net(images, 0)

 votes = []
 for n in range(client_num):
 clients[n] = clients[n].cuda()
 output,_ = clients[n](images, 0)
 pred = torch.argmax(output, 1)
 votes.append(pred)

 class_label = torch.Tensor([int(max(set(batch), key = batch.count).cpu().data.numpy()) for batch in [list(i) for i in zip(*votes)]]).to(torch.int64)
 class_label = class_label.cuda()

 total += labels.size(0)
 num_correct += (class_label == labels).sum().item()

 err = loss_class(class_output, class_label)*alpha
 err.backward()
 global_optimizer.step()
 
 print(f'Voting accuracy: {num_correct * 100 / total}% Adoption rate: {alpha*100}%')
 
 
 # Evaluation every round 
 # Target task 
 with torch.no_grad():
 num_correct = 0
 total = 0

 for i, (images, labels) in enumerate(cloud_dataset[0]):
 if i > 312:
 break
 
 if cuda:
 global_net = global_net.cuda()
 images =images.cuda()
 labels = labels.cuda()
 
 output,_ = global_net(images, 0)
 pred = torch.argmax(output, 1)
 total += labels.size(0)
 num_correct += (pred == labels).sum().item()
 
 print(f'Global: Accuracy of the model on {total} test images: {num_correct * 100 / total}% \n')
 acc.append(num_correct * 100 / total)
 
 acc_list.append(acc)

 for n in range(client_num):
 clients[n].load_state_dict(global_sd)

 return acc_list

## 6. Run experiments

### Method 0: Source Only

### Method 1: $f$-DANN

### Method 2: FedKA

In [None]:
import matplotlib.pyplot as plt

methods = [2]
voting = True

# run 5 tasks
for t in range(5):
 acc = []
 clients_datasets = [c1, c2, c3, c4, c5]
 cloud_dataset = [clients_datasets.pop(t)]
 target = f"c{t+1}"

 # use three seeds
 for s in range(3):
 torch.manual_seed(s)
 random.seed(s)
 np.random.seed(s)
 
 acc_m = []
 for method in methods:

 print(f"Task: c{t+1} Seed: {s} Method: {method}")

 evl = run(method, voting)
 
 result = np.array((evl)).T
 plt.plot(result[0], label = "Global", color = 'C4')
 plt.legend()
 plt.show()

 acc_m.append(max(np.array((evl))[:,0]))
 acc.append(acc_m)
 
 print(f'Task: c{t+1} Mean: {np.mean((np.array((acc)).T), axis =1)}')
 print(f'Task: c{t+1} Std: {np.std((np.array((acc)).T), axis =1)}')