{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "LcO6E1lNBh1P" }, "source": [ "## 1. Import necessary packages" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dBaDYH8WBh1U" }, "outputs": [], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "import scipy.io\n", "from torch.utils.data import TensorDataset\n", "from torch.utils.data import DataLoader\n", "import torch\n", "import matplotlib.pyplot as plt\n", "import random\n", "import os\n", "import torch.backends.cudnn as cudnn\n", "import torch.optim as optim\n", "import torch.utils.data\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "import torch.nn as nn\n", "from torch.autograd import Function\n", "\n", "cudnn.benchmark = False\n", "cudnn.deterministic = True\n", "cuda = True\n", "\n", "lr = 3e-4\n", "batch_size = 16\n", "image_size = 28\n", "n_epoch = 200\n", "\n", "def dataprocess(data, target):\n", " data = torch.from_numpy(data).float()\n", " target = torch.from_numpy(target).long() \n", " dataset = TensorDataset(data, target)\n", " trainloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", "\n", " return trainloader" ] }, { "cell_type": "markdown", "metadata": { "id": "8n9FMTyTBh1U" }, "source": [ "## 2. Prepare the datasets for clients (source) and the cloud (target)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Zocl_klGBh1V" }, "outputs": [], "source": [ "mat = scipy.io.loadmat('Digit-Five/mnist_data.mat')\n", "data = np.transpose((np.array((tf.image.grayscale_to_rgb(tf.convert_to_tensor(mat['train_28'])))).astype('float32')/255.0).reshape(-1,28,28,3), (0,3,1,2))\n", "target = np.argmax((mat['label_train']), axis = 1)\n", "c1_mt = [data, target]\n", "\n", "mat = scipy.io.loadmat('Digit-Five/mnistm_with_label.mat')\n", "data = np.transpose((np.array((tf.convert_to_tensor(mat['train']))).astype('float32')/255.0).reshape(-1,28,28,3), (0,3,1,2)) \n", "target = np.argmax((mat['label_train']), axis = 1)\n", "c2_mm =[data, target]\n", "\n", "mat = scipy.io.loadmat('Digit-Five/usps_28x28.mat')\n", "data = np.transpose((np.array((tf.image.grayscale_to_rgb(tf.convert_to_tensor(mat['dataset'][0][0].reshape(-1,28,28,1))))).astype('float32')).reshape(-1,28,28,3), (0,3,1,2))\n", "target = mat['dataset'][0][1].flatten()\n", "c3_up = [data, target]\n", "\n", "mat = scipy.io.loadmat('Digit-Five/svhn_train_32x32.mat')\n", "data = np.transpose((np.array((tf.image.resize(np.moveaxis(mat['X'], -1, 0), [28,28]) )).astype('float32')/255.0).reshape(-1,28,28,3), (0,3,1,2))\n", "target = (mat['y']-1).flatten()\n", "c4_sv = [data, target]\n", "\n", "mat = scipy.io.loadmat('Digit-Five/syn_number.mat')\n", "data = np.transpose((np.array((tf.image.resize(mat['train_data'], [28,28]) )).astype('float32')/255.0).reshape(-1,28,28,3), (0,3,1,2)) \n", "target = mat['train_label'].flatten()\n", "c5_sy = [data, target]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6uECrd2eBh1V" }, "outputs": [], "source": [ "c1 = dataprocess(c1_mt[0], c1_mt[1])\n", "c2 = dataprocess(c2_mm[0], c2_mm[1])\n", "c3 = dataprocess(c3_up[0], c3_up[1])\n", "c4 = dataprocess(c4_sv[0], c4_sv[1])\n", "c5 = dataprocess(c5_sy[0], c5_sy[1])\n", "\n", "data_all = [c1_mt[0],c2_mm[0], c3_up[0], c4_sv[0], c5_sy[0]]" ] }, { "cell_type": "markdown", "metadata": { "id": "Th2GBdjoBh1X" }, "source": [ "## 3. Define the MK-MMD loss (guassian kernel)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gjz5bZS0Bh1X" }, "outputs": [], "source": [ "class MMD_loss(nn.Module):\n", " def __init__(self, kernel_mul = 2.0, kernel_num = 5):\n", " super(MMD_loss, self).__init__()\n", " self.kernel_num = kernel_num\n", " self.kernel_mul = kernel_mul\n", " self.fix_sigma = None\n", " return\n", "\n", " def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):\n", " n_samples = int(source.size()[0])+int(target.size()[0])\n", " total = torch.cat([source, target], dim=0)\n", "\n", " total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))\n", " total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))\n", " L2_distance = ((total0-total1)**2).sum(2) \n", " if fix_sigma:\n", " bandwidth = fix_sigma\n", " else:\n", " bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)\n", " bandwidth /= kernel_mul ** (kernel_num // 2)\n", " bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]\n", " kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]\n", " return sum(kernel_val)\n", "\n", " def forward(self, source, target):\n", " batch_size = int(source.size()[0])\n", " kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)\n", " XX = kernels[:batch_size, :batch_size]\n", " YY = kernels[batch_size:, batch_size:]\n", " XY = kernels[:batch_size, batch_size:]\n", " YX = kernels[batch_size:, :batch_size]\n", " loss = torch.mean(XX + YY - XY -YX)\n", " return loss" ] }, { "cell_type": "markdown", "metadata": { "id": "bw4hooRKBh1Y" }, "source": [ "## 4. Define the model architecture following the Reverse Layer" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_dH5urjaBh1Y" }, "outputs": [], "source": [ "class ReverseLayerF(Function):\n", "\n", " @staticmethod\n", " def forward(ctx, x, alpha):\n", " ctx.alpha = alpha\n", "\n", " return x.view_as(x)\n", "\n", " @staticmethod\n", " def backward(ctx, grad_output):\n", " output = grad_output.neg() * ctx.alpha\n", "\n", " return output, None\n", "\n", "class CNNModel(nn.Module):\n", "\n", " def __init__(self):\n", " super(CNNModel, self).__init__()\n", " self.feature = nn.Sequential()\n", " self.feature.add_module('f_conv1', nn.Conv2d(3, 64, kernel_size=5))\n", " self.feature.add_module('f_bn1', nn.BatchNorm2d(64))\n", " self.feature.add_module('f_pool1', nn.MaxPool2d(2))\n", " self.feature.add_module('f_relu1', nn.ReLU(True))\n", " self.feature.add_module('f_conv2', nn.Conv2d(64, 50, kernel_size=5))\n", " self.feature.add_module('f_bn2', nn.BatchNorm2d(50))\n", " self.feature.add_module('f_drop1', nn.Dropout2d())\n", " self.feature.add_module('f_pool2', nn.MaxPool2d(2))\n", " self.feature.add_module('f_relu2', nn.ReLU(True))\n", " \n", " self.class_classifier = nn.Sequential()\n", " self.class_classifier.add_module('c_fc1', nn.Linear(50 * 4 * 4, 100))\n", " self.class_classifier.add_module('c_bn1', nn.BatchNorm1d(100))\n", " self.class_classifier.add_module('c_relu1', nn.ReLU(True))\n", " self.class_classifier.add_module('c_fc3', nn.Linear(100, 10))\n", " self.class_classifier.add_module('c_softmax', nn.LogSoftmax())\n", " \n", " self.domain_classifier = nn.Sequential()\n", " self.domain_classifier.add_module('d_fc1', nn.Linear(50 * 4 * 4, 100))\n", " self.domain_classifier.add_module('d_bn1', nn.BatchNorm1d(100))\n", " self.domain_classifier.add_module('d_relu1', nn.ReLU(True))\n", " self.domain_classifier.add_module('d_fc2', nn.Linear(100, 2))\n", " self.domain_classifier.add_module('d_softmax', nn.LogSoftmax(dim=1))\n", "\n", " def forward(self, input_data, alpha):\n", " input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28)\n", " feature = self.feature(input_data)\n", " feature = feature.view(-1, 50 * 4 * 4)\n", " reverse_feature = ReverseLayerF.apply(feature, alpha)\n", " class_output = self.class_classifier(feature)\n", " domain_output = self.domain_classifier(reverse_feature)\n", "\n", " return class_output, domain_output" ] }, { "cell_type": "markdown", "metadata": { "id": "YAZ2eOkbBh1Y" }, "source": [ "## 5. Federated Knowledge Alignment (FedKA) " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Su0klLOiBh1Z", "scrolled": true }, "outputs": [], "source": [ "from tqdm import notebook\n", "\n", "def run(method, voting):\n", " learned_models = []\n", " global_net = CNNModel()\n", " global_optimizer = optim.Adam(global_net.parameters(), lr=lr) \n", " clients = []\n", " optims = []\n", " client_num = 4\n", " for n in range(client_num):\n", " local_net = CNNModel()\n", " local_optimizer = optim.Adam(local_net.parameters(), lr=lr) \n", " clients.append(local_net)\n", " optims.append(local_optimizer)\n", "\n", " loss_class = torch.nn.NLLLoss()\n", " loss_domain = torch.nn.NLLLoss()\n", " loss_class = loss_class.cuda()\n", " loss_domain = loss_domain.cuda()\n", " loss_mmd = MMD_loss() \n", " \n", " acc_list = []\n", " for epoch in notebook.tqdm(range(n_epoch)):\n", " print(f\"===========Round {epoch} ===========\")\n", " if cuda:\n", " global_net =global_net.cuda()\n", "\n", " data_target_iter = iter(cloud_dataset[0]) \n", " loss_epoch = []\n", " \n", " acc = []\n", " for n in range(client_num):\n", " # Enumerating batches from the dataloader provides a random selection of 512 samples every round.\n", " for i, (s_img, s_label) in enumerate(clients_datasets[n]): \n", " len_dataloader = 32\n", " if i > 31: \n", " break\n", "\n", " p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader\n", " alpha = 2. / (1. + np.exp(-5 * p)) - 1\n", "\n", " optims[n].zero_grad()\n", " batch_size = len(s_label)\n", "\n", " input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)\n", " class_label = torch.LongTensor(batch_size)\n", " domain_label = torch.zeros(batch_size)\n", " domain_label = domain_label.long()\n", "\n", " if cuda:\n", " clients[n] = clients[n].cuda()\n", " s_img = s_img.cuda()\n", " s_label = s_label.cuda()\n", " input_img = input_img.cuda()\n", " class_label = class_label.cuda()\n", " domain_label = domain_label.cuda()\n", "\n", " input_img.resize_as_(s_img).copy_(s_img)\n", " class_label.resize_as_(s_label).copy_(s_label)\n", " class_output, domain_output = clients[n](input_data=input_img, alpha=alpha)\n", "\n", " err_s_label = loss_class(class_output, class_label)\n", " err_s_domain = loss_domain(domain_output, domain_label)\n", "\n", " t_img, _ = data_target_iter.next()\n", " batch_size = len(t_img)\n", " input_img = torch.FloatTensor(batch_size, 3, image_size, image_size)\n", " domain_label = torch.ones(batch_size)\n", " domain_label = domain_label.long()\n", "\n", " if cuda:\n", " t_img = t_img.cuda()\n", " input_img = input_img.cuda()\n", " domain_label = domain_label.cuda()\n", "\n", " input_img.resize_as_(t_img).copy_(t_img)\n", "\n", " _, domain_output = clients[n](input_data=input_img, alpha=alpha)\n", " err_t_domain = loss_domain(domain_output, domain_label)\n", "\n", " if method == 0 or method == 3:\n", " err = err_s_label\n", " else:\n", " err = err_s_label + err_s_domain + err_t_domain\n", "\n", " err.backward()\n", " optims[n].step()\n", " \n", " # mmd loss\n", " mmd_loss_total = 0\n", " for i, (s_img, s_label) in enumerate(clients_datasets[n]): \n", " len_dataloader = 32\n", " if i > 31: \n", " break\n", "\n", " p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader\n", " alpha = 2. / (1. + np.exp(-5* p)) - 1\n", " batch_size = len(s_label)\n", " t_img, _ = data_target_iter.next()\n", " s_img = s_img.cuda()\n", " t_img = t_img.cuda()\n", "\n", " hidden_c = clients[n].feature(s_img).reshape(batch_size, -1)\n", " hidden_avg = global_net.feature(t_img).reshape(batch_size, -1)\n", " mmd_loss = loss_mmd(hidden_c, hidden_avg)\n", " mmd_loss_total =+mmd_loss*alpha\n", " \n", " if method == 2 or method == 3:\n", " if (i+1) % 8 == 0:\n", " optims[n].zero_grad()\n", " err_mmd = mmd_loss_total/8\n", " err_mmd.backward()\n", " optims[n].step()\n", " mmd_loss_total = 0 \n", " \n", " # FedAvg\n", " global_sd = global_net.state_dict()\n", " for key in global_sd:\n", " global_sd[key] = torch.sum(torch.stack([model.state_dict()[key] for m, model in enumerate(clients)]), axis = 0)/client_num\n", " # update the global model\n", " global_net.load_state_dict(global_sd) \n", " \n", " \n", " if voting:\n", " total = 0\n", " num_correct = 0\n", " for i, (images, labels) in enumerate(cloud_dataset[0]):\n", " len_dataloader = 128\n", " if i > 127: \n", " break\n", "\n", " p = float(i + epoch * len_dataloader) / n_epoch / len_dataloader\n", " alpha = 2. / (1. + np.exp(-5* p)) - 1\n", "\n", " images =images.cuda()\n", " labels = labels.cuda()\n", " global_optimizer.zero_grad()\n", " class_output, _ = global_net(images, 0)\n", "\n", " votes = []\n", " for n in range(client_num):\n", " clients[n] = clients[n].cuda()\n", " output,_ = clients[n](images, 0)\n", " pred = torch.argmax(output, 1)\n", " votes.append(pred)\n", "\n", " class_label = torch.Tensor([int(max(set(batch), key = batch.count).cpu().data.numpy()) for batch in [list(i) for i in zip(*votes)]]).to(torch.int64)\n", " class_label = class_label.cuda()\n", "\n", " total += labels.size(0)\n", " num_correct += (class_label == labels).sum().item()\n", "\n", " err = loss_class(class_output, class_label)*alpha\n", " err.backward()\n", " global_optimizer.step()\n", " \n", " print(f'Voting accuracy: {num_correct * 100 / total}% Adoption rate: {alpha*100}%')\n", " \n", " \n", " # Evaluation every round \n", " # Target task \n", " with torch.no_grad():\n", " num_correct = 0\n", " total = 0\n", "\n", " for i, (images, labels) in enumerate(cloud_dataset[0]):\n", " if i > 312:\n", " break\n", " \n", " if cuda:\n", " global_net = global_net.cuda()\n", " images =images.cuda()\n", " labels = labels.cuda()\n", " \n", " output,_ = global_net(images, 0)\n", " pred = torch.argmax(output, 1)\n", " total += labels.size(0)\n", " num_correct += (pred == labels).sum().item()\n", " \n", " print(f'Global: Accuracy of the model on {total} test images: {num_correct * 100 / total}% \\n')\n", " acc.append(num_correct * 100 / total)\n", " \n", " acc_list.append(acc)\n", "\n", " for n in range(client_num):\n", " clients[n].load_state_dict(global_sd)\n", "\n", " return acc_list" ] }, { "cell_type": "markdown", "metadata": { "id": "aS8nJ2sdBh1a" }, "source": [ "## 6. Run experiments\n", "\n", "### Method 0: Source Only\n", "\n", "### Method 1: $f$-DANN\n", "\n", "### Method 2: FedKA" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eeqIdMAXBh1b", "outputId": "38dd4827-bd80-4f6d-9112-82dda29ab28c" }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "methods = [2]\n", "voting = True\n", "\n", "# run 5 tasks\n", "for t in range(5):\n", " acc = []\n", " clients_datasets = [c1, c2, c3, c4, c5]\n", " cloud_dataset = [clients_datasets.pop(t)]\n", " target = f\"c{t+1}\"\n", "\n", " # use three seeds\n", " for s in range(3):\n", " torch.manual_seed(s)\n", " random.seed(s)\n", " np.random.seed(s)\n", " \n", " acc_m = []\n", " for method in methods:\n", "\n", " print(f\"Task: c{t+1} Seed: {s} Method: {method}\")\n", "\n", " evl = run(method, voting)\n", " \n", " result = np.array((evl)).T\n", " plt.plot(result[0], label = \"Global\", color = 'C4')\n", " plt.legend()\n", " plt.show()\n", "\n", " acc_m.append(max(np.array((evl))[:,0]))\n", " acc.append(acc_m)\n", " \n", " print(f'Task: c{t+1} Mean: {np.mean((np.array((acc)).T), axis =1)}')\n", " print(f'Task: c{t+1} Std: {np.std((np.array((acc)).T), axis =1)}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "name": "proposal.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.11" } }, "nbformat": 4, "nbformat_minor": 1 }