import math import pickle import sys import torch from torch import nn, optim from torch.utils.data import DataLoader import torch.nn.functional as F import lightning as L from lightning.pytorch.tuner import Tuner from lightning.pytorch.callbacks import LearningRateMonitor import wandb from pytorch_lightning.loggers import WandbLogger from transformers import CLIPTokenizer, CLIPTextModelWithProjection class SoftAttention(L.LightningModule): def __init__(self, learning_rate=0.001, batch_size=10, unfreeze=0, random_text=False, random_everything=False, fixed_text=False, random_images=False): super(SoftAttention, self).__init__() self.my_optimizer = None self.my_scheduler = None self.save_hyperparameters() self.learning_rate = learning_rate self.batch_size = batch_size self.frozen = False self.unfreeze_epoch = unfreeze self.loss_method = torch.nn.CrossEntropyLoss() self.train_sum_precision = 0 self.train_sum_accuracy = 0 self.train_sum_recall = 0 self.train_sum_runs = 0 self.val_sum_precision = 0 self.val_sum_accuracy = 0 self.val_sum_recall = 0 self.val_sum_runs = 0 # NETWORK # Linear layers to reduce dimensionality self.text_reduction = torch.nn.Linear(512, 256) self.image_reduction = torch.nn.Linear(512, 256) # Soft attention weights self.W_query_text_half_dim = torch.nn.Linear(256, 256) self.W_query_image_half_dim = torch.nn.Linear(256, 256) self.W_query_text_full_dim = torch.nn.Linear(512, 512) self.W_query_image_full_dim = torch.nn.Linear(512, 512) self.W_key_text_half_dim = torch.nn.Linear(256, 256) self.W_key_image_half_dim = torch.nn.Linear(256, 256) self.W_key_image_full_dim = torch.nn.Linear(512, 512) self.W_key_text_full_dim = torch.nn.Linear(512, 512) # TO TEST THE MODEL WITH SAME TEXT self.fixed_text = torch.tensor([2.2875e-01, 2.3762e-02, 1.3448e-01, 6.5997e-02, 2.5605e-01, -1.6183e-01, 7.1169e-03, -1.6895e+00, 1.8110e-01, 1.7249e-01, 7.0582e-02, -6.3566e-02, -1.5862e-01, -2.3586e-01, 6.9382e-02, 9.4649e-02, 6.3127e-01, -4.1287e-02, -4.9883e-02, -2.1821e-01, 5.8677e-01, -2.5353e-01, 1.4792e-01, 2.2195e-02, -6.8436e-02, -1.5512e-01, -9.8894e-02, 6.3377e-02, -2.3078e-01, 9.3588e-02, 5.2875e-02, -5.1388e-01, -7.0461e-02, 2.4253e-02, -7.8069e-02, 7.6921e-02, -1.1610e-01, -1.3345e-01, 7.8038e-03, -2.0226e-01, 1.1381e-01, -9.6335e-02, -2.2195e-02, -6.5028e-02, 1.4025e-01, 2.6969e-01, -1.0758e-01, 3.6736e-02, 3.2893e-01, -1.9067e-01, 4.9070e-02, 8.0207e-02, 7.2942e-02, 7.7496e-03, 2.0883e-01, 1.7339e-01, 1.0072e-01, -1.7874e-01, -4.6898e-02, -6.2682e-02, 5.9596e-02, 5.2925e-02, 2.4633e-01, -7.2811e-02, -1.4157e-01, 8.8013e-03, -4.6815e-02, -7.4260e-02, 8.6530e-03, -1.8174e-01, 1.6101e-01, -4.8832e-02, -5.8030e-02, -3.2518e-02, -6.2896e-02, -2.3472e-01, -8.0996e-02, 1.1261e-01, -2.1039e-01, -2.3837e-01, -2.6827e-02, -2.3075e-01, -2.2087e-02, 5.4009e-01, 3.7671e-02, 3.3140e-01, -4.2569e-02, -1.6946e-01, 1.7165e-01, 3.0887e-01, 4.9847e-02, 1.2438e-02, -2.0701e+00, 2.7104e-01, 1.9001e-01, 3.1907e-01, -9.1116e-02, -8.3141e-02, 4.5765e-03, -2.5675e-01, -2.2119e-02, 3.4949e-02, 2.8192e-01, 7.9688e-02, -2.1810e-01, 8.1565e-02, 3.3208e-01, -9.1857e-02, -2.1145e-01, -1.6843e-01, 6.7942e-02, 5.1067e-01, -1.6835e-01, 2.2090e-02, 1.8061e-02, -2.1313e-01, 2.6867e-02, -2.2734e-01, 8.4164e-02, -4.7868e-02, 2.0980e-02, -2.1424e-01, -2.2919e-02, 1.7554e-01, 5.2253e-02, -2.2049e-01, 6.9408e-02, 7.0811e-02, -1.1892e-02, -4.7958e-02, 7.9476e-02, 1.8851e-01, 2.2516e-02, 8.6119e+00, -7.8583e-02, 1.0218e-01, 1.6675e-01, -4.0961e-01, 4.5291e-02, 7.9783e-02, -1.1764e-01, -2.3162e-01, -2.7717e-02, 1.2963e-01, -3.0165e-01, -2.1588e-02, -1.2324e-01, 1.9732e-02, -1.9312e-01, -7.1229e-02, 2.5102e-01, -4.1674e-01, -1.5610e-01, -6.1321e-03, -4.5332e-02, 6.1500e-02, -1.5942e-01, 3.5142e-01, -2.1119e-01, 4.5057e-02, -5.6277e-02, -3.4298e-01, -1.6499e-01, -2.9384e-02, -2.7163e-01, 6.5339e-03, 2.7674e-02, -1.1302e-01, -2.6373e-02, -1.4370e-01, 2.1936e-01, 1.3103e-01, 2.5538e-01, 1.9502e-01, -1.5278e-01, 1.4978e-01, -2.5552e-01, 2.2397e-01, -1.0369e-01, -1.0491e-01, 5.1112e-01, 2.4879e-01, 7.0940e-02, 1.7351e-01, -3.6831e-02, 1.5027e-01, -1.9452e-01, 2.0322e-01, 8.5931e-02, -2.8588e-03, 3.1146e-02, -3.3307e-01, 1.1595e-01, 1.9435e-01, -3.4536e-02, 2.5245e-01, 4.5388e-02, 2.1197e-02, 4.2232e-02, 4.2436e-02, 4.9622e-02, -2.0907e-01, 1.2264e-01, -7.3529e-02, -2.1788e-01, -1.2429e-01, -8.1422e-02, 1.6572e-01, -6.0989e-02, 8.0322e-02, 3.3477e-01, -7.2207e-02, -8.8658e-02, -2.4944e-01, 9.9211e-02, 8.6244e-02, 8.8807e-02, -1.9676e-01, -4.5365e-03, -3.7754e-01, -1.7204e-01, -1.3001e-01, 6.4961e-02, -5.8192e-03, 2.4670e-01, -8.3591e-02, -3.0810e-01, -3.4549e-02, -1.4452e-01, -5.5416e-02, 1.0527e-02, 3.1159e-01, -1.3857e-01, -2.2676e-01, 1.4768e-01, 3.2650e-01, 2.3971e-01, 6.8196e-02, -2.6235e-02, -2.9741e-01, 4.7721e-02, -1.2859e-02, 2.0340e-01, 1.7823e-02, -1.1337e-01, 4.4077e-02, -1.3949e-01, 2.9229e-01, 1.7425e-01, -5.0722e-03, -6.3722e-02, 1.0181e-01, 2.3344e-02, 2.2200e-01, 3.5022e-02, 1.5361e-01, -1.0702e-03, 2.9319e-02, 1.8938e-01, -7.2263e-02, 2.2192e-02, 9.5394e-02, -4.4459e-03, 7.6698e-02, -1.7830e-01, 1.0213e-01, -8.8493e-02, -1.6439e-01, -1.1085e-01, 1.2938e-01, 2.3929e-01, -4.9047e-02, -1.2814e-01, -2.1075e-01, 2.4423e-01, -4.4565e-02, -5.1225e-02, -4.0214e-02, -1.4033e-01, 6.3284e-02, 4.7094e-01, -2.6821e-02, 2.1138e-02, 1.1590e-01, -2.0023e-02, 1.7200e-01, 3.8215e-01, -2.4871e-01, -1.5359e-01, 2.4691e-01, 1.4904e-01, -1.0636e-01, 2.4185e-01, 1.7119e-03, 1.4618e-01, -1.6813e-01, -4.4372e-01, -1.7475e-01, -6.9891e-02, -4.5553e-02, 9.3102e-02, 1.7686e-02, -1.1781e-01, 6.9423e-02, 1.0211e-02, 3.2742e-01, 7.5272e-02, 8.5080e-02, -1.7731e-01, 1.4030e-01, 2.7764e-01, -6.5041e-02, 8.5968e+00, 2.5900e-01, -2.0825e-01, 9.6241e-02, -1.5257e-01, -3.4269e-01, -1.1251e-01, 3.0549e-01, 3.1628e-01, 6.1856e-01, 1.5791e-03, 6.5656e-02, 1.8862e-02, -7.1927e-02, 1.3239e-01, -1.1126e-01, 1.1135e-02, -3.2411e+00, -4.7349e-02, 1.4775e-01, -9.7712e-02, 4.5727e-02, -1.3868e-01, 2.1260e-01, 1.5465e-01, 1.1308e-01, -8.0110e-02, -1.3123e-01, 1.8527e-01, -8.6424e-02, -1.9778e-01, -1.3295e-01, -1.5880e-01, 2.0800e-01, -3.6513e-02, 2.6472e-02, 2.7275e-01, 1.8995e-01, -7.7340e-02, 1.2059e-02, 3.5163e-02, 1.5442e-02, 5.1417e-02, 5.0993e-01, 1.2994e-01, 2.3873e-01, -7.2816e-02, 1.5850e-01, -2.0404e-01, -2.2941e-01, 2.3660e-01, 2.0418e-01, 6.7775e-02, -3.9195e-01, 3.6655e-01, 1.6498e-01, 6.4065e-02, 4.9579e-02, 2.8265e-01, -5.9919e-03, 4.0163e-02, 8.9072e-02, 1.5125e-01, 9.0711e-02, -1.2608e-01, -1.0413e-01, -2.1931e-01, 5.0183e-02, -3.4841e-02, -8.1449e-02, -1.1225e-01, -4.5787e-02, -7.8871e-02, 3.8858e-02, 9.2660e-02, 1.5991e-01, -6.7528e-02, -6.3166e-02, -4.7824e-03, -1.3528e-01, 1.4845e-01, 2.0460e-01, -9.3238e-02, 1.4902e-03, 1.1896e-01, -3.1337e-01, 2.1637e-02, 1.4990e-01, -2.1179e-03, -8.1374e-02, -1.0241e-01, -8.0754e-02, -1.4449e-01, -1.3549e-01, -7.5588e-02, -8.0083e-02, -1.4114e-01, 2.9467e-03, 3.5340e-01, -4.3351e-02, 9.6934e-02, 1.3625e-01, 1.3339e-01, -1.2059e-02, -1.4325e-01, -2.1202e-01, 3.8758e-02, 2.5965e-01, -7.8454e-02, 1.5983e-01, 1.0115e-02, 2.2192e-01, -1.4043e-01, 6.7966e-02, -1.4672e-01, -1.8846e-01, 1.9488e-01, 1.2942e-01, -1.3165e-02, -1.6099e-01, -9.6146e-02, 1.3439e-01, -5.0560e-02, 8.2779e-02, -2.4827e-01, -7.8047e-02, -3.1163e-01, -1.7481e-01, 2.1450e-01, -7.6112e-02, -1.9967e-02, 5.7099e-02, 7.7664e-02, -7.9647e-02, 3.3941e-02, 2.9551e-02, 1.4231e-01, 2.3480e-02, 1.5209e-01, -2.0011e-01, 1.1153e-01, 1.2694e-01, 8.7853e-02, 2.6997e-01, 1.3525e-01, 1.9541e-01, 3.4429e-03, -9.6446e-02, 7.6708e-02, -3.0698e-02, -1.8507e-01, 2.5645e-01, 2.8182e-01, -1.2282e-01, -1.1017e-01, 2.2249e-01, 2.1966e-01, 3.5795e-01, 1.6279e-01, 1.7276e-01, 2.1410e-01, -3.2499e-01, 5.0327e-02, 7.9813e-02, -1.5915e-01, -3.6175e-02, 1.4376e-01, 2.9565e-01, 6.9097e-02, -8.0661e-01, 4.9966e-02, 6.2506e-02, 1.8852e-02, -8.6921e-02, 6.0899e-02, 2.2442e-01, -1.4272e-01, -4.0656e-04, -1.2531e-01, 1.5240e-01, -6.8841e-02, 4.2114e-01, -4.4379e-02, -3.5105e-02, 1.4931e-01, -8.3358e-02, -1.0498e-01, 1.4575e-01, -1.6491e-01, 4.7820e-02, 2.5958e-01, 1.1974e-01, 1.8271e-01, 1.7439e-02, -1.5855e-01, -9.0135e-02, -2.6199e-01, -2.5709e-01, 6.3203e-03, 7.5823e-02]) self.random_text_flag = random_text self.random_everything_flag = random_everything self.fixed_text_flag = fixed_text self.random_image_flag = random_images # Weight Stacks self.W_query = { "multimodal": [self.text_reduction, self.image_reduction, self.W_query_text_half_dim, self.W_query_image_half_dim], "image": [self.W_query_image_full_dim], } self.W_key = { "multimodal": [self.text_reduction, self.image_reduction, self.W_key_text_half_dim, self.W_key_image_half_dim], "image": [self.W_key_image_full_dim] } def weight_pass(self, query_text, query_image, key_text, key_image): inference_functions = [ (True, True, True, True), # Input: text and image Context: text and image (False, True, False, True), # Input: image Context: image (False, True, True, True) # Input: image Context: text and image ] if None in (query_image, key_image): raise ValueError("Query and Key image cannot be None") if (query_text is not None, query_image is not None, key_text is not None, key_image is not None) in inference_functions: query = self._queries_inference(query_text, query_image) key = self._keys_inference(key_text, key_image) return query, key else: raise ValueError("Invalid input") def _queries_inference(self, query_text, query_image): if query_text is None: output = self.W_query_image_full_dim(query_image) elif query_image is None: raise ValueError("Query image cannot be None") else: text_reduction = self.text_reduction(query_text) image_reduction = self.image_reduction(query_image) query_text_half_dim = self.W_query_text_half_dim(text_reduction) query_image_half_dim = self.W_query_image_half_dim(image_reduction) output = torch.cat((query_text_half_dim, query_image_half_dim), dim=-1) return output def _keys_inference(self, key_text, key_image): if key_text is None: output = self.W_key_image_full_dim(key_image) elif key_image is None: raise ValueError("Key image cannot be None") else: text_reduction = self.text_reduction(key_text) image_reduction = self.image_reduction(key_image) key_text_half_dim = self.W_key_text_half_dim(text_reduction) key_image_half_dim = self.W_key_image_half_dim(image_reduction) output = torch.cat((key_text_half_dim, key_image_half_dim), dim=-1) return output def forward(self, query_text, query_image, key_text, key_image): query_text = query_text.to(self.device) query_image = query_image.to(self.device) key_text = key_text.to(self.device) key_image = key_image.to(self.device) query, key = self.weight_pass(query_text, query_image, key_text, key_image) d_k = key.size()[-1] # Get the size of the key key_transposed = key.transpose(1, 2) logits = torch.matmul(query, key_transposed) / math.sqrt(d_k) logits = logits.squeeze() if len(logits.shape) <= 2: softmax = F.softmax(logits, dim=0) else: softmax = F.softmax(logits, dim=1) return softmax, logits def training_step(self, train_batch, batch_idx): if self.current_epoch == 0 and not self.frozen and self.unfreeze_epoch != 0: print("Freezing....................................................") for param in self.image_reduction.parameters(): param.requires_grad = False self.frozen = True if self.current_epoch == self.unfreeze_epoch and self.frozen: print("Unfreezing....................................................") for param in self.image_reduction.parameters(): param.requires_grad = True self.frozen = False # Unpack the batch data queries = train_batch['queries'] keys = train_batch['keys'] real_labels = train_batch['real_index'] keys_text = [] keys_image = [] for batch in keys: temp_key_text = [] temp_key_image = [] for key_text, key_image in batch: temp_key_text.append(key_text) temp_key_image.append(key_image) keys_text.append(torch.stack(temp_key_text)) keys_image.append(torch.stack(temp_key_image)) queries_text = [] queries_image = [] for batch in queries: temp_query_text = [] temp_query_image = [] for query_text, query_image in batch: temp_query_text.append(query_text) temp_query_image.append(query_image) queries_text.append(torch.stack(temp_query_text)) queries_image.append(torch.stack(temp_query_image)) queries_text = torch.stack(queries_text) queries_image = torch.stack(queries_image) keys_text = torch.stack(keys_text) keys_image = torch.stack(keys_image) if self.fixed_text_flag: print("Fixed text flag") queries_text_shape = queries_text.shape keys_text_shape = keys_text.shape queries_text = self.fixed_text.expand(*queries_text_shape).to(queries_text.device) keys_text = self.fixed_text.expand(*keys_text_shape).to(keys_text.device) if self.random_text_flag: print("Random text flag") old_queries_text = queries_text.clone() old_keys_text = keys_text.clone() queries_text = torch.randn(queries_text.shape).to(queries_text.device) keys_text = torch.randn(keys_text.shape).to(keys_text.device) if torch.equal(queries_text, old_queries_text): print("Queries text are equal") if torch.equal(keys_text, old_keys_text): print("Keys text are equal") if self.random_image_flag: print("Random image flag") old_queries_image = queries_image.clone() old_keys_image = keys_image.clone() queries_image = torch.randn(queries_image.shape).to(queries_image.device) keys_image = torch.randn(keys_image.shape).to(keys_image.device) if torch.equal(queries_image, old_queries_image): print("Queries image are equal") if torch.equal(keys_image, old_keys_image): print("Keys image are equal") if self.random_everything_flag: print("Random everything flag") old_queries_text = queries_text.clone() old_keys_text = keys_text.clone() old_queries_image = queries_image.clone() old_keys_image = keys_image.clone() queries_text = torch.randn(queries_text.shape).to(queries_text.device) keys_text = torch.randn(keys_text.shape).to(keys_text.device) queries_image = torch.randn(queries_image.shape).to(queries_image.device) keys_image = torch.randn(keys_image.shape).to(keys_image.device) if torch.equal(queries_text, old_queries_text): print("Queries text are equal") if torch.equal(keys_text, old_keys_text): print("Keys text are equal") if torch.equal(queries_image, old_queries_image): print("Queries image are equal") if torch.equal(keys_image, old_keys_image): print("Keys image are equal") # Forward pass softmax, logits = self.forward(queries_text, queries_image, keys_text, keys_image) softmax = softmax.squeeze() real_labels = real_labels.squeeze() logits = logits.squeeze() real_labels = real_labels.float() if real_labels.dim() < 3: real_labels = real_labels.unsqueeze(0) softmax = softmax.unsqueeze(0) logits = logits.unsqueeze(0) temp_real_labels = [] temp_logits = [] global_padding = 0 for batch_l, batch_r in zip(logits, real_labels): padding = torch.nonzero(batch_r[0] == -100) if padding.nelement() == 0: temp_real_labels.append(batch_r) temp_logits.append(batch_l) continue global_padding = global_padding + padding.nelement() padding_index = padding[0] temp_r = batch_r.clone() temp_r[:, padding_index:] = 0 temp_l = batch_l.clone() temp_l[:, padding_index:] = -100 temp_real_labels.append(temp_r) temp_logits.append(temp_l) for_loss_real_labels = torch.stack(temp_real_labels).float() for_loss_logits = torch.stack(temp_logits) loss = self.loss_method(for_loss_logits.mT, for_loss_real_labels.mT) batched_precision = [] batched_accuracy = [] batched_recall = [] for batch_s, batch_r in zip(softmax, real_labels): padding = torch.nonzero(batch_r[0] == -100) if padding.nelement() > 0: padding_index = padding[0] batch_r = batch_r[:, :padding_index] batch_s = batch_s[:, :padding_index] max_indices = batch_s.argmax(dim=0) # print("Max indices: ", max_indices) target_index = batch_r.argmax(dim=0) # print("Target index: ", target_index) subtraction = max_indices - target_index # print("Subtraction: ", subtraction) different_values = torch.count_nonzero(subtraction) # print("Different values: ", different_values) # print("Sample size: ", target_index.shape) # print("Len target index: ", len(target_index)) samples = batch_s.shape[1] * batch_s.shape[0] TP = len(target_index) - different_values FP = different_values FN = different_values TN = samples - TP - FP - FN precision = TP / (TP + FP) accuracy = (TP + TN) / samples recall = TP / (TP + FN) batched_precision.append(precision.item()) batched_accuracy.append(accuracy.item()) batched_recall.append(recall.item()) precision = sum(batched_precision) / len(batched_precision) accuracy = sum(batched_accuracy) / len(batched_accuracy) recall = sum(batched_recall) / len(batched_recall) self.train_sum_precision += precision self.train_sum_accuracy += accuracy self.train_sum_recall += recall self.train_sum_runs += 1 self.log("train_loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True) self.log("train_precision", precision, on_epoch=True, on_step=False, prog_bar=True, logger=True) self.log("train_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True, logger=True) self.log("train_recall", recall, on_epoch=True, on_step=False, prog_bar=True, logger=True) return loss def on_train_epoch_end(self) -> None: self.log("train_precision_epoch", self.train_sum_precision / self.train_sum_runs) self.log("train_accuracy_epoch", self.train_sum_accuracy / self.train_sum_runs) self.log("train_recall_epoch", self.train_sum_recall / self.train_sum_runs) self.train_sum_precision = 0 self.train_sum_accuracy = 0 self.train_sum_recall = 0 self.train_sum_runs = 0 def configure_optimizers(self): self.my_optimizer = torch.optim.Adam(params=self.parameters(), lr=self.learning_rate) optimizer = self.my_optimizer """self.my_scheduler = torch.optim.lr_scheduler.CyclicLR(self.my_optimizer, base_lr=0.01, max_lr=0.05,step_size_up=100,cycle_momentum=False) scheduler = { 'scheduler': self.my_scheduler, 'interval': 'step', 'frequency': 1, 'name': 'learning_rate' }""" return [optimizer] def validation_step(self, val_batch, batch_idx): # Unpack the batch data queries = val_batch['queries'] keys = val_batch['keys'] real_labels = val_batch['real_index'] keys_text = [] keys_image = [] for batch in keys: temp_key_text = [] temp_key_image = [] for key_text, key_image in batch: temp_key_text.append(key_text) temp_key_image.append(key_image) keys_text.append(torch.stack(temp_key_text)) keys_image.append(torch.stack(temp_key_image)) queries_text = [] queries_image = [] for batch in queries: temp_query_text = [] temp_query_image = [] for query_text, query_image in batch: temp_query_text.append(query_text) temp_query_image.append(query_image) queries_text.append(torch.stack(temp_query_text)) queries_image.append(torch.stack(temp_query_image)) queries_text = torch.stack(queries_text) queries_image = torch.stack(queries_image) keys_text = torch.stack(keys_text) keys_image = torch.stack(keys_image) # Forward pass softmax, logits = self.forward(queries_text, queries_image, keys_text, keys_image) softmax = softmax.squeeze() real_labels = real_labels.squeeze() if real_labels.dim() < 3: real_labels = real_labels.unsqueeze(0) softmax = softmax.unsqueeze(0) logits = logits.unsqueeze(0) temp_real_labels = [] temp_logits = [] for batch_l, batch_r in zip(logits, real_labels): padding = torch.nonzero(batch_r[0] == -100) if padding.nelement() == 0: continue padding_index = padding[0] temp_r = batch_r.clone() temp_r[:, padding_index:] = 0 temp_l = batch_l.clone() temp_l[:, padding_index:] = -100 temp_real_labels.append(temp_r) temp_logits.append(temp_l) if padding.nelement() > 0: for_loss_real_labels = torch.stack(temp_real_labels).float() for_loss_logits = torch.stack(temp_logits) loss = self.loss_method(for_loss_logits.mT, for_loss_real_labels.mT) else: loss = self.loss_method(logits.mT, real_labels.mT) if loss < 0: print("Padding: ", padding.nelement()) print("Loss: ", loss) print("Logits: ", logits) print("Real labels: ", real_labels) exit() batched_precision = [] batched_accuracy = [] batched_recall = [] for batch_s, batch_r in zip(softmax, real_labels): padding = torch.nonzero(batch_r[0] == -100) if padding.nelement() > 0: padding_index = padding[0] batch_r = batch_r[:, :padding_index] batch_s = batch_s[:, :padding_index] max_indices = batch_s.argmax(dim=0) # print("Max indices: ", max_indices) target_index = batch_r.argmax(dim=0) # print("Target index: ", target_index) subtraction = max_indices - target_index # print("Subtraction: ", subtraction) different_values = torch.count_nonzero(subtraction) # print("Different values: ", different_values) # print("Sample size: ", target_index.shape) # print("Len target index: ", len(target_index)) samples = batch_s.shape[1] * batch_s.shape[0] TP = len(target_index) - different_values FP = different_values FN = different_values TN = samples - TP - FP - FN precision = TP / (TP + FP) accuracy = (TP + TN) / samples recall = TP / (TP + FN) batched_precision.append(precision.item()) batched_accuracy.append(accuracy.item()) batched_recall.append(recall.item()) precision = sum(batched_precision) / len(batched_precision) accuracy = sum(batched_accuracy) / len(batched_accuracy) recall = sum(batched_recall) / len(batched_recall) self.val_sum_precision += precision self.val_sum_accuracy += accuracy self.val_sum_recall += recall self.val_sum_runs += 1 self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True, logger=True) self.log("val_precision", precision, on_epoch=True, on_step=False, prog_bar=True, logger=True) self.log("val_accuracy", accuracy, on_epoch=True, on_step=False, prog_bar=True, logger=True) self.log("val_recall", recall, on_epoch=True, on_step=False, prog_bar=True, logger=True) def on_validation_epoch_end(self) -> None: self.log("val_precision_epoch", self.val_sum_precision / self.val_sum_runs) self.log("val_accuracy_epoch", self.val_sum_accuracy / self.val_sum_runs) self.log("val_recall_epoch", self.val_sum_recall / self.val_sum_runs) self.val_sum_precision = 0 self.val_sum_accuracy = 0 self.val_sum_recall = 0 self.val_sum_runs = 0 if __name__ == '__main__': if len(sys.argv) > 1: print("Using arguments") batch_size = int(sys.argv[1]) learning_rate = float(sys.argv[2]) epochs = int(sys.argv[3]) if sys.argv[4] == "True": wandb_flag = True else: wandb_flag = False if sys.argv[5] == "True": find_lr = True else: find_lr = False unfreeze = int(sys.argv[6]) else: print("Using default values") batch_size = 500 learning_rate = 0.01 epochs = 50 wandb_flag = True find_lr = False unfreeze = 10 random_text = False random_everything = False random_images = False fixed_text = False print("Batch size: ", batch_size) print("Learning rate: ", learning_rate) print("Epochs: ", epochs) print("Wandb flag: ", wandb_flag) print("Find lr: ", find_lr) print("Unfreeze: ", unfreeze) train_path = "./recipe_dataset_3500_real_1.pkl" val_path = "./recipe_dataset_3500_real_2.pkl" train = pickle.load(open(train_path, "rb")) val = pickle.load(open(val_path, "rb")) if "wrong" in train_path and "wrong" in val_path: print("Using dataset with false positives") string_wrong = "WRONG_" elif "wrong" in train_path or "wrong" in val_path: raise ValueError("One of the datasets is wrong") else: print("Using normal dataset") string_wrong = "" if random_text: string_wrong += "RANDOM_TEXT_" elif random_everything: string_wrong += "RANDOM_EVERYTHING_" elif random_images: string_wrong += "RANDOM_IMAGES_" elif fixed_text: string_wrong += "FIXED_TEXT_" # remove fields that are not needed for batch in train: batch.pop('ids_queries') batch.pop('ids_keys') for batch in val: batch.pop('ids_queries') batch.pop('ids_keys') train_dataset = DataLoader(train, num_workers=0, shuffle=False, batch_size=batch_size) print("Train dataset size:", len(train_dataset)) val_dataset = DataLoader(val, num_workers=0, shuffle=False, batch_size=batch_size) print("Val dataset size:", len(val_dataset)) model = SoftAttention(learning_rate=learning_rate, batch_size=batch_size, unfreeze=unfreeze, random_text=random_text, random_everything=random_everything, fixed_text=fixed_text, random_images=random_images) lr_monitor = LearningRateMonitor(logging_interval='step') if wandb_flag: run_name = f"{string_wrong}MORE_RECIPES_{len(train_dataset)}_batch_{batch_size}_lr_{learning_rate}_epochs_{epochs}_unfreeze_{unfreeze}" wandb_logger = WandbLogger(project='reference_training', name=run_name, log_model="all") wandb_logger.experiment.config["batch_size"] = batch_size wandb_logger.experiment.config["max_epochs"] = epochs wandb_logger.experiment.config["learning_rate"] = learning_rate trainer = L.Trainer(max_epochs=epochs, detect_anomaly=False, logger=wandb_logger, callbacks=[lr_monitor]) else: trainer = L.Trainer(max_epochs=epochs, default_root_dir="./", callbacks=[lr_monitor]) if find_lr: tuner = Tuner(trainer) lr_finder = tuner.lr_find(model, train_dataloaders=train_dataset, val_dataloaders=val_dataset) print(lr_finder.suggestion()) else: trainer.fit(model, train_dataloaders=train_dataset, val_dataloaders=val_dataset) # trainer.fit(model, train_dataloaders=train_dataset) if wandb_flag: wandb.finish()