diff --git "a/examples/PrithviWxC_rollout.ipynb" "b/examples/PrithviWxC_rollout.ipynb" new file mode 100644--- /dev/null +++ "b/examples/PrithviWxC_rollout.ipynb" @@ -0,0 +1,3670 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PrithviWxC Rollout Inference\n", + "If you haven't already, take a look at the exmaple for the PrithviWxC core\n", + "model, as we will pass over the points covered there.\n", + "\n", + "Here we will introduce the PrithviWxC model that was trained furhter for\n", + "autoregressive rollout, a common strategy to increase accuracy and stability of\n", + "models when applied to forecasting-type tasks. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from huggingface_hub import hf_hub_download, snapshot_download\n", + "\n", + "# Set backend etc.\n", + "torch.jit.enable_onednn_fusion(True)\n", + "if torch.cuda.is_available():\n", + " torch.backends.cudnn.benchmark = True\n", + " torch.backends.cudnn.deterministic = True\n", + "\n", + "# Set seeds\n", + "random.seed(42)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed(42)\n", + "torch.manual_seed(42)\n", + "np.random.seed(42)\n", + "\n", + "# Set device\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + "# Set variables\n", + "surface_vars = [\n", + " \"EFLUX\",\n", + " \"GWETROOT\",\n", + " \"HFLUX\",\n", + " \"LAI\",\n", + " \"LWGAB\",\n", + " \"LWGEM\",\n", + " \"LWTUP\",\n", + " \"PS\",\n", + " \"QV2M\",\n", + " \"SLP\",\n", + " \"SWGNT\",\n", + " \"SWTNT\",\n", + " \"T2M\",\n", + " \"TQI\",\n", + " \"TQL\",\n", + " \"TQV\",\n", + " \"TS\",\n", + " \"U10M\",\n", + " \"V10M\",\n", + " \"Z0M\",\n", + "]\n", + "static_surface_vars = [\"FRACI\", \"FRLAND\", \"FROCEAN\", \"PHIS\"]\n", + "vertical_vars = [\"CLOUD\", \"H\", \"OMEGA\", \"PL\", \"QI\", \"QL\", \"QV\", \"T\", \"U\", \"V\"]\n", + "levels = [\n", + " 34.0,\n", + " 39.0,\n", + " 41.0,\n", + " 43.0,\n", + " 44.0,\n", + " 45.0,\n", + " 48.0,\n", + " 51.0,\n", + " 53.0,\n", + " 56.0,\n", + " 63.0,\n", + " 68.0,\n", + " 71.0,\n", + " 72.0,\n", + "]\n", + "padding = {\"level\": [0, 0], \"lat\": [0, -1], \"lon\": [0, 0]}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Lead time\n", + "When performing auto-regressive rollout, the intermediate steps require the\n", + "static data at those times and---if using `residual=climate`---the intermediate\n", + "climatology. We provide a dataloader that extends the MERRA2 loader of the\n", + "core model, adding in these additional terms. Further, it return target data for\n", + "the intermediate steps if those are required for loss terms. \n", + "\n", + "The `lead_time` flag still lets the target time for the model, however now it\n", + "only a single value and must be a positive integer multiple of the `-input_time`. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "lead_time = 3 # This variable can be change to change the task\n", + "input_time = -3 # This variable can be change to change the task" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data file\n", + "MERRA-2 data is available from 1980 to the present day,\n", + "at 3-hour temporal resolution. The dataloader we have provided\n", + "expects the surface data and vertical data to be saved in\n", + "separate files, and when provided with the directories, will\n", + "search for the relevant data that falls within the provided time range.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "159bec6eee1846d680fe284324094487", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 1 files: 0%| | 0/1 [00:00 dict[str, Tensor]:\n", + " \"\"\"Prepressing function for MERRA2 Dataset\n", + "\n", + " Args:\n", + " batch (dict): List of training samples, each sample should be a\n", + " dictionary with the following keys::\n", + "\n", + " 'sur_static': Numpy array of shape (3, lat, lon). For each pixel (lat, lon), the first dimension indexes sin(lat), cos(lon), sin(lon).\n", + " 'sur_vals': Torch tensor of shape (parameter, time, lat, lon).\n", + " 'sur_tars': Torch tensor of shape (parameter, time, lat, lon).\n", + " 'ulv_vals': Torch tensor of shape (parameter, level, time, lat, lon).\n", + " 'ulv_tars': Torch tensor of shape (parameter, level, time, lat, lon).\n", + " 'sur_climate': Torch tensor of shape (parameter, lat, lon)\n", + " 'ulv_climate': Torch tensor of shape (parameter, level, lat, lon)\n", + " 'lead_time': Integer.\n", + " 'input_time': Integer.\n", + "\n", + " padding: Dictionary with keys 'level', 'lat', 'lon', each of dim 2.\n", + "\n", + " Returns:\n", + " Dictionary with the following keys::\n", + "\n", + " 'x': [batch, time, parameter, lat, lon]\n", + " 'y': [batch, parameter, lat, lon]\n", + " 'static': [batch, parameter, lat, lon]\n", + " 'lead_time': [batch]\n", + " 'input_time': [batch]\n", + " 'climate (Optional)': [batch, parameter, lat, lon]\n", + "\n", + " Note:\n", + " Here, for x and y, 'parameter' is [surface parameter, upper level,\n", + " parameter x level]. Similarly for the static information we have\n", + " [sin(lat), cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod),\n", + " ...].\n", + " \"\"\" # noqa: E501\n", + " b0 = batch[0]\n", + " nbatch = len(batch)\n", + " data_keys = set(b0.keys())\n", + "\n", + " essential_keys = {\n", + " \"sur_static\",\n", + " \"sur_vals\",\n", + " \"sur_tars\",\n", + " \"ulv_vals\",\n", + " \"ulv_tars\",\n", + " \"input_time\",\n", + " \"lead_time\",\n", + " }\n", + "\n", + " climate_keys = {\n", + " \"sur_climate\",\n", + " \"ulv_climate\",\n", + " }\n", + "\n", + " all_keys = essential_keys | climate_keys\n", + "\n", + " if not essential_keys.issubset(data_keys):\n", + " raise ValueError(\"Missing essential keys.\")\n", + "\n", + " if not data_keys.issubset(all_keys):\n", + " raise ValueError(\"Unexpected keys in batch.\")\n", + "\n", + " # Bring all tensors from the batch into a single tensor\n", + " upl_x = torch.empty((nbatch, *b0[\"ulv_vals\"].shape))\n", + " upl_y = torch.empty((nbatch, *b0[\"ulv_tars\"].shape))\n", + "\n", + " sur_x = torch.empty((nbatch, *b0[\"sur_vals\"].shape))\n", + " sur_y = torch.empty((nbatch, *b0[\"sur_tars\"].shape))\n", + "\n", + " sur_sta = torch.empty((nbatch, *b0[\"sur_static\"].shape))\n", + "\n", + " lead_time = torch.empty((nbatch,), dtype=torch.float32)\n", + " input_time = torch.empty((nbatch,), dtype=torch.float32)\n", + "\n", + " for i, rec in enumerate(batch):\n", + " sur_x[i] = rec[\"sur_vals\"]\n", + " sur_y[i] = rec[\"sur_tars\"]\n", + "\n", + " upl_x[i] = rec[\"ulv_vals\"]\n", + " upl_y[i] = rec[\"ulv_tars\"]\n", + "\n", + " sur_sta[i] = rec[\"sur_static\"]\n", + "\n", + " lead_time[i] = rec[\"lead_time\"]\n", + " input_time[i] = rec[\"input_time\"]\n", + "\n", + " return_value = {\n", + " \"lead_time\": lead_time,\n", + " \"input_time\": input_time,\n", + " }\n", + "\n", + " # Reshape (batch, parameter, level, time, lat, lon) ->\n", + " # (batch, time, parameter, level, lat, lon)\n", + " upl_x = upl_x.permute((0, 3, 1, 2, 4, 5))\n", + " upl_y = upl_y.permute((0, 3, 1, 2, 4, 5))\n", + " # Reshape (batch, parameter, time, lat, lon) ->\n", + " # (batch, time, parameter, lat, lon)\n", + " sur_x = sur_x.permute((0, 2, 1, 3, 4))\n", + " sur_y = sur_y.permute((0, 2, 1, 3, 4))\n", + "\n", + " # Pad\n", + " padding_2d = (*padding[\"lon\"], *padding[\"lat\"])\n", + "\n", + " def pad2d(x):\n", + " return torch.nn.functional.pad(x, padding_2d, mode=\"constant\", value=0)\n", + "\n", + " padding_3d = (*padding[\"lon\"], *padding[\"lat\"], *padding[\"level\"])\n", + "\n", + " def pad3d(x):\n", + " return torch.nn.functional.pad(x, padding_3d, mode=\"constant\", value=0)\n", + "\n", + " sur_x = pad2d(sur_x).contiguous()\n", + " upl_x = pad3d(upl_x).contiguous()\n", + " sur_y = pad2d(sur_y).contiguous()\n", + " upl_y = pad3d(upl_y).contiguous()\n", + " return_value[\"static\"] = pad2d(sur_sta).contiguous()\n", + "\n", + " # Remove time for targets\n", + " upl_y = torch.squeeze(upl_y, 1)\n", + " sur_y = torch.squeeze(sur_y, 1)\n", + "\n", + " # We stack along the combined parameter x level dimension\n", + " return_value[\"x\"] = torch.cat(\n", + " (sur_x, upl_x.view(*upl_x.shape[:2], -1, *upl_x.shape[4:])), dim=2\n", + " )\n", + " return_value[\"y\"] = torch.cat(\n", + " (sur_y, upl_y.view(upl_y.shape[0], -1, *upl_y.shape[3:])), dim=1\n", + " )\n", + "\n", + " if climate_keys.issubset(data_keys):\n", + " sur_climate = torch.empty((nbatch, *b0[\"sur_climate\"].shape))\n", + " ulv_climate = torch.empty((nbatch, *b0[\"ulv_climate\"].shape))\n", + " for i, rec in enumerate(batch):\n", + " sur_climate[i] = rec[\"sur_climate\"]\n", + " ulv_climate[i] = rec[\"ulv_climate\"]\n", + " sur_climate = pad2d(sur_climate)\n", + " ulv_climate = pad3d(ulv_climate)\n", + "\n", + " return_value[\"climate\"] = torch.cat(\n", + " (\n", + " sur_climate,\n", + " ulv_climate.view(nbatch, -1, *ulv_climate.shape[3:]),\n", + " ),\n", + " dim=1,\n", + " )\n", + "\n", + " return return_value\n", + "\n", + "\n", + "def input_scalers(\n", + " surf_vars: list[str],\n", + " vert_vars: list[str],\n", + " levels: list[float],\n", + " surf_path: str | Path,\n", + " vert_path: str | Path,\n", + ") -> tuple[Tensor, Tensor]:\n", + " \"\"\"Reads the input scalers\n", + "\n", + " Args:\n", + " surf_vars: surface variables to be used.\n", + " vert_vars: vertical variables to be used.\n", + " levels: MERRA2 levels to use.\n", + " surf_path: path to surface scalers file.\n", + " vert_path: path to vertical level scalers file.\n", + "\n", + " Returns:\n", + " mu (Tensor): mean values\n", + " var (Tensor): varience values\n", + " \"\"\"\n", + " with h5py.File(Path(surf_path), \"r\", libver=\"latest\") as surf_file:\n", + " stats = [x.decode().lower() for x in surf_file[\"statistic\"][()]]\n", + " mu_idx = stats.index(\"mu\")\n", + " sig_idx = stats.index(\"sigma\")\n", + "\n", + " s_mu = torch.tensor([surf_file[k][()][mu_idx] for k in surf_vars])\n", + " s_sig = torch.tensor([surf_file[k][()][sig_idx] for k in surf_vars])\n", + "\n", + " with h5py.File(Path(vert_path), \"r\", libver=\"latest\") as vert_file:\n", + " stats = [x.decode().lower() for x in vert_file[\"statistic\"][()]]\n", + " mu_idx = stats.index(\"mu\")\n", + " sig_idx = stats.index(\"sigma\")\n", + "\n", + " lvl = vert_file[\"lev\"][()]\n", + " l_idx = [np.where(lvl == v)[0].item() for v in levels]\n", + "\n", + " v_mu = np.array([vert_file[k][()][mu_idx, l_idx] for k in vert_vars])\n", + " v_sig = np.array([vert_file[k][()][sig_idx, l_idx] for k in vert_vars])\n", + "\n", + " v_mu = torch.from_numpy(v_mu).view(-1)\n", + " v_sig = torch.from_numpy(v_sig).view(-1)\n", + "\n", + " mu = torch.cat((s_mu, v_mu), dim=0).to(torch.float32)\n", + " sig = torch.cat((s_sig, v_sig), dim=0).to(torch.float32).clamp(1e-4, 1e4)\n", + " return mu, sig\n", + "\n", + "\n", + "def static_input_scalers(\n", + " scalar_path: str | Path, stat_vars: list[str], unscaled_params: int = 7\n", + ") -> tuple[Tensor, Tensor]:\n", + " scalar_path = Path(scalar_path)\n", + "\n", + " with h5py.File(scalar_path, \"r\", libver=\"latest\") as scaler_file:\n", + " stats = [x.decode().lower() for x in scaler_file[\"statistic\"][()]]\n", + " mu_idx = stats.index(\"mu\")\n", + " sig_idx = stats.index(\"sigma\")\n", + "\n", + " mu = torch.tensor([scaler_file[k][()][mu_idx] for k in stat_vars])\n", + " sig = torch.tensor([scaler_file[k][()][sig_idx] for k in stat_vars])\n", + "\n", + " z = torch.zeros(unscaled_params, dtype=mu.dtype, device=mu.device)\n", + " o = torch.ones(unscaled_params, dtype=sig.dtype, device=sig.device)\n", + " mu = torch.cat((z, mu), dim=0).to(torch.float32)\n", + " sig = torch.cat((o, sig), dim=0).to(torch.float32)\n", + "\n", + " return mu, sig.clamp(1e-4, 1e4)\n", + "\n", + "\n", + "def output_scalers(\n", + " surf_vars: list[str],\n", + " vert_vars: list[str],\n", + " levels: list[float],\n", + " surf_path: str | Path,\n", + " vert_path: str | Path,\n", + ") -> Tensor:\n", + " surf_path = Path(surf_path)\n", + " vert_path = Path(vert_path)\n", + "\n", + " with h5py.File(surf_path, \"r\", libver=\"latest\") as surf_file:\n", + " svars = torch.tensor([surf_file[k][()] for k in surf_vars])\n", + "\n", + " with h5py.File(vert_path, \"r\", libver=\"latest\") as vert_file:\n", + " lvl = vert_file[\"lev\"][()]\n", + " l_idx = [np.where(lvl == v)[0].item() for v in levels]\n", + " vvars = np.array([vert_file[k][()][l_idx] for k in vert_vars])\n", + " vvars = torch.from_numpy(vvars).view(-1)\n", + "\n", + " var = torch.cat((svars, vvars), dim=0).to(torch.float32).clamp(1e-7, 1e7)\n", + "\n", + " return var\n", + "\n", + "\n", + "class SampleSpec:\n", + " \"\"\"\n", + " A data class to collect the information used to define a sample.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " inputs: tuple[pd.Timestamp, pd.Timestamp],\n", + " lead_time: int,\n", + " target: pd.Timestamp | list[pd.Timestamp],\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " inputs: Tuple of timestamps. In ascending order.\n", + " lead_time: Lead time. In hours.\n", + " target: Timestamp of the target. Can be before or after the inputs.\n", + " \"\"\"\n", + " if not inputs[0] < inputs[1]:\n", + " raise ValueError(\n", + " \"Timestamps in `inputs` should be in strictly ascending order.\"\n", + " )\n", + "\n", + " self.inputs = inputs\n", + " self.input_time = (inputs[1] - inputs[0]).total_seconds() / 3600\n", + " self.lead_time = lead_time\n", + " self.target = target\n", + "\n", + " self.times = [*inputs, target]\n", + " self.stat_times = [inputs[-1]]\n", + "\n", + " @property\n", + " def climatology_info(self) -> tuple[int, int]:\n", + " \"\"\"Get the required climatology info.\n", + "\n", + " :return: information required to obtain climatology data. Essentially\n", + " this is the day of the year and hour of the day of the target\n", + " timestamp, with the former restricted to the interval [1, 365].\n", + " :rtype: tuple\n", + " \"\"\"\n", + " return (min(self.target.dayofyear, 365), self.target.hour)\n", + "\n", + " @property\n", + " def year(self) -> int:\n", + " return self.inputs[1].year\n", + "\n", + " @property\n", + " def dayofyear(self) -> int:\n", + " return self.inputs[1].dayofyear\n", + "\n", + " @property\n", + " def hourofday(self) -> int:\n", + " return self.inputs[1].hour\n", + "\n", + " def _info_str(self) -> str:\n", + " iso_8601 = \"%Y-%m-%dT%H:%M:%S\"\n", + "\n", + " return (\n", + " f\"Issue time: {self.inputs[1].strftime(iso_8601)}\\n\"\n", + " f\"Lead time: {self.lead_time} hours ahead\\n\"\n", + " f\"Input delta: {self.input_time} hours\\n\"\n", + " f\"Target time: {self.target.strftime(iso_8601)}\"\n", + " )\n", + "\n", + " @classmethod\n", + " def get(cls, timestamp: pd.Timestamp, dt: int, lead_time: int):\n", + " \"\"\"Given a timestamp and lead time, generates a SampleSpec object\n", + " describing the sample further.\n", + "\n", + " Args:\n", + " timestamp: Timstamp of the sample, Ie this is the larger of the two\n", + " input timstamps.\n", + " dt: Time between input samples, in hours.\n", + " lead_time: Lead time. In hours.\n", + "\n", + " Returns:\n", + " SampleSpec\n", + " \"\"\" # noqa: E501\n", + " assert dt > 0, \"dt should be possitive\"\n", + " lt = pd.to_timedelta(lead_time, unit=\"h\")\n", + " dt = pd.to_timedelta(dt, unit=\"h\")\n", + "\n", + " if lead_time >= 0:\n", + " timestamp_target = timestamp + lt\n", + " else:\n", + " timestamp_target = timestamp - dt + lt\n", + "\n", + " spec = cls(\n", + " inputs=(timestamp - dt, timestamp),\n", + " lead_time=lead_time,\n", + " target=timestamp_target,\n", + " )\n", + "\n", + " return spec\n", + "\n", + " def __repr__(self) -> str:\n", + " return self._info_str()\n", + "\n", + " def __str__(self) -> str:\n", + " return self._info_str()\n", + "\n", + "\n", + "class Merra2Dataset(Dataset):\n", + " \"\"\"MERRA2 dataset. The dataset unifies surface and vertical data as well as\n", + " optional climatology.\n", + "\n", + " Samples come in the form of a dictionary. Not all keys support all\n", + " variables, yet the general ordering of dimensions is\n", + " parameter, level, time, lat, lon\n", + "\n", + " Note:\n", + " Data is assumed to be in NetCDF files containing daily data at 3-hourly\n", + " intervals. These follow the naming patterns\n", + " MERRA2_sfc_YYYYMMHH.nc and MERRA_pres_YYYYMMHH.nc and can be located in\n", + " two different locations. Optional climatology data comes from files\n", + " climate_surface_doyDOY_hourHOD.nc and\n", + " climate_vertical_doyDOY_hourHOD.nc.\n", + "\n", + "\n", + " Note:\n", + " `_get_valid_timestamps` assembles a set of all timestamps for which\n", + " there is data (with hourly resolutions). The result is stored in\n", + " `_valid_timestamps`. `_get_valid_climate_timestamps` does the same with\n", + " climatology data and stores it in `_valid_climate_timestamps`.\n", + "\n", + " Based on this information, `samples` generates a list of valid samples,\n", + " stored in `samples`. Here the format is::\n", + "\n", + " [\n", + " [\n", + " (timestamp 1, lead time A),\n", + " (timestamp 1, lead time B),\n", + " (timestamp 1, lead time C),\n", + " ],\n", + " [\n", + " (timestamp 2, lead time D),\n", + " (timestamp 2, lead time E),\n", + " ]\n", + " ]\n", + "\n", + " That is, the outer list iterates over timestamps (init times), the\n", + " inner over lead times. Only valid entries are stored.\n", + " \"\"\"\n", + "\n", + " valid_vertical_vars = [\n", + " \"CLOUD\",\n", + " \"H\",\n", + " \"OMEGA\",\n", + " \"PL\",\n", + " \"QI\",\n", + " \"QL\",\n", + " \"QV\",\n", + " \"T\",\n", + " \"U\",\n", + " \"V\",\n", + " ]\n", + " valid_surface_vars = [\n", + " \"EFLUX\",\n", + " \"GWETROOT\",\n", + " \"HFLUX\",\n", + " \"LAI\",\n", + " \"LWGAB\",\n", + " \"LWGEM\",\n", + " \"LWTUP\",\n", + " \"PRECTOT\",\n", + " \"PS\",\n", + " \"QV2M\",\n", + " \"SLP\",\n", + " \"SWGNT\",\n", + " \"SWTNT\",\n", + " \"T2M\",\n", + " \"TQI\",\n", + " \"TQL\",\n", + " \"TQV\",\n", + " \"TS\",\n", + " \"U10M\",\n", + " \"V10M\",\n", + " \"Z0M\",\n", + " ]\n", + " valid_static_surface_vars = [\"FRACI\", \"FRLAND\", \"FROCEAN\", \"PHIS\"]\n", + "\n", + " valid_levels = [\n", + " 34.0,\n", + " 39.0,\n", + " 41.0,\n", + " 43.0,\n", + " 44.0,\n", + " 45.0,\n", + " 48.0,\n", + " 51.0,\n", + " 53.0,\n", + " 56.0,\n", + " 63.0,\n", + " 68.0,\n", + " 71.0,\n", + " 72.0,\n", + " ]\n", + "\n", + " timedelta_input = pd.to_timedelta(3, unit=\"h\")\n", + "\n", + " def __init__(\n", + " self,\n", + " time_range: tuple[str | pd.Timestamp, str | pd.Timestamp],\n", + " lead_times: list[int],\n", + " input_times: list[int],\n", + " data_path_surface: str | Path,\n", + " data_path_vertical: str | Path,\n", + " climatology_path_surface: str | Path | None = None,\n", + " climatology_path_vertical: str | Path | None = None,\n", + " surface_vars: list[str] | None = None,\n", + " static_surface_vars: list[str] | None = None,\n", + " vertical_vars: list[str] | None = None,\n", + " levels: list[float] | None = None,\n", + " roll_longitudes: int = 0,\n", + " positional_encoding: str = \"absolute\",\n", + " rtype: type = np.float32,\n", + " dtype: torch.dtype = torch.float32,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " data_path_surface: Location of surface data.\n", + " data_path_vertical: Location of vertical data.\n", + " climatology_path_surface: Location of (optional) surface\n", + " climatology.\n", + " climatology_path_vertical: Location of (optional) vertical\n", + " climatology.\n", + " surface_vars: Surface variables.\n", + " static_surface_vars: Static surface variables.\n", + " vertical_vars: Vertical variables.\n", + " levels: Levels.\n", + " time_range: Used to subset data.\n", + " lead_times: Lead times for generalized forecasting.\n", + " roll_longitudes: Set to non-zero value to data by random amount\n", + " along longitude dimension.\n", + " position_encoding: possible values are\n", + " ['absolute' (default), 'fourier'].\n", + " 'absolute' returns lat lon encoded in 3 dimensions using sine\n", + " and cosine\n", + " 'fourier' returns lat/lon to be encoded by model\n", + " returns lat/lon to be encoded by model\n", + " rtype: numpy data type used during read\n", + " dtype: torch data type of data output\n", + " \"\"\"\n", + "\n", + " self.time_range = (\n", + " pd.to_datetime(time_range[0]),\n", + " pd.to_datetime(time_range[1]),\n", + " )\n", + " self.lead_times = lead_times\n", + " self.input_times = input_times\n", + " self._roll_longitudes = list(range(roll_longitudes + 1))\n", + "\n", + " self._uvars = vertical_vars or self.valid_vertical_vars\n", + " self._level = levels or self.valid_levels\n", + " self._svars = surface_vars or self.valid_surface_vars\n", + " self._sstat = static_surface_vars or self.valid_static_surface_vars\n", + " self._nuvars = len(self._uvars)\n", + " self._nlevel = len(self._level)\n", + " self._nsvars = len(self._svars)\n", + " self._nsstat = len(self._sstat)\n", + "\n", + " self.rtype = rtype\n", + " self.dtype = dtype\n", + "\n", + " self.positional_encoding = positional_encoding\n", + "\n", + " self._data_path_surface = Path(data_path_surface)\n", + " self._data_path_vertical = Path(data_path_vertical)\n", + "\n", + " self.dir_exists(self._data_path_surface)\n", + " self.dir_exists(self._data_path_vertical)\n", + "\n", + " self._get_coordinates()\n", + "\n", + " self._climatology_path_surface = Path(climatology_path_surface) or None\n", + " self._climatology_path_vertical = (\n", + " Path(climatology_path_vertical) or None\n", + " )\n", + " self._require_clim = (\n", + " self._climatology_path_surface is not None\n", + " and self._climatology_path_vertical is not None\n", + " )\n", + "\n", + " if self._require_clim:\n", + " self.dir_exists(self._climatology_path_surface)\n", + " self.dir_exists(self._climatology_path_vertical)\n", + " elif (\n", + " climatology_path_surface is None\n", + " and climatology_path_vertical is None\n", + " ):\n", + " self._climatology_path_surface = None\n", + " self._climatology_path_vertical = None\n", + " else:\n", + " raise ValueError(\n", + " \"Either both or neither of\"\n", + " \"`climatology_path_surface` and\"\n", + " \"`climatology_path_vertical` should be None.\"\n", + " )\n", + "\n", + " if not set(self._svars).issubset(set(self.valid_surface_vars)):\n", + " raise ValueError(\"Invalid surface variable.\")\n", + "\n", + " if not set(self._sstat).issubset(set(self.valid_static_surface_vars)):\n", + " raise ValueError(\"Invalid static surface variable.\")\n", + "\n", + " if not set(self._uvars).issubset(set(self.valid_vertical_vars)):\n", + " raise ValueError(\"Inalid vertical variable.\")\n", + "\n", + " if not set(self._level).issubset(set(self.valid_levels)):\n", + " raise ValueError(\"Invalid level.\")\n", + "\n", + " @staticmethod\n", + " def dir_exists(path: Path) -> None:\n", + " if not path.is_dir():\n", + " raise ValueError(f\"Directory {path} does not exist.\")\n", + "\n", + " @property\n", + " def upper_shape(self) -> tuple:\n", + " \"\"\"Returns the vertical variables shape\n", + " Returns:\n", + " tuple: vertical variable shape in the following order::\n", + "\n", + " [VAR, LEV, TIME, LAT, LON]\n", + " \"\"\"\n", + " return self._nuvars, self._nlevel, 2, 361, 576\n", + "\n", + " @property\n", + " def surface_shape(self) -> tuple:\n", + " \"\"\"Returns the surface variables shape\n", + "\n", + " Returns:\n", + " tuple: surafce shape in the following order::\n", + "\n", + " [VAR, LEV, TIME, LAT, LON]\n", + " \"\"\"\n", + " return self._nsvars, 2, 361, 576\n", + "\n", + " def data_file_surface(self, timestamp: pd.Timestamp) -> Path:\n", + " \"\"\"Build the surfcae data file name based on timestamp\n", + "\n", + " Args:\n", + " timestamp: a timestamp\n", + "\n", + " Returns:\n", + " Path: constructed path\n", + " \"\"\"\n", + " pattern = \"MERRA2_sfc_%Y%m%d.nc\"\n", + " data_file = self._data_path_surface / timestamp.strftime(pattern)\n", + " return data_file\n", + "\n", + " def data_file_vertical(self, timestamp: pd.Timestamp) -> Path:\n", + " \"\"\"Build the vertical data file name based on timestamp\n", + "\n", + " Args:\n", + " timestamp: a timestamp\n", + "\n", + " Returns:\n", + " Path: constructed path\n", + " \"\"\"\n", + " pattern = \"MERRA_pres_%Y%m%d.nc\"\n", + " data_file = self._data_path_vertical / timestamp.strftime(pattern)\n", + " return data_file\n", + "\n", + " def data_file_surface_climate(\n", + " self,\n", + " timestamp: pd.Timestamp | None = None,\n", + " dayofyear: int | None = None,\n", + " hourofday: int | None = None,\n", + " ) -> Path:\n", + " \"\"\"\n", + " Returns the path to a climatology file based either on a timestamp or\n", + " the dayofyear / hourofday combination.\n", + " Args:\n", + " timestamp: A timestamp.\n", + " dayofyear: Day of the year. 1 to 366.\n", + " hourofday: Hour of the day. 0 to 23.\n", + " Returns:\n", + " Path: Path to climatology file.\n", + " \"\"\"\n", + " if timestamp is not None and (\n", + " (dayofyear is not None) or (hourofday is not None)\n", + " ):\n", + " raise ValueError(\n", + " \"Provide either timestamp or both dayofyear and hourofday.\"\n", + " )\n", + "\n", + " if timestamp is not None:\n", + " dayofyear = min(timestamp.dayofyear, 365)\n", + " hourofday = timestamp.hour\n", + "\n", + " file_name = f\"climate_surface_doy{dayofyear:03}_hour{hourofday:02}.nc\"\n", + " data_file = self._climatology_path_surface / file_name\n", + " return data_file\n", + "\n", + " def data_file_vertical_climate(\n", + " self,\n", + " timestamp: pd.Timestamp | None = None,\n", + " dayofyear: int | None = None,\n", + " hourofday: int | None = None,\n", + " ) -> Path:\n", + " \"\"\"Returns the path to a climatology file based either on a timestamp\n", + " or the dayofyear / hourofday combination.\n", + "\n", + " Args:\n", + " timestamp: A timestamp. dayofyear: Day of the year. 1 to 366.\n", + " hourofday: Hour of the day. 0 to 23.\n", + " Returns:\n", + " Path: Path to climatology file.\n", + " \"\"\"\n", + " if timestamp is not None and (\n", + " (dayofyear is not None) or (hourofday is not None)\n", + " ):\n", + " raise ValueError(\n", + " \"Provide either timestamp or both dayofyear and hourofday.\"\n", + " )\n", + "\n", + " if timestamp is not None:\n", + " dayofyear = min(timestamp.dayofyear, 365)\n", + " hourofday = timestamp.hour\n", + "\n", + " file_name = f\"climate_vertical_doy{dayofyear:03}_hour{hourofday:02}.nc\"\n", + " data_file = self._climatology_path_vertical / file_name\n", + " return data_file\n", + "\n", + " def _get_coordinates(self) -> None:\n", + " \"\"\"\n", + " Obtains the coordiantes (latitudes and longitudes) from a single data\n", + " file.\n", + " \"\"\"\n", + " timestamp = next(iter(self.valid_timestamps))\n", + "\n", + " file = self.data_file_surface(timestamp)\n", + " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", + " self.lats = lats = handle[\"lat\"][()].astype(self.rtype)\n", + " self.lons = lons = handle[\"lon\"][()].astype(self.rtype)\n", + "\n", + " deg_to_rad = np.pi / 180\n", + " self._embed_lat = np.sin(lats * deg_to_rad).reshape(-1, 1)\n", + "\n", + " self._embed_lon = np.empty((2, 1, len(lons)), dtype=self.rtype)\n", + " self._embed_lon[0, 0] = np.cos(lons * deg_to_rad)\n", + " self._embed_lon[1, 0] = np.sin(lons * deg_to_rad)\n", + "\n", + " @ft.cached_property\n", + " def lats(self) -> np.ndarray:\n", + " timestamp = next(iter(self.valid_timestamps))\n", + "\n", + " file = self.data_file_surface(timestamp)\n", + " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", + " return handle[\"lat\"][()].astype(self.rtype)\n", + "\n", + " @ft.cached_property\n", + " def lons(self) -> np.ndarray:\n", + " timestamp = next(iter(self.valid_timestamps))\n", + "\n", + " file = self.data_file_surface(timestamp)\n", + " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", + " return handle[\"lon\"][()].astype(self.rtype)\n", + "\n", + " @ft.cached_property\n", + " def position_signal(self) -> np.ndarray:\n", + " \"\"\"Generates the \"position signal\" that is part of the static\n", + " features.\n", + "\n", + " Returns:\n", + " Tensor: Torch tensor of dimension (parameter, lat, lon) containing\n", + " sin(lat), cos(lon), sin(lon).\n", + " \"\"\"\n", + "\n", + " latitudes, longitudes = np.meshgrid(\n", + " self.lats, self.lons, indexing=\"ij\"\n", + " )\n", + "\n", + " if self.positional_encoding == \"absolute\":\n", + " latitudes = latitudes / 360 * 2.0 * np.pi\n", + " longitudes = longitudes / 360 * 2.0 * np.pi\n", + " sur_static = np.stack(\n", + " [np.sin(latitudes), np.cos(longitudes), np.sin(longitudes)],\n", + " axis=0,\n", + " )\n", + " else:\n", + " sur_static = np.stack([latitudes, longitudes], axis=0)\n", + "\n", + " sur_static = sur_static.astype(self.rtype)\n", + "\n", + " return sur_static\n", + "\n", + " @ft.cached_property\n", + " def valid_timestamps(self) -> set[pd.Timestamp]:\n", + " \"\"\"Generates list of valid timestamps based on available files. Only\n", + " timestamps for which both surface and vertical information is available\n", + " are considered valid.\n", + " Returns:\n", + " list: list of timestamps\n", + " \"\"\"\n", + "\n", + " s_glob = self._data_path_surface.glob(\"MERRA2_sfc_????????.nc\")\n", + " s_files = [os.path.basename(f) for f in s_glob]\n", + " v_glob = self._data_path_surface.glob(\"MERRA_pres_????????.nc\")\n", + " v_files = [os.path.basename(f) for f in v_glob]\n", + "\n", + " s_re = re.compile(r\"MERRA2_sfc_(\\d{8}).nc\\Z\")\n", + " v_re = re.compile(r\"MERRA_pres_(\\d{8}).nc\\Z\")\n", + " fmt = \"%Y%m%d\"\n", + "\n", + " s_times = {\n", + " (datetime.strptime(m[1], fmt))\n", + " for f in s_files\n", + " if (m := s_re.match(f))\n", + " }\n", + " v_times = {\n", + " (datetime.strptime(m[1], fmt))\n", + " for f in v_files\n", + " if (m := v_re.match(f))\n", + " }\n", + "\n", + " times = s_times.intersection(v_times)\n", + "\n", + " # Each file contains a day at 3 hour intervals\n", + " times = {\n", + " t + timedelta(hours=i) for i in range(0, 24, 3) for t in times\n", + " }\n", + "\n", + " start_time, end_time = self.time_range\n", + " times = {pd.Timestamp(t) for t in times if start_time <= t <= end_time}\n", + "\n", + " return times\n", + "\n", + " @ft.cached_property\n", + " def valid_climate_timestamps(self) -> set[tuple[int, int]]:\n", + " \"\"\"Generates list of \"timestamps\" (dayofyear, hourofday) for which\n", + " climatology data is present. Only instances for which surface and\n", + " vertical data is available are considered valid.\n", + " Returns:\n", + " list: List of tuples describing valid climatology instances.\n", + " \"\"\"\n", + " if not self._require_clim:\n", + " return set()\n", + "\n", + " s_glob = self._climatology_path_surface.glob(\n", + " \"climate_surface_doy???_hour??.nc\"\n", + " )\n", + " s_files = [os.path.basename(f) for f in s_glob]\n", + "\n", + " v_glob = self._climatology_path_vertical.glob(\n", + " \"climate_vertical_doy???_hour??.nc\"\n", + " )\n", + " v_files = [os.path.basename(f) for f in v_glob]\n", + "\n", + " s_re = re.compile(r\"climate_surface_doy(\\d{3})_hour(\\d{2}).nc\\Z\")\n", + " v_re = re.compile(r\"climate_vertical_doy(\\d{3})_hour(\\d{2}).nc\\Z\")\n", + "\n", + " s_times = {\n", + " (int(m[1]), int(m[2])) for f in s_files if (m := s_re.match(f))\n", + " }\n", + " v_times = {\n", + " (int(m[1]), int(m[2])) for f in v_files if (m := v_re.match(f))\n", + " }\n", + "\n", + " times = s_times.intersection(v_times)\n", + "\n", + " return times\n", + "\n", + " def _data_available(self, spec: SampleSpec) -> bool:\n", + " \"\"\"\n", + " Checks whether data is available for a given SampleSpec object. Does so\n", + " using the internal sets with available data previously constructed. Not\n", + " by checking the file system.\n", + " Args:\n", + " spec: SampleSpec object as returned by SampleSpec.get\n", + " Returns:\n", + " bool: if data is availability.\n", + " \"\"\"\n", + " valid = set(spec.times).issubset(self.valid_timestamps)\n", + "\n", + " if self._require_clim:\n", + " sci = spec.climatology_info\n", + " ci = set(sci) if isinstance(sci, list) else set([sci]) # noqa: C405\n", + " valid &= ci.issubset(self.valid_climate_timestamps)\n", + "\n", + " return valid\n", + "\n", + " @ft.cached_property\n", + " def samples(self) -> list[tuple[pd.Timestamp, int, int]]:\n", + " \"\"\"\n", + " Generates list of all valid samlpes.\n", + " Returns:\n", + " list: List of tuples (timestamp, input time, lead time).\n", + " \"\"\"\n", + " valid_samples = []\n", + " dts = [(it, lt) for it in self.input_times for lt in self.lead_times]\n", + "\n", + " for timestamp in sorted(self.valid_timestamps):\n", + " timestamp_samples = []\n", + " for it, lt in dts:\n", + " spec = SampleSpec.get(timestamp, -it, lt)\n", + "\n", + " if self._data_available(spec):\n", + " timestamp_samples.append((timestamp, it, lt))\n", + "\n", + " if timestamp_samples:\n", + " valid_samples.append(timestamp_samples)\n", + "\n", + " return valid_samples\n", + "\n", + " def _to_torch(\n", + " self,\n", + " data: dict[str, Tensor | list[Tensor]],\n", + " dtype: torch.dtype = torch.float32,\n", + " ) -> dict[str, Tensor | list[Tensor]]:\n", + " out = {}\n", + " for k, v in data.items():\n", + " if isinstance(v, list):\n", + " out[k] = [torch.from_numpy(x).to(dtype) for x in v]\n", + " else:\n", + " out[k] = torch.from_numpy(v).to(dtype)\n", + "\n", + " return out\n", + "\n", + " def _lat_roll(\n", + " self, data: dict[str, Tensor | list[Tensor]], n: int\n", + " ) -> dict[str, Tensor | list[Tensor]]:\n", + " out = {}\n", + " for k, v in data.items():\n", + " if isinstance(v, list):\n", + " out[k] = [torch.roll(x, shifts=n, dims=-1) for x in v]\n", + " else:\n", + " out[k] = torch.roll(v, shifts=n, dims=-1)\n", + "\n", + " return out\n", + "\n", + " def _read_static_data(\n", + " self, file: str | Path, doy: int, hod: int\n", + " ) -> np.ndarray:\n", + " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", + " lats_surf = handle[\"lat\"]\n", + " lons_surf = handle[\"lon\"]\n", + "\n", + " nll = (len(lats_surf), len(lons_surf))\n", + "\n", + " npos = len(self.position_signal)\n", + " ntime = 4\n", + "\n", + " nstat = npos + ntime + self._nsstat\n", + " data = np.empty((nstat, *nll), dtype=self.rtype)\n", + "\n", + " for i, key in enumerate(self._sstat, start=npos + ntime):\n", + " data[i] = handle[key][()].astype(dtype=self.rtype)\n", + "\n", + " # [possition signal], cos(doy), sin(doy), cos(hod), sin(hod)\n", + " data[0:npos] = self.position_signal\n", + " data[npos + 0] = np.cos(2 * np.pi * doy / 366)\n", + " data[npos + 1] = np.sin(2 * np.pi * doy / 366)\n", + " data[npos + 2] = np.cos(2 * np.pi * hod / 24)\n", + " data[npos + 3] = np.sin(2 * np.pi * hod / 24)\n", + "\n", + " return data\n", + "\n", + " def _read_surface(\n", + " self, tidx: int, nll: tuple[int, int], handle: h5py.File\n", + " ) -> np.ndarray:\n", + " data = np.empty((self._nsvars, *nll), dtype=self.rtype)\n", + "\n", + " for i, key in enumerate(self._svars):\n", + " data[i] = handle[key][tidx][()].astype(dtype=self.rtype)\n", + "\n", + " return data\n", + "\n", + " def _read_levels(\n", + " self, tidx: int, nll: tuple[int, int], handle: h5py.File\n", + " ) -> np.ndarray:\n", + " lvls = handle[\"lev\"][()]\n", + " lidx = self._level_idxs(lvls)\n", + "\n", + " data = np.empty((self._nuvars, self._nlevel, *nll), dtype=self.rtype)\n", + "\n", + " for i, key in enumerate(self._uvars):\n", + " data[i] = handle[key][tidx, lidx][()].astype(dtype=self.rtype)\n", + "\n", + " return np.ascontiguousarray(np.flip(data, axis=1))\n", + "\n", + " def _level_idxs(self, lvls):\n", + " lidx = [np.argwhere(lvls == int(lvl)).item() for lvl in self._level]\n", + " return sorted(lidx)\n", + "\n", + " @staticmethod\n", + " def _date_to_tidx(date: datetime | pd.Timestamp, handle: h5py.File) -> int:\n", + " if isinstance(date, pd.Timestamp):\n", + " date = date.to_pydatetime()\n", + "\n", + " time = handle[\"time\"]\n", + "\n", + " t0 = time.attrs[\"begin_time\"][()].item()\n", + " d0 = f\"{time.attrs['begin_date'][()].item()}\"\n", + "\n", + " offset = datetime.strptime(d0, \"%Y%m%d\")\n", + "\n", + " times = [offset + timedelta(minutes=int(t + t0)) for t in time[()]]\n", + " return times.index(date)\n", + "\n", + " def _read_data(\n", + " self, file_pair: tuple[str, str], date: datetime\n", + " ) -> dict[str, np.ndarray]:\n", + " s_file, v_file = file_pair\n", + "\n", + " with h5py.File(s_file, \"r\", libver=\"latest\") as shandle:\n", + " lats_surf = shandle[\"lat\"]\n", + " lons_surf = shandle[\"lon\"]\n", + "\n", + " nll = (len(lats_surf), len(lons_surf))\n", + "\n", + " tidx = self._date_to_tidx(date, shandle)\n", + "\n", + " sdata = self._read_surface(tidx, nll, shandle)\n", + "\n", + " with h5py.File(v_file, \"r\", libver=\"latest\") as vhandle:\n", + " lats_vert = vhandle[\"lat\"]\n", + " lons_vert = vhandle[\"lon\"]\n", + "\n", + " nll = (len(lats_vert), len(lons_vert))\n", + "\n", + " tidx = self._date_to_tidx(date, vhandle)\n", + "\n", + " vdata = self._read_levels(tidx, nll, vhandle)\n", + "\n", + " data = {\"vert\": vdata, \"surf\": sdata}\n", + "\n", + " return data\n", + "\n", + " def _read_climate(\n", + " self, file_pair: tuple[str, str]\n", + " ) -> dict[str, np.ndarray]:\n", + " s_file, v_file = file_pair\n", + "\n", + " with h5py.File(s_file, \"r\", libver=\"latest\") as shandle:\n", + " lats_surf = shandle[\"lat\"]\n", + " lons_surf = shandle[\"lon\"]\n", + "\n", + " nll = (len(lats_surf), len(lons_surf))\n", + "\n", + " sdata = np.empty((self._nsvars, *nll), dtype=self.rtype)\n", + "\n", + " for i, key in enumerate(self._svars):\n", + " sdata[i] = shandle[key][()].astype(dtype=self.rtype)\n", + "\n", + " with h5py.File(v_file, \"r\", libver=\"latest\") as vhandle:\n", + " lats_vert = vhandle[\"lat\"]\n", + " lons_vert = vhandle[\"lon\"]\n", + "\n", + " nll = (len(lats_vert), len(lons_vert))\n", + "\n", + " lvls = vhandle[\"lev\"][()]\n", + " lidx = self._level_idxs(lvls)\n", + "\n", + " vdata = np.empty(\n", + " (self._nuvars, self._nlevel, *nll), dtype=self.rtype\n", + " )\n", + "\n", + " for i, key in enumerate(self._uvars):\n", + " vdata[i] = vhandle[key][lidx][()].astype(dtype=self.rtype)\n", + "\n", + " data = {\n", + " \"vert\": np.ascontiguousarray(np.flip(vdata, axis=1)),\n", + " \"surf\": sdata,\n", + " }\n", + "\n", + " return data\n", + "\n", + " def get_data_from_sample_spec(\n", + " self, spec: SampleSpec\n", + " ) -> dict[str, Tensor | int | float]:\n", + " \"\"\"Loads and assembles sample data given a SampleSpec object.\n", + "\n", + " Args:\n", + " spec (SampleSpec): Full details regarding the data to be loaded\n", + " Returns:\n", + " dict: Dictionary with the following keys::\n", + "\n", + " 'sur_static': Torch tensor of shape [parameter, lat, lon]. For\n", + " each pixel (lat, lon), the first 7 dimensions index sin(lat),\n", + " cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod).\n", + " Where doy is the day of the year [1, 366] and hod the hour of\n", + " the day [0, 23].\n", + " 'sur_vals': Torch tensor of shape [parameter, time, lat, lon].\n", + " 'sur_tars': Torch tensor of shape [parameter, time, lat, lon].\n", + " 'ulv_vals': Torch tensor of shape [parameter, level, time, lat, lon].\n", + " 'ulv_tars': Torch tensor of shape [parameter, level, time, lat, lon].\n", + " 'sur_climate': Torch tensor of shape [parameter, lat, lon].\n", + " 'ulv_climate': Torch tensor of shape [paramter, level, lat, lon].\n", + " 'lead_time': Float.\n", + " 'input_time': Float.\n", + "\n", + " \"\"\" # noqa: E501\n", + "\n", + " # We assemble the unique timestamps for which we need data.\n", + " vals_required = {*spec.times}\n", + " stat_required = {*spec.stat_times}\n", + "\n", + " # We assemble the unique data files from which we need value data\n", + " vals_file_map = defaultdict(list)\n", + " for t in vals_required:\n", + " data_files = (\n", + " self.data_file_surface(t),\n", + " self.data_file_vertical(t),\n", + " )\n", + " vals_file_map[data_files].append(t)\n", + "\n", + " # We assemble the unique data files from which we need static data\n", + " stat_file_map = defaultdict(list)\n", + " for t in stat_required:\n", + " data_files = (\n", + " self.data_file_surface(t),\n", + " self.data_file_vertical(t),\n", + " )\n", + " stat_file_map[data_files].append(t)\n", + "\n", + " # Load the value data\n", + " data = {}\n", + " for data_files, times in vals_file_map.items():\n", + " for time in times:\n", + " data[time] = self._read_data(data_files, time)\n", + "\n", + " # Combine times\n", + " sample_data = {}\n", + "\n", + " input_upl = np.stack([data[t][\"vert\"] for t in spec.inputs], axis=2)\n", + " sample_data[\"ulv_vals\"] = input_upl\n", + "\n", + " target_upl = data[spec.target][\"vert\"]\n", + " sample_data[\"ulv_tars\"] = target_upl[:, :, None]\n", + "\n", + " input_sur = np.stack([data[t][\"surf\"] for t in spec.inputs], axis=1)\n", + " sample_data[\"sur_vals\"] = input_sur\n", + "\n", + " target_sur = data[spec.target][\"surf\"]\n", + " sample_data[\"sur_tars\"] = target_sur[:, None]\n", + "\n", + " # Load the static data\n", + " data_files, times = stat_file_map.popitem()\n", + " time = times[0].dayofyear, times[0].hour\n", + " sample_data[\"sur_static\"] = self._read_static_data(\n", + " data_files[0], *time\n", + " )\n", + "\n", + " # If required load the surface data\n", + " if self._require_clim:\n", + " ci_year, ci_hour = spec.climatology_info\n", + "\n", + " surf_file = self.data_file_surface_climate(\n", + " dayofyear=ci_year,\n", + " hourofday=ci_hour,\n", + " )\n", + "\n", + " vert_file = self.data_file_vertical_climate(\n", + " dayofyear=ci_year,\n", + " hourofday=ci_hour,\n", + " )\n", + "\n", + " clim_data = self._read_climate((surf_file, vert_file))\n", + "\n", + " sample_data[\"sur_climate\"] = clim_data[\"surf\"]\n", + " sample_data[\"ulv_climate\"] = clim_data[\"vert\"]\n", + "\n", + " # Move the data from numpy to torch\n", + " sample_data = self._to_torch(sample_data, dtype=self.dtype)\n", + "\n", + " # Optionally roll\n", + " if len(self._roll_longitudes) > 0:\n", + " roll_by = random.choice(self._roll_longitudes)\n", + " sample_data = self._lat_roll(sample_data, roll_by)\n", + "\n", + " # Now that we have rolled, we can add the static data\n", + " sample_data[\"lead_time\"] = spec.lead_time\n", + " sample_data[\"input_time\"] = spec.input_time\n", + "\n", + " return sample_data\n", + "\n", + " def get_data(\n", + " self, timestamp: pd.Timestamp, input_time: int, lead_time: int\n", + " ) -> dict[str, Tensor | int]:\n", + " \"\"\"\n", + " Loads data based on timestamp and lead time.\n", + " Args:\n", + " timestamp: Timestamp.\n", + " input_time: time between input samples.\n", + " lead_time: lead time.\n", + " Returns:\n", + " Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", + " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',\n", + " 'lead_time'.\n", + " \"\"\"\n", + " spec = SampleSpec.get(timestamp, -input_time, lead_time)\n", + " sample_data = self.get_data_from_sample_spec(spec)\n", + " return sample_data\n", + "\n", + " def __getitem__(self, idx: int) -> dict[str, Tensor | int]:\n", + " \"\"\"\n", + " Loads data based on sample index and random choice of sample.\n", + " Args:\n", + " idx: Sample index.\n", + " Returns:\n", + " Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", + " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',\n", + " 'lead_time', 'input_time'.\n", + " \"\"\"\n", + " sample_set = self.samples[idx]\n", + " timestamp, input_time, lead_time, *nsteps = random.choice(sample_set)\n", + " sample_data = self.get_data(timestamp, input_time, lead_time)\n", + " return sample_data\n", + "\n", + " def __len__(self):\n", + " return len(self.samples)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import functools as ft\n", + "import random\n", + "from collections import defaultdict\n", + "from copy import deepcopy\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "from torch import Tensor\n", + "\n", + "# from PrithviWxC.dataloaders.merra2 import Merra2Dataset, SampleSpec\n", + "\n", + "\n", + "def preproc(\n", + " batch: list[dict[str, int | float | Tensor]], padding: dict[tuple[int]]\n", + ") -> dict[str, Tensor]:\n", + " \"\"\"Prepressing function for MERRA2 Dataset\n", + "\n", + " Args:\n", + " batch (dict): List of training samples, each sample should be a\n", + " dictionary with the following keys::\n", + "\n", + " 'sur_static': Numpy array of shape (3, lat, lon). For each pixel (lat, lon), the first dimension indexes sin(lat), cos(lon), sin(lon).\n", + " 'sur_vals': Torch tensor of shape (parameter, time, lat, lon).\n", + " 'sur_tars': Torch tensor of shape (parameter, time, lat, lon).\n", + " 'ulv_vals': Torch tensor of shape (parameter, level, time, lat, lon).\n", + " 'ulv_tars': Torch tensor of shape (parameter, level, time, lat, lon).\n", + " 'sur_climate': Torch tensor of shape (nstep, parameter, lat, lon)\n", + " 'ulv_climate': Torch tensor of shape (nstep parameter, level, lat, lon)\n", + " 'lead_time': Integer.\n", + " 'input_time': Interger\n", + "\n", + " padding: Dictionary with keys 'level', 'lat', 'lon', each of dim 2.\n", + "\n", + " Returns:\n", + " Dictionary with the following keys::\n", + "\n", + " 'x': [batch, time, parameter, lat, lon]\n", + " 'ys': [batch, nsteps, parameter, lat, lon]\n", + " 'static': [batch, nstep, parameter, lat, lon]\n", + " 'lead_time': [batch]\n", + " 'input_time': [batch]\n", + " 'climate (Optional)': [batch, nsteps, parameter, lat, lon]\n", + "\n", + " Note:\n", + " Here, for x and ys, 'parameter' is [surface parameter, upper level,\n", + " parameter x level]. Similarly for the static information we have\n", + " [sin(lat), cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod),\n", + " ...].\n", + " \"\"\" # noqa: E501\n", + "\n", + " b0 = batch[0]\n", + " nbatch = len(batch)\n", + " data_keys = set(b0.keys())\n", + "\n", + " essential_keys = {\n", + " \"sur_static\",\n", + " \"sur_vals\",\n", + " \"sur_tars\",\n", + " \"ulv_vals\",\n", + " \"ulv_tars\",\n", + " \"input_time\",\n", + " \"lead_time\",\n", + " }\n", + "\n", + " climate_keys = {\n", + " \"sur_climate\",\n", + " \"ulv_climate\",\n", + " }\n", + "\n", + " all_keys = essential_keys | climate_keys\n", + "\n", + " if not essential_keys.issubset(data_keys):\n", + " raise ValueError(\"Missing essential keys.\")\n", + "\n", + " if not data_keys.issubset(all_keys):\n", + " raise ValueError(\"Unexpected keys in batch.\")\n", + "\n", + " # Bring all tensors from the batch into a single tensor\n", + " upl_x = torch.empty((nbatch, *b0[\"ulv_vals\"].shape))\n", + " upl_y = torch.empty((nbatch, *b0[\"ulv_tars\"].shape))\n", + "\n", + " sur_x = torch.empty((nbatch, *b0[\"sur_vals\"].shape))\n", + " sur_y = torch.empty((nbatch, *b0[\"sur_tars\"].shape))\n", + "\n", + " sur_sta = torch.empty((nbatch, *b0[\"sur_static\"].shape))\n", + "\n", + " lead_time = torch.empty(\n", + " (nbatch, *b0[\"lead_time\"].shape),\n", + " dtype=torch.float32,\n", + " )\n", + " input_time = torch.empty((nbatch,), dtype=torch.float32)\n", + "\n", + " for i, rec in enumerate(batch):\n", + " sur_x[i] = torch.Tensor(rec[\"sur_vals\"])\n", + " sur_y[i] = torch.Tensor(rec[\"sur_tars\"])\n", + "\n", + " upl_x[i] = torch.Tensor(rec[\"ulv_vals\"])\n", + " upl_y[i] = torch.Tensor(rec[\"ulv_tars\"])\n", + "\n", + " sur_sta[i] = torch.Tensor(rec[\"sur_static\"])\n", + "\n", + " lead_time[i] = rec[\"lead_time\"]\n", + " input_time[i] = rec[\"input_time\"]\n", + "\n", + " return_value = {\n", + " \"lead_time\": lead_time,\n", + " \"input_time\": input_time,\n", + " \"target_time\": torch.sum(lead_time).reshape(-1),\n", + " }\n", + "\n", + " # Reshape (batch, parameter, level, time, lat, lon)\n", + " # -> (batch, time, parameter, level, lat, lon)\n", + " upl_x = upl_x.permute((0, 3, 1, 2, 4, 5))\n", + " upl_y = upl_y.permute((0, 3, 1, 2, 4, 5))\n", + "\n", + " # Reshape (batch, parameter, time, lat, lon)\n", + " # -> (batch, time, parameter, lat, lon)\n", + " sur_x = sur_x.permute((0, 2, 1, 3, 4))\n", + " sur_y = sur_y.permute((0, 2, 1, 3, 4))\n", + "\n", + " # Pad\n", + " padding_2d = (*padding[\"lon\"], *padding[\"lat\"])\n", + "\n", + " def pad2d(x):\n", + " return torch.nn.functional.pad(x, padding_2d, mode=\"constant\", value=0)\n", + "\n", + " padding_3d = (*padding[\"lon\"], *padding[\"lat\"], *padding[\"level\"])\n", + "\n", + " def pad3d(x):\n", + " return torch.nn.functional.pad(x, padding_3d, mode=\"constant\", value=0)\n", + "\n", + " sur_x = pad2d(sur_x).contiguous()\n", + " upl_x = pad3d(upl_x).contiguous()\n", + " sur_y = pad2d(sur_y).contiguous()\n", + " upl_y = pad3d(upl_y).contiguous()\n", + " return_value[\"statics\"] = pad2d(sur_sta).contiguous()\n", + "\n", + " # We stack along the combined parameter level dimension\n", + " return_value[\"x\"] = torch.cat(\n", + " (sur_x, upl_x.view(*upl_x.shape[:2], -1, *upl_x.shape[4:])), dim=2\n", + " )\n", + " return_value[\"ys\"] = torch.cat(\n", + " (sur_y, upl_y.view(*upl_y.shape[:2], -1, *upl_y.shape[4:])), dim=2\n", + " )\n", + "\n", + " if climate_keys.issubset(data_keys):\n", + " sur_climate = torch.empty((nbatch, *b0[\"sur_climate\"].shape))\n", + " ulv_climate = torch.empty((nbatch, *b0[\"ulv_climate\"].shape))\n", + " for i, rec in enumerate(batch):\n", + " sur_climate[i] = rec[\"sur_climate\"]\n", + " ulv_climate[i] = rec[\"ulv_climate\"]\n", + " sur_climate = pad2d(sur_climate)\n", + " ulv_climate = pad3d(ulv_climate)\n", + "\n", + " ulv_climate = ulv_climate.view(\n", + " *ulv_climate.shape[:2], -1, *ulv_climate.shape[4:]\n", + " )\n", + " return_value[\"climates\"] = torch.cat((sur_climate, ulv_climate), dim=2)\n", + "\n", + " return return_value\n", + "\n", + "\n", + "class RolloutSpec(SampleSpec):\n", + " \"\"\"\n", + " A data class to collect the information used to define a rollout sample.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " inputs: tuple[pd.Timestamp, pd.Timestamp],\n", + " lead_time: int,\n", + " target: pd.Timestamp,\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " inputs: Tuple of timestamps. In ascending order.\n", + " lead_time: Lead time. In hours.\n", + " target: Timestamp of the target. Can be before or after the inputs.\n", + " \"\"\"\n", + " super().__init__(inputs, lead_time, target)\n", + "\n", + " self.dt = dt = pd.Timedelta(lead_time, unit=\"h\")\n", + " self.inters = list(pd.date_range(inputs[-1], target, freq=dt))\n", + "\n", + " self._ctimes = deepcopy(self.inters)\n", + " self.stat_times = deepcopy(self.inters)\n", + "\n", + " self.stat_times.pop(-1)\n", + " self._ctimes.pop(0)\n", + " self.inters.pop(0)\n", + " self.inters.pop(-1)\n", + "\n", + " self.times = [*inputs, *self.inters, target]\n", + " self.targets = self.times[2:]\n", + " self.nsteps = len(self.times) - 2\n", + "\n", + " @property\n", + " def climatology_info(self) -> dict[pd.Timestamp, tuple[int, int]]:\n", + " \"\"\"Returns information required to obtain climatology data.\n", + " Returns:\n", + " list: list containing required climatology info.\n", + " \"\"\"\n", + " return [(min(t.dayofyear, 365), t.hour) for t in self._ctimes]\n", + "\n", + " def _info_str(self) -> str:\n", + " iso_8601 = \"%Y-%m-%dT%H:%M:%S\"\n", + "\n", + " inter_str = \"\\n\".join(t.strftime(iso_8601) for t in self.inters)\n", + "\n", + " return (\n", + " f\"Issue time: {self.inputs[1].strftime(iso_8601)}\\n\"\n", + " f\"Lead time: {self.lead_time} hours ahead\\n\"\n", + " f\"Target time: {self.target.strftime(iso_8601)}\\n\"\n", + " f\"Intermediate times: {inter_str}\"\n", + " )\n", + "\n", + " @classmethod\n", + " def get(cls, timestamp: pd.Timestamp, lead_time: int, nsteps: int):\n", + " \"\"\"Given a timestamp and lead time, generates a RolloutSpec object\n", + " describing the sample further.\n", + "\n", + " Args:\n", + " timestamp: Timstamp (issue time) of the sample.\n", + " lead_time: Lead time. In hours.\n", + "\n", + " Returns:\n", + " SampleSpec object.\n", + " \"\"\"\n", + " if lead_time > 0:\n", + " dt = pd.to_timedelta(lead_time, unit=\"h\")\n", + " timestamp_target = timestamp + nsteps * dt\n", + " else:\n", + " raise ValueError(\"Rollout is only forwards\")\n", + "\n", + " spec = cls(\n", + " inputs=(timestamp - dt, timestamp),\n", + " lead_time=lead_time,\n", + " target=timestamp_target,\n", + " )\n", + "\n", + " return spec\n", + "\n", + " def __repr__(self) -> str:\n", + " return self._info_str()\n", + "\n", + " def __str__(self) -> str:\n", + " return self._info_str()\n", + "\n", + "\n", + "class Merra2RolloutDataset(Merra2Dataset):\n", + " \"\"\"Dataset class that read MERRA2 data for performing rollout.\n", + "\n", + " Implementation details::\n", + "\n", + " Samples stores the list of valid samples. This takes the form\n", + " ```\n", + " [\n", + " [(timestamp 1, -input_time, n_steps)],\n", + " [(timestamp 2, -input_time, n_steps)],\n", + " ]\n", + " ```\n", + " The nested list is for compatibility reasons with Merra2Dataset. Note\n", + " that input time and n_steps are always the same value. For some reason\n", + " the sign of input_time is the opposite to that in Merra2Dataset\n", + " \"\"\"\n", + "\n", + " input_time_len = 2\n", + "\n", + " def __init__(\n", + " self,\n", + " time_range: tuple[str | pd.Timestamp, str | pd.Timestamp],\n", + " input_time: int | float | pd.Timedelta,\n", + " lead_time: int | float,\n", + " data_path_surface: str | Path,\n", + " data_path_vertical: str | Path,\n", + " climatology_path_surface: str | Path | None,\n", + " climatology_path_vertical: str | Path | None,\n", + " surface_vars: list[str],\n", + " static_surface_vars: list[str],\n", + " vertical_vars: list[str],\n", + " levels: list[float],\n", + " roll_longitudes: int = 0,\n", + " positional_encoding: str = \"absolute\",\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " time_range: time range to consider when building dataset\n", + " input_time: requested time between inputs\n", + " lead_time: requested time to predict\n", + " data_path_surface: path of surface data directory\n", + " data_path_vertical: path of vertical data directory\n", + " climatology_path_surface: path of surface climatology data\n", + " directory\n", + " climatology_path_vertical: path of vertical climatology data\n", + " directory\n", + " surface_vars: surface variables to return\n", + " static_surface_vars: static surface variables to return\n", + " vertical_vars: vertical variables to return\n", + " levels: MERA2 vertical levels to consider\n", + " roll_longitudes: Whether and now uch to randomly roll latitudes by.\n", + " Defaults to 0.\n", + " positional_encoding: The type of possitional encodeing to use.\n", + " Defaults to \"absolute\".\n", + "\n", + " Raises:\n", + " ValueError: If lead time is not integer multiple of input time\n", + " \"\"\"\n", + "\n", + " self._target_lead = lead_time\n", + "\n", + " if isinstance(input_time, int) or isinstance(input_time, float):\n", + " self.timedelta_input = pd.to_timedelta(-input_time, unit=\"h\")\n", + " else:\n", + " self.timedelta_input = -input_time\n", + "\n", + " lead_times = [self.timedelta_input / pd.to_timedelta(1, unit=\"h\")]\n", + "\n", + " super().__init__(\n", + " time_range,\n", + " lead_times,\n", + " [input_time],\n", + " data_path_surface,\n", + " data_path_vertical,\n", + " climatology_path_surface,\n", + " climatology_path_vertical,\n", + " surface_vars,\n", + " static_surface_vars,\n", + " vertical_vars,\n", + " levels,\n", + " roll_longitudes,\n", + " positional_encoding,\n", + " )\n", + "\n", + " nstep_float = (\n", + " pd.to_timedelta(self._target_lead, unit=\"h\") / self.timedelta_input\n", + " )\n", + "\n", + " if abs(nstep_float % 1) > 1e-5:\n", + " raise ValueError(\"Leadtime not multiple of input time\")\n", + "\n", + " self.nsteps = round(nstep_float)\n", + "\n", + " @ft.cached_property\n", + " def samples(self) -> list[tuple[pd.Timestamp, int, int]]:\n", + " \"\"\"Generates list of all valid samlpes.\n", + "\n", + " Returns:\n", + " List of tuples (timestamp, input time, lead time).\n", + " \"\"\"\n", + " valid_samples = []\n", + "\n", + " for timestamp in sorted(self.valid_timestamps):\n", + " timestamp_samples = []\n", + " for lt in self.lead_times:\n", + " spec = RolloutSpec.get(timestamp, lt, self.nsteps)\n", + "\n", + " if self._data_available(spec):\n", + " timestamp_samples.append(\n", + " (timestamp, self.input_times[0], lt, self.nsteps)\n", + " )\n", + "\n", + " if timestamp_samples:\n", + " valid_samples.append(timestamp_samples)\n", + "\n", + " return valid_samples\n", + "\n", + " def get_data_from_rollout_spec(\n", + " self, spec: RolloutSpec\n", + " ) -> dict[str, Tensor | int | float]:\n", + " \"\"\"Loads and assembles sample data given a RolloutSpec object.\n", + "\n", + " Args:\n", + " spec (RolloutSpec): Full details regarding the data to be loaded\n", + " Returns:\n", + " dict: Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", + " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',c'lead_time',\n", + " 'input_time'. For each, the value is as follows::\n", + "\n", + " {\n", + " 'sur_static': Torch tensor of shape [parameter, lat, lon]. For\n", + " each pixel (lat, lon), the first 7 dimensions index sin(lat),\n", + " cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod).\n", + " Where doy is the day of the year [1, 366] and hod the hour of\n", + " the day [0, 23].\n", + " 'sur_vals': Torch tensor of shape [parameter, time, lat, lon].\n", + " 'sur_tars': Torch tensor of shape [parameter, time, lat, lon].\n", + " 'ulv_vals': Torch tensor of shape\n", + " [parameter, level, time, lat, lon].\n", + " 'ulv_tars': Torch tensor of shape\n", + " [nsteps, parameter, level, time, lat, lon].\n", + " 'sur_climate': Torch tensor of shape\n", + " [nsteps, parameter, lat, lon].\n", + " 'ulv_climate': Torch tensor of shape\n", + " [nsteps, paramter, level, lat, lon].\n", + " 'lead_time': Float.\n", + " 'input_time': Float.\n", + " }\n", + "\n", + " \"\"\"\n", + "\n", + " # We assemble the unique timestamps for which we need data.\n", + " vals_required = {*spec.times}\n", + " stat_required = {*spec.stat_times}\n", + "\n", + " # We assemble the unique data files from which we need value data\n", + " vals_file_map = defaultdict(list)\n", + " for t in vals_required:\n", + " data_files = (\n", + " self.data_file_surface(t),\n", + " self.data_file_vertical(t),\n", + " )\n", + " vals_file_map[data_files].append(t)\n", + "\n", + " # We assemble the unique data files from which we need static data\n", + " stat_file_map = defaultdict(list)\n", + " for t in stat_required:\n", + " data_files = (\n", + " self.data_file_surface(t),\n", + " self.data_file_vertical(t),\n", + " )\n", + " stat_file_map[data_files].append(t)\n", + "\n", + " # Load the value data\n", + " data = {}\n", + " for data_files, times in vals_file_map.items():\n", + " for time in times:\n", + " data[time] = self._read_data(data_files, time)\n", + "\n", + " # Load the static data\n", + " stat = {}\n", + " for data_files, times in stat_file_map.items():\n", + " for time in times:\n", + " hod, doy = time.hour, time.dayofyear\n", + " stat[time] = self._read_static_data(data_files[0], hod, doy)\n", + "\n", + " # Combine times\n", + " sample_data = {}\n", + "\n", + " input_upl = np.stack([data[t][\"vert\"] for t in spec.inputs], axis=2)\n", + " sample_data[\"ulv_vals\"] = input_upl\n", + "\n", + " target_upl = np.stack([data[t][\"vert\"] for t in spec.targets], axis=2)\n", + " sample_data[\"ulv_tars\"] = target_upl\n", + "\n", + " input_sur = np.stack([data[t][\"surf\"] for t in spec.inputs], axis=1)\n", + " sample_data[\"sur_vals\"] = input_sur\n", + "\n", + " target_sur = np.stack([data[t][\"surf\"] for t in spec.targets], axis=1)\n", + " sample_data[\"sur_tars\"] = target_sur\n", + "\n", + " # Load the static data\n", + " static = np.stack([stat[t] for t in spec.stat_times], axis=0)\n", + " sample_data[\"sur_static\"] = static\n", + "\n", + " # If required load the climate data\n", + " if self._require_clim:\n", + " clim_data = {}\n", + " for ci in spec.climatology_info:\n", + " ci_year, ci_hour = ci\n", + "\n", + " surf_file = self.data_file_surface_climate(\n", + " dayofyear=ci_year,\n", + " hourofday=ci_hour,\n", + " )\n", + "\n", + " vert_file = self.data_file_vertical_climate(\n", + " dayofyear=ci_year,\n", + " hourofday=ci_hour,\n", + " )\n", + "\n", + " clim_data[ci] = self._read_climate((surf_file, vert_file))\n", + "\n", + " clim_surf = [clim_data[ci][\"surf\"] for ci in spec.climatology_info]\n", + " sample_data[\"sur_climate\"] = np.stack(clim_surf, axis=0)\n", + "\n", + " clim_surf = [clim_data[ci][\"vert\"] for ci in spec.climatology_info]\n", + " sample_data[\"ulv_climate\"] = np.stack(clim_surf, axis=0)\n", + "\n", + " # Move the data from numpy to torch\n", + " sample_data = self._to_torch(sample_data, dtype=self.dtype)\n", + "\n", + " # Optionally roll\n", + " if len(self._roll_longitudes) > 0:\n", + " roll_by = random.choice(self._roll_longitudes)\n", + " sample_data = self._lat_roll(sample_data, roll_by)\n", + "\n", + " # Now that we have rolled, we can add the static data\n", + " lt = torch.tensor([spec.lead_time] * self.nsteps).to(self.dtype)\n", + " sample_data[\"lead_time\"] = lt\n", + " sample_data[\"input_time\"] = spec.input_time\n", + "\n", + " return sample_data\n", + "\n", + " def get_data(\n", + " self, timestamp: pd.Timestamp, *args, **kwargs\n", + " ) -> dict[Tensor | int]:\n", + " \"\"\"Loads data based on timestamp and lead time.\n", + "\n", + " Args:\n", + " timestamp: Timestamp.\n", + " Returns:\n", + " Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", + " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',\n", + " 'lead_time', 'input_time'\n", + " \"\"\"\n", + " rollout_spec = RolloutSpec.get(\n", + " timestamp, self.lead_times[0], self.nsteps\n", + " )\n", + " sample_data = self.get_data_from_rollout_spec(rollout_spec)\n", + " return sample_data\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# from PrithviWxC.dataloaders.merra2_rollout import Merra2RolloutDataset\n", + "\n", + "dataset = Merra2RolloutDataset(\n", + " time_range=time_range,\n", + " lead_time=lead_time,\n", + " input_time=input_time,\n", + " data_path_surface=surf_dir,\n", + " data_path_vertical=vert_dir,\n", + " climatology_path_surface=surf_clim_dir,\n", + " climatology_path_vertical=vert_clim_dir,\n", + " surface_vars=surface_vars,\n", + " static_surface_vars=static_surface_vars,\n", + " vertical_vars=vertical_vars,\n", + " levels=levels,\n", + " positional_encoding=positional_encoding,\n", + ")\n", + "assert len(dataset) > 0, \"There doesn't seem to be any valid data.\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model\n", + "### Scalers and other hyperparameters\n", + "Again, this setup is similar as before." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# from PrithviWxC.dataloaders.merra2 import (\n", + "# input_scalers,\n", + "# output_scalers,\n", + "# static_input_scalers,\n", + "# )\n", + "\n", + "surf_in_scal_path = Path(\"./climatology/musigma_surface.nc\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=f\"climatology/{surf_in_scal_path.name}\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "vert_in_scal_path = Path(\"./climatology/musigma_vertical.nc\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=f\"climatology/{vert_in_scal_path.name}\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "surf_out_scal_path = Path(\"./climatology/anomaly_variance_surface.nc\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=f\"climatology/{surf_out_scal_path.name}\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "vert_out_scal_path = Path(\"./climatology/anomaly_variance_vertical.nc\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=f\"climatology/{vert_out_scal_path.name}\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.rollout.2300m.v1\",\n", + " filename=\"config.yaml\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "in_mu, in_sig = input_scalers(\n", + " surface_vars,\n", + " vertical_vars,\n", + " levels,\n", + " surf_in_scal_path,\n", + " vert_in_scal_path,\n", + ")\n", + "\n", + "output_sig = output_scalers(\n", + " surface_vars,\n", + " vertical_vars,\n", + " levels,\n", + " surf_out_scal_path,\n", + " vert_out_scal_path,\n", + ")\n", + "\n", + "static_mu, static_sig = static_input_scalers(\n", + " surf_in_scal_path,\n", + " static_surface_vars,\n", + ")\n", + "\n", + "residual = \"none\"\n", + "masking_mode = \"local\"\n", + "decoder_shifting = True\n", + "masking_ratio = 0.99" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model init\n", + "We can now build and load the pretrained weights, note that you should use the\n", + "rollout version of the weights." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'weights\\\\prithvi.wxc.rollout.2300m.v1.pt'" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "weights_path = Path(\"./weights/prithvi.wxc.rollout.2300m.v1.pt\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.rollout.2300m.v1\",\n", + " filename=weights_path.name,\n", + " local_dir=\"./weights\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import cached_property\n", + "from importlib.metadata import version\n", + "\n", + "from torch import Tensor\n", + "from torch.utils.checkpoint import checkpoint\n", + "\n", + "if version(\"torch\") > \"2.3.0\":\n", + " from torch.nn.attention import SDPBackend, sdpa_kernel\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "# DropPath code is straight from timm\n", + "# (https://huggingface.co/spaces/Roll20/pet_score/blame/main/lib/timm/models/layers/drop.py)\n", + "def drop_path(\n", + " x: Tensor,\n", + " drop_prob: float = 0.0,\n", + " training: bool = False,\n", + " scale_by_keep: bool = True,\n", + ") -> Tensor:\n", + " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of\n", + " residual blocks). Taken form timm.\n", + "\n", + " Args:\n", + " x (Tensor): Input tensor.\n", + " drop_prob (float): Probability of dropping `x`, defaults to 0.\n", + " training (bool): Whether model is in in traingin of eval mode,\n", + " defaults to False.\n", + " scale_by_keep (bool): Whether the output should scaled by\n", + " (`1 - drop_prob`), defaults to True.\n", + " Returns:\n", + " Tensor: Tensor that may have randomly dropped with proability\n", + " `drop_path`\n", + " \"\"\"\n", + " if drop_prob == 0.0 or not training:\n", + " return x\n", + " keep_prob = 1 - drop_prob\n", + " shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n", + " random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n", + " if keep_prob > 0.0 and scale_by_keep:\n", + " random_tensor.div_(keep_prob)\n", + " return x * random_tensor\n", + "\n", + "\n", + "class DropPath(nn.Module):\n", + " \"\"\"\n", + " Drop paths (Stochastic Depth) per sample (when applied in main path of\n", + " residual blocks).\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self, drop_prob: float | None = None, scale_by_keep: bool = True\n", + " ) -> None:\n", + " super(DropPath, self).__init__()\n", + " self.drop_prob = drop_prob\n", + " self.scale_by_keep = scale_by_keep\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"Runs drop path on input tensor\n", + "\n", + " Args:\n", + " x: input\n", + "\n", + " Returns:\n", + " tensor: output after drop_path\n", + " \"\"\"\n", + " return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)\n", + "\n", + "\n", + "class Mlp(nn.Module):\n", + " \"\"\"\n", + " Multi layer perceptron.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self, features: int, hidden_features: int, dropout: float = 0.0\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " features: Input/output dimension.\n", + " hidden_features: Hidden dimension.\n", + " dropout: Dropout.\n", + " \"\"\"\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(features, hidden_features),\n", + " nn.GELU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(hidden_features, features),\n", + " nn.Dropout(dropout),\n", + " )\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x (Tesnor): Tensor of shape [..., channel]\n", + " Returns:\n", + " Tenosr: Tensor of same shape as x.\n", + " \"\"\"\n", + " return self.net(x)\n", + "\n", + "\n", + "class LayerNormPassThrough(nn.LayerNorm):\n", + " \"\"\"Normalising layer that allows the attention mask to be passed through\"\"\"\n", + "\n", + " def __init__(self, *args, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + "\n", + " def forward(\n", + " self, d: tuple[Tensor, Tensor | None]\n", + " ) -> tuple[Tensor, Tensor | None]:\n", + " \"\"\"Forwards function\n", + "\n", + " Args:\n", + " d (tuple): tuple of the data tensor and the attention mask\n", + " Returns:\n", + " output (Tensor): normalised output data\n", + " attn_mask (Tensor): the attention mask that was passed in\n", + " \"\"\"\n", + " input, attn_mask = d\n", + " output = F.layer_norm(\n", + " input, self.normalized_shape, self.weight, self.bias, self.eps\n", + " )\n", + " return output, attn_mask\n", + "\n", + "\n", + "class MultiheadAttention(nn.Module):\n", + " \"\"\"Multihead attention layer for inputs of shape\n", + " [..., sequence, features].\n", + " \"\"\"\n", + "\n", + " def __init__(self, features: int, n_heads: int, dropout: float) -> None:\n", + " \"\"\"\n", + " Args:\n", + " features: Number of features for inputs to the layer.\n", + " n_heads: Number of attention heads. Should be a factor of features.\n", + " (I.e. the layer uses features // n_heads.)\n", + " dropout: Dropout.\n", + " \"\"\" # noqa: E501\n", + " super().__init__()\n", + "\n", + " if (features % n_heads) != 0:\n", + " raise ValueError(\n", + " f\"Features '{features}' is not divisible by heads '{n_heads}'.\"\n", + " )\n", + "\n", + " self.features = features\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + "\n", + " self.qkv_layer = torch.nn.Linear(features, features * 3, bias=False)\n", + " self.w_layer = torch.nn.Linear(features, features, bias=False)\n", + "\n", + " def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " d (tuple): tuple containing Tensor of shape [..., sequence, features] and the attention mask\n", + " Returns:\n", + " Tensor: Tensor of shape [..., sequence, features]\n", + " \"\"\" # noqa: E501\n", + " x, attn_mask = d\n", + "\n", + " if not x.shape[-1] == self.features:\n", + " raise ValueError(\n", + " f\"Expecting tensor with last dimension size {self.features}.\"\n", + " )\n", + "\n", + " passenger_dims = x.shape[:-2]\n", + " B = passenger_dims.numel()\n", + " S = x.shape[-2]\n", + " C = x.shape[-1]\n", + " x = x.reshape(B, S, C)\n", + "\n", + " # x [B, S, C]\n", + " # q, k, v [B, H, S, C/H]\n", + " q, k, v = (\n", + " self.qkv_layer(x)\n", + " .view(B, S, self.n_heads, 3 * (C // self.n_heads))\n", + " .transpose(1, 2)\n", + " .chunk(chunks=3, dim=3)\n", + " )\n", + "\n", + " # Let us enforce either flash (A100+) or memory efficient attention.\n", + " if version(\"torch\") > \"2.3.0\":\n", + " with sdpa_kernel(\n", + " [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]\n", + " ):\n", + " # x [B, H, S, C//H]\n", + " x = F.scaled_dot_product_attention(\n", + " q, k, v, attn_mask=attn_mask, dropout_p=self.dropout\n", + " )\n", + " else:\n", + " with torch.backends.cuda.sdp_kernel(\n", + " enable_flash=True, enable_math=False, enable_mem_efficient=True\n", + " ):\n", + " # x [B, H, S, C//H]\n", + " x = F.scaled_dot_product_attention(\n", + " q, k, v, dropout_p=self.dropout\n", + " )\n", + "\n", + " # x [B, S, C]\n", + " x = x.transpose(1, 2).view(B, S, C)\n", + "\n", + " # x [B, S, C]\n", + " x = self.w_layer(x)\n", + "\n", + " # Back to input shape\n", + " x = x.view(*passenger_dims, S, self.features)\n", + " return x\n", + "\n", + "\n", + "class Transformer(nn.Module):\n", + " \"\"\"\n", + " Transformer for inputs of shape [..., S, features].\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " features: int,\n", + " mlp_multiplier: int,\n", + " n_heads: int,\n", + " dropout: float,\n", + " drop_path: float,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " features: Number of features for inputs to the layer.\n", + " mlp_multiplier: Model uses features*mlp_multiplier hidden units.\n", + " n_heads: Number of attention heads. Should be a factor of features.\n", + " (I.e. the layer uses features // n_heads.) dropout: Dropout.\n", + " drop_path: DropPath.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.features = features\n", + " self.mlp_multiplier = mlp_multiplier\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + " self.drop_path = (\n", + " DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n", + " )\n", + "\n", + " self.attention = nn.Sequential(\n", + " LayerNormPassThrough(features),\n", + " MultiheadAttention(features, n_heads, dropout),\n", + " )\n", + "\n", + " self.ff = nn.Sequential(\n", + " nn.LayerNorm(features),\n", + " Mlp(\n", + " features=features,\n", + " hidden_features=features * mlp_multiplier,\n", + " dropout=dropout,\n", + " ),\n", + " )\n", + "\n", + " def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x: Tensor of shape [..., sequence, features]\n", + " Returns:\n", + " Tensor: Tensor of shape [..., sequence, features]\n", + " \"\"\"\n", + " x, attn_mask = d\n", + " if not x.shape[-1] == self.features:\n", + " raise ValueError(\n", + " f\"Expecting tensor with last dimension size {self.features}.\"\n", + " )\n", + "\n", + " attention_x = self.attention(d)\n", + "\n", + " x = x + self.drop_path(attention_x)\n", + " x = x + self.drop_path(self.ff(x))\n", + "\n", + " return x\n", + "\n", + "\n", + "class _Shift(nn.Module):\n", + " \"\"\"Private base class for the shifter. This allows some behaviour to be\n", + " easily handled when the shifter isn't used.\n", + " \"\"\"\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " self._shifted = False\n", + "\n", + " @torch.no_grad()\n", + " def reset(self) -> None:\n", + " \"\"\"\n", + " Resets the bool tracking whether the data is shifted\n", + " \"\"\"\n", + " self._shifted: bool = False\n", + "\n", + " def forward(self, data: Tensor) -> tuple[Tensor, dict[bool, None]]:\n", + " return data, {True: None, False: None}\n", + "\n", + "\n", + "class SWINShift(_Shift):\n", + " \"\"\"\n", + " Handles the shifting of patches similar to how SWIN works. However if we\n", + " shift the latitudes then the poles will wrap and potentially that might be\n", + " problematic. The possition tokens should handle it but masking is safer.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " mu_shape: tuple[int, int],\n", + " global_shape: tuple[int, int],\n", + " local_shape: tuple[int, int],\n", + " patch_shape: tuple[int, int],\n", + " n_context_tokens: int = 2,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " mu_shape: the shape to the masking units\n", + " global_shape: number of global patches in lat and lon\n", + " local_shape: size of the local patches\n", + " patch_shape: patch size\n", + " n_context_token: number of additional context tokens at start of\n", + " _each_ local sequence\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self._mu_shape = ms = mu_shape\n", + " self._g_shape = gs = global_shape\n", + " self._l_shape = ls = local_shape\n", + " self._p_shape = ps = patch_shape\n", + " self._lat_patch = (gs[0], ls[0], gs[1], ls[1])\n", + " self._n_context_tokens = n_context_tokens\n", + "\n", + " self._g_shift_to = tuple(\n", + " int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)\n", + " )\n", + " self._g_shift_from = tuple(\n", + " -int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)\n", + " )\n", + "\n", + " # Define the attention masks for the shifted MaxViT.\n", + " nglobal = global_shape[0] * global_shape[1]\n", + " nlocal = (\n", + " local_shape[0] * local_shape[1] + self._n_context_tokens\n", + " ) # \"+ 1\" for leadtime\n", + "\n", + " lm = torch.ones((nglobal, 1, nlocal, nlocal), dtype=bool)\n", + " mwidth = int(0.5 * local_shape[1]) * local_shape[0]\n", + " lm[\n", + " : gs[1],\n", + " :,\n", + " self._n_context_tokens : mwidth + self._n_context_tokens,\n", + " self._n_context_tokens : mwidth + self._n_context_tokens,\n", + " ] = False\n", + " self.register_buffer(\"local_mask\", lm)\n", + "\n", + " gm = torch.ones((nlocal, 1, nglobal, nglobal), dtype=bool)\n", + " gm[: int(0.5 * ls[1]) * ls[0], :, : gs[1], : gs[1]] = False\n", + " self.register_buffer(\"global_mask\", gm)\n", + "\n", + " def _to_grid_global(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shuffle and reshape the data from the global/local setting back to the\n", + " lat/lon grid setting\n", + " Args:\n", + " x: the data tensor to be shuffled.\n", + " Returns:\n", + " x: data in the global/local setting\n", + " \"\"\"\n", + " nbatch, *other = x.shape\n", + "\n", + " y1 = x.view(nbatch, *self._g_shape, *self._l_shape, -1)\n", + " y2 = y1.permute(0, 5, 1, 3, 2, 4).contiguous()\n", + "\n", + " s = y2.shape\n", + " return y2.view((nbatch, -1, s[2] * s[3], s[4] * s[5]))\n", + "\n", + " def _to_grid_local(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shuffle and reshape the data from the local/global setting to the\n", + " lat/lon grid setting\n", + " Args:\n", + " x: the data tensor to be shuffled.\n", + " Returns:\n", + " x: data in the lat/lon setting.\n", + " \"\"\"\n", + " x = x.transpose(2, 1).contiguous()\n", + " return self._to_grid_global(x)\n", + "\n", + " def _from_grid_global(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shuffle and reshape the data from the lat/lon grid to the global/local\n", + " setting\n", + " Args:\n", + " x: the data tensor to be shuffled.\n", + " Returns:\n", + " x: data in the global/local setting\n", + " \"\"\"\n", + " nbatch, *other = x.shape\n", + "\n", + " z1 = x.view(nbatch, -1, *self._lat_patch)\n", + " z2 = z1.permute(0, 2, 4, 3, 5, 1).contiguous()\n", + "\n", + " s = z2.shape\n", + " return z2.view(nbatch, s[1] * s[2], s[3] * s[4], -1)\n", + "\n", + " def _from_grid_local(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shuffle and reshape the data from the lat/lon grid to the local/global\n", + " setting\n", + " Args:\n", + " x: the data tensor to be shuffled.\n", + " Returns:\n", + " x: data in the local/global setting\n", + " \"\"\"\n", + " x = self._from_grid_global(x)\n", + " return x.transpose(2, 1).contiguous()\n", + "\n", + " def _shift(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shifts data in the gridded lat/lon setting by half the mask unit shape\n", + " Args:\n", + " x: data to be shifted\n", + " Returns:\n", + " x: either the hsifted or unshifted data\n", + " \"\"\"\n", + " shift = self._g_shift_from if self._shifted else self._g_shift_to\n", + " x_shifted = torch.roll(x, shift, (-2, -1))\n", + "\n", + " self._shifted = not self._shifted\n", + " return x_shifted\n", + "\n", + " def _sep_lt(self, x: Tensor) -> tuple[Tensor, Tensor]:\n", + " \"\"\"\n", + " Seperate off the leadtime from the local patches\n", + " Args:\n", + " x: data to have leadtime removed from\n", + " Returns:\n", + " lt: leadtime\n", + " x: data without the lead time in the local patch\n", + " \"\"\"\n", + " lt_it = x[:, : self._n_context_tokens, :, :]\n", + " x_stripped = x[:, self._n_context_tokens :, :, :]\n", + "\n", + " return lt_it, x_stripped\n", + "\n", + " def forward(self, data: Tensor) -> tuple[Tensor, Tensor]:\n", + " \"\"\"Shift or unshift the the data depending on whether the data is\n", + " already shifted, as defined by self._shifte.\n", + "\n", + " Args:\n", + " data: data to be shifted\n", + " Returns:\n", + " Tensor: shifted data Tensor\n", + " \"\"\"\n", + " lt, x = self._sep_lt(data)\n", + "\n", + " x_grid = self._to_grid_local(x)\n", + " x_shifted = self._shift(x_grid)\n", + " x_patched = self._from_grid_local(x_shifted)\n", + "\n", + " # Mask has to be repeated based on batch size\n", + " n_batch = x_grid.shape[0]\n", + " local_rep = [n_batch] + [1] * (self.local_mask.ndim - 1)\n", + " global_rep = [n_batch] + [1] * (self.global_mask.ndim - 1)\n", + "\n", + " if self._shifted:\n", + " attn_mask = {\n", + " True: self.local_mask.repeat(local_rep),\n", + " False: self.global_mask.repeat(global_rep),\n", + " }\n", + " else:\n", + " attn_mask = {True: None, False: None}\n", + "\n", + " return torch.cat((lt, x_patched), axis=1), attn_mask\n", + "\n", + "\n", + "class LocalGlobalLocalBlock(nn.Module):\n", + " \"\"\"\n", + " Applies alternating block and grid attention. Given a parameter n_blocks,\n", + " the entire module contains 2*n_blocks+1 transformer blocks. The first,\n", + " third, ..., last apply local (block) attention. The second, fourth, ...\n", + " global (grid) attention.\n", + "\n", + " This is heavily inspired by\n", + " Tu et al. \"MaxViT: Multi-Axis Vision Transformer\"\n", + " (https://arxiv.org/abs/2204.01697).\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " features: int,\n", + " mlp_multiplier: int,\n", + " n_heads: int,\n", + " dropout: float,\n", + " n_blocks: int,\n", + " drop_path: float,\n", + " shifter: nn.Module | None = None,\n", + " checkpoint: list[int] | None = None,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " features: Number of features for inputs to the layer.\n", + " mlp_multiplier: Model uses features*mlp_multiplier hidden units.\n", + " n_heads: Number of attention heads. Should be a factor of features.\n", + " (I.e. the layer uses features // n_heads.)\n", + " dropout: Dropout.\n", + " drop_path: DropPath.\n", + " n_blocks: Number of local-global transformer pairs.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.features = features\n", + " self.mlp_multiplier = mlp_multiplier\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + " self.drop_path = drop_path\n", + " self.n_blocks = n_blocks\n", + " self._checkpoint = checkpoint or []\n", + "\n", + " if not all(0 <= c < 2 * n_blocks + 1 for c in self._checkpoint):\n", + " raise ValueError(\n", + " \"Checkpoints should be 0 <= i < 2*n_blocks+1. \"\n", + " f\"{self._checkpoint=}.\"\n", + " )\n", + "\n", + " self.transformers = nn.ModuleList(\n", + " [\n", + " Transformer(\n", + " features=features,\n", + " mlp_multiplier=mlp_multiplier,\n", + " n_heads=n_heads,\n", + " dropout=dropout,\n", + " drop_path=drop_path,\n", + " )\n", + " for _ in range(2 * n_blocks + 1)\n", + " ]\n", + " )\n", + "\n", + " self.evaluator = [\n", + " self._checkpoint_wrapper\n", + " if i in self._checkpoint\n", + " else lambda m, x: m(x)\n", + " for i, _ in enumerate(self.transformers)\n", + " ]\n", + "\n", + " self.shifter = shifter or _Shift()\n", + "\n", + " @staticmethod\n", + " def _checkpoint_wrapper(\n", + " model: nn.Module, data: tuple[Tensor, Tensor | None]\n", + " ) -> Tensor:\n", + " return checkpoint(model, data, use_reentrant=False)\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x: Tensor of shape::\n", + "\n", + " [batch, global_sequence, local_sequence, features]\n", + "\n", + " Returns:\n", + " Tensor: Tensor of shape::\n", + "\n", + " [batch, global_sequence, local_sequence, features]\n", + " \"\"\"\n", + " if x.shape[-1] != self.features:\n", + " raise ValueError(\n", + " f\"Expecting tensor with last dimension size {self.features}.\"\n", + " )\n", + " if x.ndim != 4:\n", + " raise ValueError(\n", + " f\"Expecting tensor with exactly four dimensions. {x.shape=}.\"\n", + " )\n", + "\n", + " self.shifter.reset()\n", + " local: bool = True\n", + " attn_mask = {True: None, False: None}\n", + "\n", + " transformer_iter = zip(self.evaluator, self.transformers, strict=False)\n", + "\n", + " # First local block\n", + " evaluator, transformer = next(transformer_iter)\n", + " x = evaluator(transformer, (x, attn_mask[local]))\n", + "\n", + " for evaluator, transformer in transformer_iter:\n", + " local = not local\n", + " # We are making exactly 2*n_blocks transposes.\n", + " # So the output has the same shape as input.\n", + " x = x.transpose(1, 2)\n", + "\n", + " x = evaluator(transformer, (x, attn_mask[local]))\n", + "\n", + " if not local:\n", + " x, attn_mask = self.shifter(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "class PatchEmbed(nn.Module):\n", + " \"\"\"\n", + " Patch embedding via 2D convolution.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self, patch_size: int | tuple[int, ...], channels: int, embed_dim: int\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.patch_size = patch_size\n", + " self.channels = channels\n", + " self.embed_dim = embed_dim\n", + "\n", + " self.proj = nn.Conv2d(\n", + " channels,\n", + " embed_dim,\n", + " kernel_size=patch_size,\n", + " stride=patch_size,\n", + " bias=True,\n", + " )\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x: Tensor of shape [batch, channels, lat, lon].\n", + " Returns:\n", + " Tensor: Tensor with shape\n", + " [batch, embed_dim, lat//patch_size, lon//patch_size]\n", + " \"\"\"\n", + "\n", + " H, W = x.shape[-2:]\n", + "\n", + " if W % self.patch_size[1] != 0:\n", + " raise ValueError(\n", + " f\"Cannot do patch embedding for tensor of shape {x.size()}\"\n", + " \" with patch size {self.patch_size}. (Dimensions are BSCHW.)\"\n", + " )\n", + " if H % self.patch_size[0] != 0:\n", + " raise ValueError(\n", + " f\"Cannot do patch embedding for tensor of shape {x.size()}\"\n", + " f\" with patch size {self.patch_size}. (Dimensions are BSCHW.)\"\n", + " )\n", + "\n", + " x = self.proj(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "class PrithviWxCEncoderDecoder(nn.Module):\n", + " \"\"\"\n", + " Hiera-MaxViT encoder/decoder code.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " embed_dim: int,\n", + " n_blocks: int,\n", + " mlp_multiplier: float,\n", + " n_heads: int,\n", + " dropout: float,\n", + " drop_path: float,\n", + " shifter: nn.Module | None = None,\n", + " transformer_cp: list[int] | None = None,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " embed_dim: Embedding dimension\n", + " n_blocks: Number of local-global transformer pairs.\n", + " mlp_multiplier: MLP multiplier for hidden features in feed forward\n", + " networks.\n", + " n_heads: Number of attention heads.\n", + " dropout: Dropout.\n", + " drop_path: DropPath.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.embed_dim = embed_dim\n", + " self.n_blocks = n_blocks\n", + " self.mlp_multiplier = mlp_multiplier\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + " self._transformer_cp = transformer_cp\n", + "\n", + " self.lgl_block = LocalGlobalLocalBlock(\n", + " features=embed_dim,\n", + " mlp_multiplier=mlp_multiplier,\n", + " n_heads=n_heads,\n", + " dropout=dropout,\n", + " drop_path=drop_path,\n", + " n_blocks=n_blocks,\n", + " shifter=shifter,\n", + " checkpoint=transformer_cp,\n", + " )\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x: Tensor of shape\n", + " [batch, global sequence, local sequence, embed_dim]\n", + " Returns:\n", + " Tensor of shape\n", + " [batch, mask_unit_sequence, local_sequence, embed_dim].\n", + " Identical in shape to the input x.\n", + " \"\"\"\n", + "\n", + " x = self.lgl_block(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "class PrithviWxC(nn.Module):\n", + " \"\"\"Encoder-decoder fusing Hiera with MaxViT. See\n", + " - Ryali et al. \"Hiera: A Hierarchical Vision Transformer without the\n", + " Bells-and-Whistles\" (https://arxiv.org/abs/2306.00989)\n", + " - Tu et al. \"MaxViT: Multi-Axis Vision Transformer\"\n", + " (https://arxiv.org/abs/2204.01697)\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " in_channels: int,\n", + " input_size_time: int,\n", + " in_channels_static: int,\n", + " input_scalers_mu: Tensor,\n", + " input_scalers_sigma: Tensor,\n", + " input_scalers_epsilon: float,\n", + " static_input_scalers_mu: Tensor,\n", + " static_input_scalers_sigma: Tensor,\n", + " static_input_scalers_epsilon: float,\n", + " output_scalers: Tensor,\n", + " n_lats_px: int,\n", + " n_lons_px: int,\n", + " patch_size_px: tuple[int],\n", + " mask_unit_size_px: tuple[int],\n", + " mask_ratio_inputs: float,\n", + " embed_dim: int,\n", + " n_blocks_encoder: int,\n", + " n_blocks_decoder: int,\n", + " mlp_multiplier: float,\n", + " n_heads: int,\n", + " dropout: float,\n", + " drop_path: float,\n", + " parameter_dropout: float,\n", + " residual: str,\n", + " masking_mode: str,\n", + " positional_encoding: str,\n", + " decoder_shifting: bool = False,\n", + " checkpoint_encoder: list[int] | None = None,\n", + " checkpoint_decoder: list[int] | None = None,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " in_channels: Number of input channels.\n", + " input_size_time: Number of timestamps in input.\n", + " in_channels_static: Number of input channels for static data.\n", + " input_scalers_mu: Tensor of size (in_channels,). Used to rescale\n", + " input.\n", + " input_scalers_sigma: Tensor of size (in_channels,). Used to rescale\n", + " input.\n", + " input_scalers_epsilon: Float. Used to rescale input.\n", + " static_input_scalers_mu: Tensor of size (in_channels_static). Used\n", + " to rescale static inputs.\n", + " static_input_scalers_sigma: Tensor of size (in_channels_static).\n", + " Used to rescale static inputs.\n", + " static_input_scalers_epsilon: Float. Used to rescale static inputs.\n", + " output_scalers: Tensor of shape (in_channels,). Used to rescale\n", + " output.\n", + " n_lats_px: Total latitudes in data. In pixels.\n", + " n_lons_px: Total longitudes in data. In pixels.\n", + " patch_size_px: Patch size for tokenization. In pixels lat/lon.\n", + " mask_unit_size_px: Size of each mask unit. In pixels lat/lon.\n", + " mask_ratio_inputs: Masking ratio for inputs. 0 to 1.\n", + " embed_dim: Embedding dimension\n", + " n_blocks_encoder: Number of local-global transformer pairs in\n", + " encoder.\n", + " n_blocks_decoder: Number of local-global transformer pairs in\n", + " decoder.\n", + " mlp_multiplier: MLP multiplier for hidden features in feed forward\n", + " networks.\n", + " n_heads: Number of attention heads.\n", + " dropout: Dropout.\n", + " drop_path: DropPath.\n", + " parameter_dropout: Dropout applied to parameters.\n", + " residual: Indicates whether and how model should work as residual\n", + " model. Accepted values are 'climate', 'temporal' and 'none'\n", + " positional_encoding: possible values are\n", + " ['absolute' (default), 'fourier'].\n", + " 'absolute' lat lon encoded in 3 dimensions using sine and\n", + " cosine\n", + " 'fourier' lat/lon to be encoded using various frequencies\n", + " masking_mode: String ['local', 'global', 'both'] that controls the\n", + " type of masking used.\n", + " checkpoint_encoder: List of integers controlling if gradient\n", + " checkpointing is used on encoder.\n", + " Format: [] for no gradient checkpointing. [3, 7] for\n", + " checkpointing after 4th and 8th layer etc.\n", + " checkpoint_decoder: List of integers controlling if gradient\n", + " checkpointing is used on decoder.\n", + " Format: See `checkpoint_encoder`.\n", + " masking_mode: The type of masking to use\n", + " {'global', 'local', 'both'}\n", + " decoder_shifting: Whether to use swin shifting in the decoder.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.in_channels = in_channels\n", + " self.input_size_time = input_size_time\n", + " self.in_channels_static = in_channels_static\n", + " self.n_lats_px = n_lats_px\n", + " self.n_lons_px = n_lons_px\n", + " self.patch_size_px = patch_size_px\n", + " self.mask_unit_size_px = mask_unit_size_px\n", + " self.mask_ratio_inputs = mask_ratio_inputs\n", + " self.embed_dim = embed_dim\n", + " self.n_blocks_encoder = n_blocks_encoder\n", + " self.n_blocks_decoder = n_blocks_decoder\n", + " self.mlp_multiplier = mlp_multiplier\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + " self.drop_path = drop_path\n", + " self.residual = residual\n", + " self._decoder_shift = decoder_shifting\n", + " self.positional_encoding = positional_encoding\n", + " self._checkpoint_encoder = checkpoint_encoder\n", + " self._checkpoint_decoder = checkpoint_decoder\n", + "\n", + " assert self.n_lats_px % self.mask_unit_size_px[0] == 0\n", + " assert self.n_lons_px % self.mask_unit_size_px[1] == 0\n", + " assert self.mask_unit_size_px[0] % self.patch_size_px[0] == 0\n", + " assert self.mask_unit_size_px[1] % self.patch_size_px[1] == 0\n", + "\n", + " if self.patch_size_px[0] != self.patch_size_px[1]:\n", + " raise NotImplementedError(\n", + " \"Current pixel shuffle symmetric patches.\"\n", + " )\n", + "\n", + " self.local_shape_mu = (\n", + " self.mask_unit_size_px[0] // self.patch_size_px[0],\n", + " self.mask_unit_size_px[1] // self.patch_size_px[1],\n", + " )\n", + " self.global_shape_mu = (\n", + " self.n_lats_px // self.mask_unit_size_px[0],\n", + " self.n_lons_px // self.mask_unit_size_px[1],\n", + " )\n", + "\n", + " assert input_scalers_mu.shape == (in_channels,)\n", + " assert input_scalers_sigma.shape == (in_channels,)\n", + " assert output_scalers.shape == (in_channels,)\n", + "\n", + " if self.positional_encoding != \"fourier\":\n", + " assert static_input_scalers_mu.shape == (in_channels_static,)\n", + " assert static_input_scalers_sigma.shape == (in_channels_static,)\n", + "\n", + " # Input shape [batch, time, parameter, lat, lon]\n", + " self.input_scalers_epsilon = input_scalers_epsilon\n", + " self.register_buffer(\n", + " \"input_scalers_mu\", input_scalers_mu.reshape(1, 1, -1, 1, 1)\n", + " )\n", + " self.register_buffer(\n", + " \"input_scalers_sigma\", input_scalers_sigma.reshape(1, 1, -1, 1, 1)\n", + " )\n", + "\n", + " # Static inputs shape [batch, parameter, lat, lon]\n", + " self.static_input_scalers_epsilon = static_input_scalers_epsilon\n", + " self.register_buffer(\n", + " \"static_input_scalers_mu\",\n", + " static_input_scalers_mu.reshape(1, -1, 1, 1),\n", + " )\n", + " self.register_buffer(\n", + " \"static_input_scalers_sigma\",\n", + " static_input_scalers_sigma.reshape(1, -1, 1, 1),\n", + " )\n", + "\n", + " # Output shape [batch, parameter, lat, lon]\n", + " self.register_buffer(\n", + " \"output_scalers\", output_scalers.reshape(1, -1, 1, 1)\n", + " )\n", + "\n", + " self.parameter_dropout = nn.Dropout2d(p=parameter_dropout)\n", + "\n", + " self.patch_embedding = PatchEmbed(\n", + " patch_size=patch_size_px,\n", + " channels=in_channels * input_size_time,\n", + " embed_dim=embed_dim,\n", + " )\n", + "\n", + " if self.residual == \"climate\":\n", + " self.patch_embedding_static = PatchEmbed(\n", + " patch_size=patch_size_px,\n", + " channels=in_channels + in_channels_static,\n", + " embed_dim=embed_dim,\n", + " )\n", + " else:\n", + " self.patch_embedding_static = PatchEmbed(\n", + " patch_size=patch_size_px,\n", + " channels=in_channels_static,\n", + " embed_dim=embed_dim,\n", + " )\n", + "\n", + " self.input_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)\n", + " self.lead_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)\n", + "\n", + " self.mask_token = nn.Parameter(torch.randn(1, 1, 1, self.embed_dim))\n", + " self._nglobal_mu = np.prod(self.global_shape_mu)\n", + " self._global_idx = torch.arange(self._nglobal_mu)\n", + "\n", + " self._nlocal_mu = np.prod(self.local_shape_mu)\n", + " self._local_idx = torch.arange(self._nlocal_mu)\n", + "\n", + " self.encoder = PrithviWxCEncoderDecoder(\n", + " embed_dim=embed_dim,\n", + " n_blocks=n_blocks_encoder,\n", + " mlp_multiplier=mlp_multiplier,\n", + " n_heads=n_heads,\n", + " dropout=dropout,\n", + " drop_path=drop_path,\n", + " transformer_cp=checkpoint_encoder,\n", + " )\n", + "\n", + " if n_blocks_decoder != 0:\n", + " if self._decoder_shift:\n", + " self.decoder_shifter = d_shifter = SWINShift(\n", + " self.mask_unit_size_px,\n", + " self.global_shape_mu,\n", + " self.local_shape_mu,\n", + " self.patch_size_px,\n", + " n_context_tokens=0,\n", + " )\n", + " else:\n", + " self.decoder_shifter = d_shifter = None\n", + "\n", + " self.decoder = PrithviWxCEncoderDecoder(\n", + " embed_dim=embed_dim,\n", + " n_blocks=n_blocks_decoder,\n", + " mlp_multiplier=mlp_multiplier,\n", + " n_heads=n_heads,\n", + " dropout=dropout,\n", + " drop_path=0.0,\n", + " shifter=d_shifter,\n", + " transformer_cp=checkpoint_decoder,\n", + " )\n", + "\n", + " self.unembed = nn.Linear(\n", + " self.embed_dim,\n", + " self.in_channels\n", + " * self.patch_size_px[0]\n", + " * self.patch_size_px[1],\n", + " bias=True,\n", + " )\n", + "\n", + " self.masking_mode = masking_mode.lower()\n", + " match self.masking_mode:\n", + " case \"local\":\n", + " self.generate_mask = self._gen_mask_local\n", + " case \"global\":\n", + " self.generate_mask = self._gen_mask_global\n", + " case \"both\":\n", + " self._mask_both_local: bool = True\n", + " self.generate_mask = self._gen_mask_both\n", + " case _:\n", + " raise ValueError(\n", + " f\"Masking mode '{masking_mode}' not supported\"\n", + " )\n", + "\n", + " def swap_masking(self) -> None:\n", + " self._mask_both_local = not self._mask_both_local\n", + "\n", + " @cached_property\n", + " def n_masked_global(self):\n", + " return int(self.mask_ratio_inputs * np.prod(self.global_shape_mu))\n", + "\n", + " @cached_property\n", + " def n_masked_local(self):\n", + " return int(self.mask_ratio_inputs * np.prod(self.local_shape_mu))\n", + "\n", + " @staticmethod\n", + " def _shuffle_along_axis(a, axis):\n", + " idx = torch.argsort(input=torch.rand(*a.shape), dim=axis)\n", + " return torch.gather(a, dim=axis, index=idx)\n", + "\n", + " def _gen_mask_local(self, sizes: tuple[int]) -> tuple[Tensor]:\n", + " \"\"\"\n", + " Args:\n", + " batch_size: Number of elements in batch\n", + " Returns:\n", + " Tuple of torch tensors. [indices masked, indices unmasked].\n", + " Each of these is a tensor of shape (batch, global sequene)\n", + " \"\"\"\n", + " # Identify which indices (values) should be masked\n", + "\n", + " maskable_indices = self._local_idx.view(1, -1).expand(*sizes[:2], -1)\n", + "\n", + " maskable_indices = self._shuffle_along_axis(maskable_indices, 2)\n", + "\n", + " indices_masked = maskable_indices[:, :, : self.n_masked_local]\n", + " indices_unmasked = maskable_indices[:, :, self.n_masked_local :]\n", + "\n", + " return indices_masked, indices_unmasked\n", + "\n", + " def _gen_mask_global(self, sizes: tuple[int]) -> tuple[Tensor]:\n", + " \"\"\"\n", + " Args:\n", + " batch_size: Number of elements in batch\n", + " Returns:\n", + " Tuple of torch tensors. [indices masked, indices unmasked].\n", + " Each of these is a tensor of shape (batch, global sequene)\n", + " \"\"\"\n", + " # Identify which indices (values) should be masked\n", + "\n", + " maskable_indices = self._global_idx.view(1, -1).expand(*sizes[:1], -1)\n", + "\n", + " maskable_indices = self._shuffle_along_axis(maskable_indices, 1)\n", + "\n", + " indices_masked = maskable_indices[:, : self.n_masked_global]\n", + " indices_unmasked = maskable_indices[:, self.n_masked_global :]\n", + "\n", + " return indices_masked, indices_unmasked\n", + "\n", + " def _gen_mask_both(self, sizes: tuple[int]) -> tuple[Tensor]:\n", + " if self._mask_both_local:\n", + " return self._gen_mask_local(sizes)\n", + " else:\n", + " return self._gen_mask_global(sizes)\n", + "\n", + " @staticmethod\n", + " def reconstruct_batch(\n", + " idx_masked: Tensor,\n", + " idx_unmasked: Tensor,\n", + " data_masked: Tensor,\n", + " data_unmasked: Tensor,\n", + " ) -> Tensor:\n", + " \"\"\"Reconstructs a tensor along the mask unit dimension. Batched\n", + " version.\n", + "\n", + " Args:\n", + " idx_masked: Tensor of shape `batch, mask unit sequence`.\n", + " idx_unmasked: Tensor of shape `batch, mask unit sequence`.\n", + " data_masked: Tensor of shape `batch, mask unit sequence, ...`.\n", + " Should have same size along mask unit sequence dimension as\n", + " idx_masked. Dimensions beyond the first two, marked here as ...\n", + " will typically be `local_sequence, channel` or\n", + " `channel, lat, lon`. These dimensions should agree with\n", + " data_unmasked.\n", + " data_unmasked: Tensor of shape `batch, mask unit sequence, ...`.\n", + " Should have same size along mask unit sequence dimension as\n", + " idx_unmasked. Dimensions beyond the first two, marked here as\n", + " ... will typically be `local_sequence, channel` or `channel,\n", + " lat, lon`. These dimensions should agree with data_masked.\n", + " Returns:\n", + " Tensor: Tensor of same shape as inputs data_masked and\n", + " data_unmasked. I.e. `batch, mask unit sequence, ...`. Index for\n", + " the total data composed of the masked and the unmasked part.\n", + " \"\"\"\n", + " dim: int = idx_masked.ndim\n", + "\n", + " idx_total = torch.argsort(\n", + " torch.cat([idx_masked, idx_unmasked], dim=-1), dim=-1\n", + " )\n", + " idx_total = idx_total.view(\n", + " *idx_total.shape, *[1] * (data_unmasked.ndim - dim)\n", + " )\n", + " idx_total = idx_total.expand(\n", + " *idx_total.shape[:dim], *data_unmasked.shape[dim:]\n", + " )\n", + "\n", + " data = torch.cat([data_masked, data_unmasked], dim=dim - 1)\n", + " data = torch.gather(data, dim=dim - 1, index=idx_total)\n", + "\n", + " return data, idx_total\n", + "\n", + " def fourier_pos_encoding(self, x_static: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Args\n", + " x_static: B x C x H x W. first two channels are lat, and lon\n", + " Returns\n", + " Tensor: Tensor of shape B x E x H x W where E is the embedding\n", + " dimension.\n", + " \"\"\"\n", + "\n", + " # B x C x H x W -> B x 1 x H/P x W/P\n", + " latitudes_patch = F.avg_pool2d(\n", + " x_static[:, [0]],\n", + " kernel_size=self.patch_size_px,\n", + " stride=self.patch_size_px,\n", + " )\n", + " longitudes_patch = F.avg_pool2d(\n", + " x_static[:, [1]],\n", + " kernel_size=self.patch_size_px,\n", + " stride=self.patch_size_px,\n", + " )\n", + "\n", + " modes = (\n", + " torch.arange(self.embed_dim // 4, device=x_static.device).view(\n", + " 1, -1, 1, 1\n", + " )\n", + " + 1.0\n", + " )\n", + " pos_encoding = torch.cat(\n", + " (\n", + " torch.sin(latitudes_patch * modes),\n", + " torch.sin(longitudes_patch * modes),\n", + " torch.cos(latitudes_patch * modes),\n", + " torch.cos(longitudes_patch * modes),\n", + " ),\n", + " axis=1,\n", + " )\n", + "\n", + " return pos_encoding # B x E x H/P x W/P\n", + "\n", + " def time_encoding(self, input_time, lead_time):\n", + " \"\"\"\n", + " Args:\n", + " input_time: Tensor of shape [batch].\n", + " lead_time: Tensor of shape [batch].\n", + " Returns:\n", + " Tensor: Tensor of shape [batch, embed_dim, 1, 1]\n", + " \"\"\"\n", + " input_time = self.input_time_embedding(input_time.view(-1, 1, 1, 1))\n", + " lead_time = self.lead_time_embedding(lead_time.view(-1, 1, 1, 1))\n", + "\n", + " time_encoding = torch.cat(\n", + " (\n", + " torch.cos(input_time),\n", + " torch.cos(lead_time),\n", + " torch.sin(input_time),\n", + " torch.sin(lead_time),\n", + " ),\n", + " axis=3,\n", + " )\n", + " return time_encoding\n", + "\n", + " def to_patching(self, x: Tensor) -> Tensor:\n", + " \"\"\"Transform data from lat/lon space to two axis patching\n", + "\n", + " Args: ->\n", + " x: Tesnor in lat/lon space (N, C, Nlat//P_0, Nlon//P_1)\n", + "\n", + " Returns:\n", + " Tensor in patch space (N, G, L, C)\n", + " \"\"\"\n", + " n_batch = x.shape[0]\n", + "\n", + " x = x.view(\n", + " n_batch,\n", + " -1,\n", + " self.global_shape_mu[0],\n", + " self.local_shape_mu[0],\n", + " self.global_shape_mu[1],\n", + " self.local_shape_mu[1],\n", + " )\n", + " x = x.permute(0, 2, 4, 3, 5, 1).contiguous()\n", + "\n", + " s = x.shape\n", + " return x.view(n_batch, s[1] * s[2], s[3] * s[4], -1)\n", + "\n", + " def from_patching(self, x: Tensor) -> Tensor:\n", + " \"\"\"Transform data from two axis patching to lat/lon space\n", + "\n", + " Args:\n", + " x: Tensor in patch space with shape (N, G, L, C*P_0*P_1)\n", + "\n", + " Returns:\n", + " Tensor: Tensor in lat/lon space\n", + " (N, C*P_0*P_1, Nlat//P_0, Nlon // P_1)\n", + " \"\"\"\n", + " n_batch = x.shape[0]\n", + "\n", + " x = x.view(\n", + " n_batch,\n", + " self.global_shape_mu[0],\n", + " self.global_shape_mu[1],\n", + " self.local_shape_mu[0],\n", + " self.local_shape_mu[1],\n", + " -1,\n", + " )\n", + " x = x.permute(0, 5, 1, 3, 2, 4).contiguous()\n", + "\n", + " s = x.shape\n", + " return x.view(n_batch, -1, s[2] * s[3], s[4] * s[5])\n", + "\n", + " def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " batch: Dictionary the following keys::\n", + "\n", + " 'x': Tensor of shape [batch, time, parameter, lat, lon]\n", + " 'y': Tensor of shape [batch, parameter, lat, lon]\n", + " 'static': Tensor of shape [batch, channel_static, lat, lon]\n", + " 'climate': Optional tensor of shape [batch, parameter, lat, lon]\n", + " 'input_time': Tensor of shape [batch]. Or none.\n", + " 'lead_time': Tensor of shape [batch]. Or none.\n", + "\n", + " Returns:\n", + " Tensor: Tensor of shape [batch, parameter, lat, lon].\n", + " \"\"\" # noqa: E501\n", + " x_rescaled = (batch[\"x\"] - self.input_scalers_mu) / (\n", + " self.input_scalers_sigma + self.input_scalers_epsilon\n", + " )\n", + " batch_size = x_rescaled.shape[0]\n", + "\n", + " if self.positional_encoding == \"fourier\":\n", + " x_static_pos = self.fourier_pos_encoding(batch[\"static\"])\n", + " x_static = (\n", + " batch[\"static\"][:, 2:] - self.static_input_scalers_mu[:, 3:]\n", + " ) / (\n", + " self.static_input_scalers_sigma[:, 3:]\n", + " + self.static_input_scalers_epsilon\n", + " )\n", + " else:\n", + " x_static = (batch[\"static\"] - self.static_input_scalers_mu) / (\n", + " self.static_input_scalers_sigma\n", + " + self.static_input_scalers_epsilon\n", + " )\n", + "\n", + " if self.residual == \"temporal\":\n", + " # We create a residual of same shape as y\n", + " index = torch.where(\n", + " batch[\"lead_time\"] > 0, batch[\"x\"].shape[1] - 1, 0\n", + " )\n", + " index = index.view(-1, 1, 1, 1, 1)\n", + " index = index.expand(batch_size, 1, *batch[\"x\"].shape[2:])\n", + " x_hat = torch.gather(batch[\"x\"], dim=1, index=index)\n", + " x_hat = x_hat.squeeze(1)\n", + " elif self.residual == \"climate\":\n", + " climate_scaled = (\n", + " batch[\"climate\"] - self.input_scalers_mu.view(1, -1, 1, 1)\n", + " ) / (\n", + " self.input_scalers_sigma.view(1, -1, 1, 1)\n", + " + self.input_scalers_epsilon\n", + " )\n", + "\n", + " # [batch, time, parameter, lat, lon]\n", + " # -> [batch, time x parameter, lat, lon]\n", + " x_rescaled = x_rescaled.flatten(1, 2)\n", + " # Parameter dropout\n", + " x_rescaled = self.parameter_dropout(x_rescaled)\n", + "\n", + " x_embedded = self.patch_embedding(x_rescaled)\n", + "\n", + " if self.residual == \"climate\":\n", + " static_embedded = self.patch_embedding_static(\n", + " torch.cat((x_static, climate_scaled), dim=1)\n", + " )\n", + " else:\n", + " static_embedded = self.patch_embedding_static(x_static)\n", + "\n", + " if self.positional_encoding == \"fourier\":\n", + " static_embedded += x_static_pos\n", + "\n", + " x_embedded = self.to_patching(x_embedded)\n", + " static_embedded = self.to_patching(static_embedded)\n", + "\n", + " time_encoding = self.time_encoding(\n", + " batch[\"input_time\"], batch[\"lead_time\"]\n", + " )\n", + "\n", + " tokens = x_embedded + static_embedded + time_encoding\n", + "\n", + " # Now we generate masks based on masking_mode\n", + " indices_masked, indices_unmasked = self.generate_mask(\n", + " (batch_size, self._nglobal_mu)\n", + " )\n", + " indices_masked = indices_masked.to(device=tokens.device)\n", + " indices_unmasked = indices_unmasked.to(device=tokens.device)\n", + " maskdim: int = indices_masked.ndim\n", + "\n", + " # Unmasking\n", + " unmask_view = (*indices_unmasked.shape, *[1] * (tokens.ndim - maskdim))\n", + " unmasked = torch.gather(\n", + " tokens,\n", + " dim=maskdim - 1,\n", + " index=indices_unmasked.view(*unmask_view).expand(\n", + " *indices_unmasked.shape, *tokens.shape[maskdim:]\n", + " ),\n", + " )\n", + "\n", + " # Encoder\n", + " x_encoded = self.encoder(unmasked)\n", + "\n", + " # Generate and position encode the mask tokens\n", + " # [1, 1, 1, embed_dim]\n", + " # -> [batch, global_seq_masked, local seq, embed_dim]\n", + " mask_view = (*indices_masked.shape, *[1] * (tokens.ndim - maskdim))\n", + " masking = self.mask_token.repeat(*static_embedded.shape[:3], 1)\n", + " masked = masking + static_embedded\n", + " masked = torch.gather(\n", + " masked,\n", + " dim=maskdim - 1,\n", + " index=indices_masked.view(*mask_view).expand(\n", + " *indices_masked.shape, *tokens.shape[maskdim:]\n", + " ),\n", + " )\n", + "\n", + " recon, _ = self.reconstruct_batch(\n", + " indices_masked, indices_unmasked, masked, x_encoded\n", + " )\n", + "\n", + " x_decoded = self.decoder(recon)\n", + "\n", + " # Output: [batch, global sequence, local sequence,\n", + " # in_channels * patch_size[0] * patch_size[1]]\n", + " x_unembed = self.unembed(x_decoded)\n", + "\n", + " # Reshape to [batch, global_lat, global_lon, local_lat, local_lon,\n", + " # in_channels * patch_size[0] * patch_size[1]]\n", + " x_out = self.from_patching(x_unembed)\n", + "\n", + " # Pixel shuffle to [batch, in_channels, lat, lon]\n", + " x_out = F.pixel_shuffle(x_out, self.patch_size_px[0])\n", + "\n", + " if self.residual == \"temporal\":\n", + " x_out = self.output_scalers * x_out + x_hat\n", + " elif self.residual == \"climate\":\n", + " x_out = self.output_scalers * x_out + batch[\"climate\"]\n", + " elif self.residual == \"none\":\n", + " x_out = (\n", + " self.output_scalers * x_out\n", + " + self.input_scalers_mu.reshape(1, -1, 1, 1)\n", + " )\n", + "\n", + " return x_out\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", + "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", + "\u001b[1;31mClick here for more info. \n", + "\u001b[1;31mView Jupyter log for further details." + ] + } + ], + "source": [ + "import yaml\n", + "\n", + "# from PrithviWxC.model import PrithviWxC\n", + "\n", + "with open(\"./config.yaml\", \"r\") as f:\n", + " config = yaml.safe_load(f)\n", + "\n", + "model = PrithviWxC(\n", + " in_channels=config[\"params\"][\"in_channels\"],\n", + " input_size_time=config[\"params\"][\"input_size_time\"],\n", + " in_channels_static=config[\"params\"][\"in_channels_static\"],\n", + " input_scalers_mu=in_mu,\n", + " input_scalers_sigma=in_sig,\n", + " input_scalers_epsilon=config[\"params\"][\"input_scalers_epsilon\"],\n", + " static_input_scalers_mu=static_mu,\n", + " static_input_scalers_sigma=static_sig,\n", + " static_input_scalers_epsilon=config[\"params\"][\n", + " \"static_input_scalers_epsilon\"\n", + " ],\n", + " output_scalers=output_sig**0.5,\n", + " n_lats_px=config[\"params\"][\"n_lats_px\"],\n", + " n_lons_px=config[\"params\"][\"n_lons_px\"],\n", + " patch_size_px=config[\"params\"][\"patch_size_px\"],\n", + " mask_unit_size_px=config[\"params\"][\"mask_unit_size_px\"],\n", + " mask_ratio_inputs=masking_ratio,\n", + " embed_dim=config[\"params\"][\"embed_dim\"],\n", + " n_blocks_encoder=config[\"params\"][\"n_blocks_encoder\"],\n", + " n_blocks_decoder=config[\"params\"][\"n_blocks_decoder\"],\n", + " mlp_multiplier=config[\"params\"][\"mlp_multiplier\"],\n", + " n_heads=config[\"params\"][\"n_heads\"],\n", + " dropout=config[\"params\"][\"dropout\"],\n", + " drop_path=config[\"params\"][\"drop_path\"],\n", + " parameter_dropout=config[\"params\"][\"parameter_dropout\"],\n", + " residual=residual,\n", + " masking_mode=masking_mode,\n", + " decoder_shifting=decoder_shifting,\n", + " positional_encoding=positional_encoding,\n", + " checkpoint_encoder=[],\n", + " checkpoint_decoder=[],\n", + ")\n", + "\n", + "\n", + "state_dict = torch.load(weights_path, weights_only=False)\n", + "if \"model_state\" in state_dict:\n", + " state_dict = state_dict[\"model_state\"]\n", + "model.load_state_dict(state_dict, strict=True)\n", + "\n", + "if (hasattr(model, \"device\") and model.device != device) or not hasattr(\n", + " model, \"device\"\n", + "):\n", + " model = model.to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Rollout\n", + "We are now ready to perform the rollout. Agin the data has to be run through a\n", + "preprocessor. However this time we use a preprocessor that can handle the\n", + "additional intermediate data. Also, rather than calling the model directly, we\n", + "have a conveient wrapper function that performs the interation. This also\n", + "simplifies the model loading when using a sharded cahckpoint. If you attempt to\n", + "perform training steps upton this function, we should use an aggressive number\n", + "of activation checkpoints as the memory consumption becomes quite high." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import Tensor, nn\n", + "\n", + "\n", + "def rollout_iter(\n", + " nsteps: int,\n", + " model: nn.Module,\n", + " batch: dict[str, Tensor | int | float],\n", + ") -> Tensor:\n", + " \"\"\"A helper function for performing autoregressive rollout.\n", + "\n", + " Args:\n", + " nsteps (int): The number of rollout steps to take\n", + " model (nn.Module): A model.\n", + " batch (dict): A data dictionary common to the Prithvi models.\n", + "\n", + " Raises:\n", + " ValueError: If the number of steps isn't positive.\n", + "\n", + " Returns:\n", + " Tensor: the output of the model after nsteps autoregressive iterations.\n", + " \"\"\"\n", + " if nsteps < 1:\n", + " raise ValueError(\"'nsteps' shouold be a positive int.\")\n", + "\n", + " xlast = batch[\"x\"][:, 1]\n", + " batch[\"lead_time\"] = batch[\"lead_time\"][..., 0]\n", + "\n", + " # Save the masking ratio to be restored later\n", + " mask_ratio_tmp = model.mask_ratio_inputs\n", + "\n", + " for step in range(nsteps):\n", + " # After first step, turn off masking\n", + " if step > 0:\n", + " model.mask_ratio_inputs = 0.0\n", + "\n", + " batch[\"static\"] = batch[\"statics\"][:, step]\n", + " batch[\"climate\"] = batch[\"climates\"][:, step]\n", + " batch[\"y\"] = batch[\"ys\"][:, step]\n", + "\n", + " out = model(batch)\n", + "\n", + " batch[\"x\"] = torch.cat((xlast[:, None], out[:, None]), dim=1)\n", + " xlast = out\n", + "\n", + " # Restore the masking ratio\n", + " model.mask_ratio_inputs = mask_ratio_tmp\n", + "\n", + " return xlast\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# from PrithviWxC.dataloaders.merra2_rollout import preproc\n", + "# from PrithviWxC.rollout import rollout_iter\n", + "\n", + "data = next(iter(dataset))\n", + "batch = preproc([data], padding)\n", + "\n", + "for k, v in batch.items():\n", + " if isinstance(v, torch.Tensor):\n", + " batch[k] = v.to(device)\n", + "\n", + "rng_state_1 = torch.get_rng_state()\n", + "with torch.no_grad():\n", + " model.eval()\n", + " out = rollout_iter(dataset.nsteps, model, batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plotting" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "t2m = out[0, 12].cpu().numpy()\n", + "\n", + "lat = np.linspace(-90, 90, out.shape[-2])\n", + "lon = np.linspace(-180, 180, out.shape[-1])\n", + "X, Y = np.meshgrid(lon, lat)\n", + "\n", + "plt.contourf(X, Y, t2m, 100)\n", + "plt.gca().set_aspect(\"equal\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}